Deep Learning for Images with PyTorch
Michal Oleszak
Machine Learning Engineer

1) - good, small loss0) - bad, large lossdef gen_loss(gen, disc, num_images, z_dim):noise = torch.randn(num_images, z_dim)fake = gen(noise)disc_pred = disc(fake)criterion = nn.BCEWithLogitsLoss()gen_loss = criterion( disc_pred, torch.ones_like(disc_pred) ) return gen_loss

0)1)def disc_loss(gen, disc, real, num_images, z_dim):criterion = nn.BCEWithLogitsLoss()noise = torch.randn(num_images, z_dim)fake = gen(noise)disc_pred_fake = disc(fake)fake_loss = criterion( disc_pred_fake, torch.zeros_like(disc_pred_fake) )disc_pred_real = disc(real)real_loss = criterion( disc_pred_real, torch.ones_like(disc_pred_real) )disc_loss = (real_loss + fake_loss) / 2 return disc_loss
for epoch in range(num_epochs): for real in dataloader: cur_batch_size = len(real)disc_opt.zero_grad()disc_loss = disc_loss( gen, disc, real, cur_batch_size, z_dim=16)disc_loss.backward() disc_opt.step()gen_opt.zero_grad()gen_loss = gen_loss( gen, disc, cur_batch_size, z_dim=16)gen_loss.backward() gen_opt.step()
Deep Learning for Images with PyTorch