Deep Learning for Images with PyTorch
Michal Oleszak
Machine Learning Engineer
1
) - good, small loss0
) - bad, large lossdef gen_loss(gen, disc, num_images, z_dim):
noise = torch.randn(num_images, z_dim)
fake = gen(noise)
disc_pred = disc(fake)
criterion = nn.BCEWithLogitsLoss()
gen_loss = criterion( disc_pred, torch.ones_like(disc_pred) ) return gen_loss
0
)1
)def disc_loss(gen, disc, real, num_images, z_dim):
criterion = nn.BCEWithLogitsLoss()
noise = torch.randn(num_images, z_dim)
fake = gen(noise)
disc_pred_fake = disc(fake)
fake_loss = criterion( disc_pred_fake, torch.zeros_like(disc_pred_fake) )
disc_pred_real = disc(real)
real_loss = criterion( disc_pred_real, torch.ones_like(disc_pred_real) )
disc_loss = (real_loss + fake_loss) / 2 return disc_loss
for epoch in range(num_epochs): for real in dataloader: cur_batch_size = len(real)
disc_opt.zero_grad()
disc_loss = disc_loss( gen, disc, real, cur_batch_size, z_dim=16)
disc_loss.backward() disc_opt.step()
gen_opt.zero_grad()
gen_loss = gen_loss( gen, disc, cur_batch_size, z_dim=16)
gen_loss.backward() gen_opt.step()
Deep Learning for Images with PyTorch