Melatih GAN

Deep Learning untuk Gambar dengan PyTorch

Michal Oleszak

Machine Learning Engineer

Tujuan generator

 

 

Diagram alur kerja GAN.

  • Tujuan: Hasilkan data palsu yang menipu diskriminator
  • Ide: Gunakan diskriminator untuk menilai kinerja generator
  • Keluaran generator diklasifikasi oleh diskriminator sebagai:
    • Asli (label 1) — baik, loss kecil
    • Palsu (label 0) — buruk, loss besar
Deep Learning untuk Gambar dengan PyTorch

Loss generator

def gen_loss(gen, disc, num_images, z_dim):

noise = torch.randn(num_images, z_dim)
fake = gen(noise)
disc_pred = disc(fake)
criterion = nn.BCEWithLogitsLoss()
gen_loss = criterion( disc_pred, torch.ones_like(disc_pred) ) return gen_loss
  • Definisikan noise acak
  • Hasilkan gambar palsu
  • Dapatkan prediksi diskriminator pada gambar palsu
  • Gunakan kriteria binary cross-entropy (BCE)
  • Loss generator: BCE antara prediksi diskriminator dan tensor berisi satu
Deep Learning untuk Gambar dengan PyTorch

Tujuan diskriminator

 

 

Diagram alur kerja GAN.

  • Tujuan: Mengklasifikasi gambar palsu dan asli dengan benar
  • Keluaran generator harus diklasifikasi sebagai palsu (label 0)
  • Gambar asli harus diklasifikasi sebagai asli (label 1)
Deep Learning untuk Gambar dengan PyTorch

Loss diskriminator

def disc_loss(gen, disc, real, num_images, z_dim):

criterion = nn.BCEWithLogitsLoss()
noise = torch.randn(num_images, z_dim)
fake = gen(noise)
disc_pred_fake = disc(fake)
fake_loss = criterion( disc_pred_fake, torch.zeros_like(disc_pred_fake) )
disc_pred_real = disc(real)
real_loss = criterion( disc_pred_real, torch.ones_like(disc_pred_real) )
disc_loss = (real_loss + fake_loss) / 2 return disc_loss
  • Tentukan kriteria binary cross-entropy
  • Hasilkan noise input untuk generator
  • Hasilkan data palsu
  • Dapatkan prediksi diskriminator untuk gambar palsu
  • Hitung komponen loss palsu
  • Dapatkan prediksi diskriminator untuk gambar asli
  • Hitung komponen loss asli
  • Loss akhir = rata-rata loss asli dan palsu
Deep Learning untuk Gambar dengan PyTorch

Loop pelatihan GAN

for epoch in range(num_epochs):
    for real in dataloader:
        cur_batch_size = len(real)


disc_opt.zero_grad()
disc_loss = disc_loss( gen, disc, real, cur_batch_size, z_dim=16)
disc_loss.backward() disc_opt.step()
gen_opt.zero_grad()
gen_loss = gen_loss( gen, disc, cur_batch_size, z_dim=16)
gen_loss.backward() gen_opt.step()
  • Loop atas epoch dan batch data asli, hitung ukuran batch saat ini
  • Reset gradien optimizer diskriminator
  • Hitung loss diskriminator
  • Hitung gradien diskriminator dan lakukan langkah optimisasi
  • Reset gradien optimizer generator
  • Hitung loss generator
  • Hitung gradien generator dan lakukan langkah optimisasi
Deep Learning untuk Gambar dengan PyTorch

Ayo berlatih!

Deep Learning untuk Gambar dengan PyTorch

Preparing Video For Download...