Evaluating GANs

Deep Learning per Immagini con PyTorch

Michal Oleszak

Machine Learning Engineer

Generating images

num_images_to_generate = 9
noise = torch.randn(num_images_to_generate, 16)

with torch.no_grad(): fake = gen(noise)
print(f"Generated shape: {fake.shape}")
Generated shape: torch.Size([9, 3, 96, 96])
for i in range(num_images_to_generate):

image_tensor = fake[i, :, :, :]
image_permuted = image_tensor.permute(1, 2, 0)
plt.imshow(image_permuted) plt.show()
  • Create random noise tensor
  • Pass noise to generator
  • Iterate over number of images
  • Slice fake to select i-th image
  • Rearrange the image dimensions
  • Plot the image
Deep Learning per Immagini con PyTorch

GAN generations

Sample of Pokemon images generated by a GAN.

Deep Learning per Immagini con PyTorch

Fréchet Inception Distance

  • Inception: Image classification model
  • Fréchet distance: Distance measure between two probability distributions
  • Fréchet Inception Distance:
    1. Use Inception to extract features from both real and fake images samples
    2. Calculate means and covariances of the features for real and fake images
    3. Calculate Fréchet distance between the real and the fake normal distributions
  • Low FID = fakes similar to training data and diverse
  • FID < 10 = good
Deep Learning per Immagini con PyTorch

FID in PyTorch

from torchmetrics.image.fid import \
FrechetInceptionDistance


fid = FrechetInceptionDistance(feature=64)
fid.update( (fake * 255).to(torch.uint8), real=False)
fid.update( (real * 255).to(torch.uint8), real=True)
fid.compute()
tensor(7.5159)
  • Import FrechetInceptionDistance
  • Instantiate the FID metric
  • Update the metric with fake images:
    • Multiply by 255
    • Parse to torch.uint8
  • Similarly, update the metric with real images
  • Compute metric value
Deep Learning per Immagini con PyTorch

Let's practice!

Deep Learning per Immagini con PyTorch

Preparing Video For Download...