Deep Learning for Images with PyTorch
Michal Oleszak
Machine Learning Engineer
num_images_to_generate = 9 noise = torch.randn(num_images_to_generate, 16)
with torch.no_grad(): fake = gen(noise)
print(f"Generated shape: {fake.shape}")
Generated shape: torch.Size([9, 3, 96, 96])
for i in range(num_images_to_generate):
image_tensor = fake[i, :, :, :]
image_permuted = image_tensor.permute(1, 2, 0)
plt.imshow(image_permuted) plt.show()
from torchmetrics.image.fid import \ FrechetInceptionDistance
fid = FrechetInceptionDistance(feature=64)
fid.update( (fake * 255).to(torch.uint8), real=False)
fid.update( (real * 255).to(torch.uint8), real=True)
fid.compute()
tensor(7.5159)
FrechetInceptionDistance
torch.uint8
Deep Learning for Images with PyTorch