Deep Learning for Images with PyTorch
Michal Oleszak
Machine Learning Engineer
Training models from scratch:
Pre-trained models - models already trained on a task
Steps to leveraging pre-trained models:
torchvision
modelstorch.save()
.pt
or .pth
.state_dict()
torch.save(model.state_dict(), "BinaryCNN.pth")
Instantiate a new model
new_model = BinaryCNN()
Load saved parameters
new_model.load_state_dict(torch.load('BinaryCNN.pth'))
from torchvision.models import ( resnet18, ResNet18_Weights )
weights = ResNet18_Weights.DEFAULT
model = resnet18(weights=weights)
transforms = weights.transforms()
resnet
architecture and weightsfrom PIL import Image image = Image.open("cat013.jpg")
image_tensor = transform(image)
image_reshaped = image_tensors.unsqueeze(0)
model.eval()
with torch.no_grad():
pred = model(image_reshaped).squeeze(0)
pred_cls = pred.softmax(0)
cls_id = pred_cls.argmax().item()
cls_name = weights.meta["categories"][cls_id]
print(cls_name)
Egyptian cat
Deep Learning for Images with PyTorch