Modelprestatie evalueren

Introductie tot Deep Learning met PyTorch

Jasmin Ludolf

Senior Data Science Content Developer, DataCamp

Training, validatie en testen

$$

  • Een dataset wordt meestal in drie subsets gesplitst:
Percentage data Rol
Training 80-90% Stelt modelparameters af
Validatie 10-20% Stemt hyperparameters af
Test 5-10% Beoordeelt uiteindelijke modelprestatie

$$

  • Volg loss en accuracy tijdens training en validatie
Introductie tot Deep Learning met PyTorch

Training loss berekenen

$$

Per epoch:

  • Sommeer de loss over alle batches in de dataloader
  • Bereken de gemiddelde training loss aan het einde van de epoch
training_loss = 0.0

for inputs, labels in trainloader: # Voer de forward pass uit outputs = model(inputs) # Bereken de loss loss = criterion(outputs, labels)
# Backpropagation loss.backward() # Bereken gradiënten optimizer.step() # Werk gewichten bij optimizer.zero_grad() # Reset gradiënten
# Bereken en tel de loss op training_loss += loss.item()
epoch_loss = training_loss / len(trainloader)
Introductie tot Deep Learning met PyTorch

Validatie loss berekenen

validation_loss = 0.0
model.eval() # Zet model in evaluatiemodus


with torch.no_grad(): # Schakel gradiënten uit voor efficiëntie
for inputs, labels in validationloader: # Voer de forward pass uit outputs = model(inputs) # Bereken de loss loss = criterion(outputs, labels) validation_loss += loss.item() epoch_loss = validation_loss / len(validationloader) # Bereken gemiddelde loss
model.train() # Schakel terug naar trainingmodus
Introductie tot Deep Learning met PyTorch

Overfitting

een voorbeeld van overfitting

Introductie tot Deep Learning met PyTorch

Accuracy berekenen met torchmetrics

import torchmetrics


# Maak accuracymetriek metric = torchmetrics.Accuracy(task="multiclass", num_classes=3)
for features, labels in dataloader: outputs = model(features) # Forward pass # Bereken batch-accuracy (met argmax voor one-hot labels) metric.update(outputs, labels.argmax(dim=-1))
# Bereken accuracy over de hele epoch accuracy = metric.compute()
# Reset metriek voor de volgende epoch metric.reset()
Introductie tot Deep Learning met PyTorch

Laten we oefenen!

Introductie tot Deep Learning met PyTorch

Preparing Video For Download...