GAN's trainen

Deep Learning voor afbeeldingen met PyTorch

Michal Oleszak

Machine Learning Engineer

Generator-doel

 

 

GAN-workflowdiagram.

  • Doel: fakes genereren die de discriminator misleiden
  • Idee: gebruik de discriminator om de generatorprestatie te meten
  • Output van de generator geclassificeerd door de discriminator als:
    • Echt (label 1) - goed, lage loss
    • Fake (label 0) - slecht, hoge loss
Deep Learning voor afbeeldingen met PyTorch

Generator-loss

def gen_loss(gen, disc, num_images, z_dim):

noise = torch.randn(num_images, z_dim)
fake = gen(noise)
disc_pred = disc(fake)
criterion = nn.BCEWithLogitsLoss()
gen_loss = criterion( disc_pred, torch.ones_like(disc_pred) ) return gen_loss
  • Definieer willekeurige ruis
  • Genereer fake beeld
  • Haal discriminator-voorspelling voor de fake op
  • Gebruik binaire cross-entropy (BCE)
  • Generator-loss: BCE tussen discriminator-voorspellingen en een tensor van enen
Deep Learning voor afbeeldingen met PyTorch

Discriminator-doel

 

 

GAN-workflowdiagram.

  • Doel: fakes en echte beelden correct classificeren
  • Outputs van de generator moeten als fake worden geclassificeerd (label 0)
  • Echte beelden moeten als echt worden geclassificeerd (label 1)
Deep Learning voor afbeeldingen met PyTorch

Discriminator-loss

def disc_loss(gen, disc, real, num_images, z_dim):

criterion = nn.BCEWithLogitsLoss()
noise = torch.randn(num_images, z_dim)
fake = gen(noise)
disc_pred_fake = disc(fake)
fake_loss = criterion( disc_pred_fake, torch.zeros_like(disc_pred_fake) )
disc_pred_real = disc(real)
real_loss = criterion( disc_pred_real, torch.ones_like(disc_pred_real) )
disc_loss = (real_loss + fake_loss) / 2 return disc_loss
  • Definieer binaire cross-entropy-criterium
  • Genereer invoerruis voor de generator
  • Genereer fakes
  • Haal voorspellingen van de discriminator voor fakes op
  • Bereken de fake-losscomponent
  • Haal voorspellingen van de discriminator voor echte beelden op
  • Bereken de real-losscomponent
  • Eindverlies: gemiddelde van real- en fake-losscomponenten
Deep Learning voor afbeeldingen met PyTorch

GAN-trainingslus

for epoch in range(num_epochs):
    for real in dataloader:
        cur_batch_size = len(real)


disc_opt.zero_grad()
disc_loss = disc_loss( gen, disc, real, cur_batch_size, z_dim=16)
disc_loss.backward() disc_opt.step()
gen_opt.zero_grad()
gen_loss = gen_loss( gen, disc, cur_batch_size, z_dim=16)
gen_loss.backward() gen_opt.step()
  • Loop over epochs en batches echte data; bepaal batchgrootte
  • Reset de gradients van de discriminator-optimizer
  • Bereken discriminator-loss
  • Backprop en doe de optimizer-stap
  • Reset de gradients van de generator-optimizer
  • Bereken generator-loss
  • Backprop en doe de optimizer-stap
Deep Learning voor afbeeldingen met PyTorch

Laten we oefenen!

Deep Learning voor afbeeldingen met PyTorch

Preparing Video For Download...