Overfitting bestrijden

Introductie tot Deep Learning met PyTorch

Jasmin Ludolf

Senior Data Science Content Developer, DataCamp

Oorzaken van overfitting

  • Overfitting: het model generaliseert niet naar onzichtbare data
    • Model onthoudt trainingsdata
    • Scoort goed op trainingsdata maar slecht op validatiedata
  • Mogelijke oorzaken:
Probleem Oplossingen
Dataset is niet groot genoeg Meer data verzamelen / data-augmentatie
Model heeft te veel capaciteit Model verkleinen / dropout toevoegen
Gewichten zijn te groot Weight decay
Introductie tot Deep Learning met PyTorch

Overfitting bestrijden

Strategieën:

  • Model verkleinen of een dropout-laag toevoegen
  • Weight decay gebruiken om parameters klein te houden
  • Nieuwe data verkrijgen of data augmenteren
Introductie tot Deep Learning met PyTorch

"Regularisatie" met een dropout-laag

  • Zet willekeurig elementen van de invoertensor op nul tijdens training
model = nn.Sequential(nn.Linear(8, 4),
                      nn.ReLU(),
                      nn.Dropout(p=0.5))
features = torch.randn((1, 8))
print(model(features))
tensor([[1.4655, 0.0000, 0.0000, 0.8456]], grad_fn=<MulBackward0>)
  • Dropout komt na de activatiefunctie
  • Gedraagt zich anders bij training vs. evaluatie - gebruik model.train() voor training en model.eval() om dropout uit te zetten tijdens evaluatie
Introductie tot Deep Learning met PyTorch

Regularisatie met weight decay

optimizer = optim.SGD(model.parameters(), lr=0.001, weight_decay=0.0001)

  • Bepaald door de weight_decay-parameter in de optimizer, meestal klein (bijv. 0.0001)
  • Weight decay stimuleert kleinere gewichten door een straf toe te voegen tijdens optimalisatie
  • Helpt overfitting te verminderen: houdt gewichten kleiner en verbetert generalisatie
Introductie tot Deep Learning met PyTorch

Data-augmentatie

voorbeelden van data-augmentatie

Introductie tot Deep Learning met PyTorch

Laten we oefenen!

Introductie tot Deep Learning met PyTorch

Preparing Video For Download...