Working with pre-trained models

Deep Learning for Images with PyTorch

Michal Oleszak

Machine Learning Engineer

Leveraging pre-trained models

  • Training models from scratch:

    • Long process
    • Requires lots of data
  • Pre-trained models - models already trained on a task

    • Directly reusable on a new task
    • Require adjustment to the new task (transfer learning)
  • Steps to leveraging pre-trained models:

    • Saving & loading models locally
    • Downloading torchvision models
Deep Learning for Images with PyTorch

Saving a complete PyTorch model

  • torch.save()
  • Model extension: .pt or .pth
  • Save model weights with .state_dict()
    torch.save(model.state_dict(), "BinaryCNN.pth")
    
Deep Learning for Images with PyTorch

Loading PyTorch models

  • Instantiate a new model

    new_model = BinaryCNN()
    
  • Load saved parameters

    new_model.load_state_dict(torch.load('BinaryCNN.pth'))
    
Deep Learning for Images with PyTorch

Downloading torchvision models

from torchvision.models import (
    resnet18, ResNet18_Weights
)


weights = ResNet18_Weights.DEFAULT
model = resnet18(weights=weights)
transforms = weights.transforms()
  • Import resnet architecture and weights
  • Extract weights
  • Instantiate a model passing it weights
  • Store required data transforms
Deep Learning for Images with PyTorch

Prepare new input images

from PIL import Image

image = Image.open("cat013.jpg")

image_tensor = transform(image)
image_reshaped = image_tensors.unsqueeze(0)

 

cat image

  • Load image
  • Transform image
  • Reshape image
Deep Learning for Images with PyTorch

Generating a new prediction

model.eval()


with torch.no_grad():
pred = model(image_reshaped).squeeze(0)
pred_cls = pred.softmax(0)
cls_id = pred_cls.argmax().item()
cls_name = weights.meta["categories"][cls_id]
print(cls_name)
Egyptian cat
  • Evaluation mode for inference
  • Disable gradients
  • Pass image to model and remove batch dimension
  • Apply softmax
  • Select the highest-probability class and extract its index
  • Map class index to label
  • Print class label
Deep Learning for Images with PyTorch

Let's practice

Deep Learning for Images with PyTorch

Preparing Video For Download...