Avaliação do desempenho do modelo

Introdução ao Aprendizado Profundo com o PyTorch

Jasmin Ludolf

Senior Data Science Content Developer, DataCamp

Treinamento, validação e testes

$$

  • Um conjunto de dados é dividido em três subconjuntos:
Porcentagem de dados Função
Treinamento 80-90% Ajusta os parâmetros do modelo
Validação 10-20% Ajusta os hiperparâmetros
Teste 5-10% Avalia o desempenho do modelo final

$$

  • Acompanha perda e precisão durante treinamento e validação
Introdução ao Aprendizado Profundo com o PyTorch

Cálculo da perda de treinamento

$$

Para cada época:

  • Soma a perda em todos os lotes no dataloader
  • Calcula a perda média de treinamento no final da época
training_loss = 0.0

for inputs, labels in trainloader: # Run the forward pass outputs = model(inputs) # Compute the loss loss = criterion(outputs, labels)
# Backpropagation loss.backward() # Compute gradients optimizer.step() # Update weights optimizer.zero_grad() # Reset gradients
# Calculate and sum the loss training_loss += loss.item()
epoch_loss = training_loss / len(trainloader)
Introdução ao Aprendizado Profundo com o PyTorch

Cálculo da perda de validação

validation_loss = 0.0
model.eval() # Put model in evaluation mode


with torch.no_grad(): # Disable gradients for efficiency
for inputs, labels in validationloader: # Run the forward pass outputs = model(inputs) # Calculate the loss loss = criterion(outputs, labels) validation_loss += loss.item() epoch_loss = validation_loss / len(validationloader) # Compute mean loss
model.train() # Switch back to training mode
Introdução ao Aprendizado Profundo com o PyTorch

Sobreajuste

um exemplo de sobreajuste

Introdução ao Aprendizado Profundo com o PyTorch

Cálculo da precisão com torchmetrics

import torchmetrics


# Create accuracy metric metric = torchmetrics.Accuracy(task="multiclass", num_classes=3)
for features, labels in dataloader: outputs = model(features) # Forward pass # Compute batch accuracy (keeping argmax for one-hot labels) metric.update(outputs, labels.argmax(dim=-1))
# Compute accuracy over the whole epoch accuracy = metric.compute()
# Reset metric for the next epoch metric.reset()
Introdução ao Aprendizado Profundo com o PyTorch

Vamos praticar!

Introdução ao Aprendizado Profundo com o PyTorch

Preparing Video For Download...