Werken met vooraf getrainde modellen

Deep Learning voor afbeeldingen met PyTorch

Michal Oleszak

Machine Learning Engineer

Vooraf getrainde modellen benutten

  • Modellen vanaf nul trainen:

    • Duurt lang
    • Vereist veel data
  • Vooraf getrainde modellen - al getraind voor een taak

    • Direct herbruikbaar voor een nieuwe taak
    • Vereisen afstemming op de nieuwe taak (transfer learning)
  • Stappen om vooraf getrainde modellen te gebruiken:

    • Modellen lokaal opslaan & laden
    • torchvision-modellen downloaden
Deep Learning voor afbeeldingen met PyTorch

Een volledig PyTorch-model opslaan

  • torch.save()
  • Model-extensie: .pt of .pth
  • Modelgewichten opslaan met .state_dict()
    torch.save(model.state_dict(), "BinaryCNN.pth")
    
Deep Learning voor afbeeldingen met PyTorch

PyTorch-modellen laden

  • Nieuw model instantiëren

    new_model = BinaryCNN()
    
  • Opgeslagen parameters laden

    new_model.load_state_dict(torch.load('BinaryCNN.pth'))
    
Deep Learning voor afbeeldingen met PyTorch

`torchvision`-modellen downloaden

from torchvision.models import (
    resnet18, ResNet18_Weights
)


weights = ResNet18_Weights.DEFAULT
model = resnet18(weights=weights)
transforms = weights.transforms()
  • resnet-architectuur en -gewichten importeren
  • Gewichten ophalen
  • Model instantieren met gewichten
  • Vereiste datatransforms opslaan
Deep Learning voor afbeeldingen met PyTorch

Nieuwe invoerafbeeldingen voorbereiden

from PIL import Image

image = Image.open("cat013.jpg")

image_tensor = transform(image)
image_reshaped = image_tensors.unsqueeze(0)

 

kattenafbeelding

  • Afbeelding laden
  • Afbeelding transformeren
  • Afbeelding reshapen
Deep Learning voor afbeeldingen met PyTorch

Een nieuwe voorspelling genereren

model.eval()


with torch.no_grad():
pred = model(image_reshaped).squeeze(0)
pred_cls = pred.softmax(0)
cls_id = pred_cls.argmax().item()
cls_name = weights.meta["categories"][cls_id]
print(cls_name)
Egyptian cat
  • Evaluatiemodus voor inferentie
  • Gradients uitschakelen
  • Afbeelding door model, batchdimensie verwijderen
  • Softmax toepassen
  • Klasse met hoogste kans kiezen en index ophalen
  • Klasse-index naar label mappen
  • Klasselabel printen
Deep Learning voor afbeeldingen met PyTorch

Laten we oefenen!

Deep Learning voor afbeeldingen met PyTorch

Preparing Video For Download...