Training GANs

Deep Learning for Images with PyTorch

Michal Oleszak

Machine Learning Engineer

Generator objective

 

 

GAN workflow diagram.

  • Objective: Generate fakes that fool the discriminator
  • Idea: Use the discriminator to inform us about generator's performance
  • Generator's output classified by the discriminator as:
    • Real (label 1) - good, small loss
    • Fake (label 0) - bad, large loss
Deep Learning for Images with PyTorch

Generator loss

def gen_loss(gen, disc, num_images, z_dim):

noise = torch.randn(num_images, z_dim)
fake = gen(noise)
disc_pred = disc(fake)
criterion = nn.BCEWithLogitsLoss()
gen_loss = criterion( disc_pred, torch.ones_like(disc_pred) ) return gen_loss
  • Define random noise
  • Generate fake image
  • Get discriminator's prediction on the fake image
  • Use binary cross-entropy (BCE) criterion
  • Generator loss: BCE between discriminator predictions and a tensor of ones
Deep Learning for Images with PyTorch

Discriminator objective

 

 

GAN workflow diagram.

  • Objective: Correctly classify fakes and real images
  • Generator's outputs should be classified as fake (label 0)
  • Real images should be classified as real (label 1)
Deep Learning for Images with PyTorch

Discriminator loss

def disc_loss(gen, disc, real, num_images, z_dim):

criterion = nn.BCEWithLogitsLoss()
noise = torch.randn(num_images, z_dim)
fake = gen(noise)
disc_pred_fake = disc(fake)
fake_loss = criterion( disc_pred_fake, torch.zeros_like(disc_pred_fake) )
disc_pred_real = disc(real)
real_loss = criterion( disc_pred_real, torch.ones_like(disc_pred_real) )
disc_loss = (real_loss + fake_loss) / 2 return disc_loss
  • Define binary cross-entropy criterion
  • Generate input noise for generator
  • Generate fakes
  • Get discriminator's predictions for fake images
  • Calculate the fake loss component
  • Get discriminator's predictions for real images
  • Calculate the real loss component
  • Final loss is the average between the real and fake loss components
Deep Learning for Images with PyTorch

GAN training loop

for epoch in range(num_epochs):
    for real in dataloader:
        cur_batch_size = len(real)


disc_opt.zero_grad()
disc_loss = disc_loss( gen, disc, real, cur_batch_size, z_dim=16)
disc_loss.backward() disc_opt.step()
gen_opt.zero_grad()
gen_loss = gen_loss( gen, disc, cur_batch_size, z_dim=16)
gen_loss.backward() gen_opt.step()
  • Loop over epochs and real data batches and compute current batch size
  • Reset discriminator optimizer's gradients
  • Compute discriminator loss
  • Compute discriminator gradients and perform the optimization step
  • Reset generator optimizer's gradients
  • Compute generator loss
  • Compute generator gradients and perform the optimization step
Deep Learning for Images with PyTorch

Let's practice!

Deep Learning for Images with PyTorch

Preparing Video For Download...