Evaluating GANs

Deep Learning for Images with PyTorch

Michal Oleszak

Machine Learning Engineer

Generating images

num_images_to_generate = 9
noise = torch.randn(num_images_to_generate, 16)

with torch.no_grad(): fake = gen(noise)
print(f"Generated shape: {fake.shape}")
Generated shape: torch.Size([9, 3, 96, 96])
for i in range(num_images_to_generate):

image_tensor = fake[i, :, :, :]
image_permuted = image_tensor.permute(1, 2, 0)
plt.imshow(image_permuted) plt.show()
  • Create random noise tensor
  • Pass noise to generator
  • Iterate over number of images
  • Slice fake to select i-th image
  • Rearrange the image dimensions
  • Plot the image
Deep Learning for Images with PyTorch

GAN generations

Sample of Pokemon images generated by a GAN.

Deep Learning for Images with PyTorch

Fréchet Inception Distance

  • Inception: Image classification model
  • Fréchet distance: Distance measure between two probability distributions
  • Fréchet Inception Distance:
    1. Use Inception to extract features from both real and fake images samples
    2. Calculate means and covariances of the features for real and fake images
    3. Calculate Fréchet distance between the real and the fake normal distributions
  • Low FID = fakes similar to training data and diverse
  • FID < 10 = good
Deep Learning for Images with PyTorch

FID in PyTorch

from torchmetrics.image.fid import \
FrechetInceptionDistance


fid = FrechetInceptionDistance(feature=64)
fid.update( (fake * 255).to(torch.uint8), real=False)
fid.update( (real * 255).to(torch.uint8), real=True)
fid.compute()
tensor(7.5159)
  • Import FrechetInceptionDistance
  • Instantiate the FID metric
  • Update the metric with fake images:
    • Multiply by 255
    • Parse to torch.uint8
  • Similarly, update the metric with real images
  • Compute metric value
Deep Learning for Images with PyTorch

Let's practice!

Deep Learning for Images with PyTorch

Preparing Video For Download...