Multi-input models

Intermediate Deep Learning with PyTorch

Michal Oleszak

Machine Learning Engineer

Why multi-input?

Using more information

A schema of a model that takes two car images as inputs and produces one output.

Multi-modal models

A schema of a model that an image and a piece of text as inputs and produces a text output.

Metric learning

A schema of a model that takes to face images as inputs and predicts whether they are the same.

Self-supervised learning

A schema of a model that takes two augmented versions of the same images as inputs is learns that they are the same.

Intermediate Deep Learning with PyTorch

Omniglot dataset

A sample of images from the Omniglot dataset.

1 Lake, B. M., Salakhutdinov, R., and Tenenbaum, J. B. (2015). Human-level concept learning through probabilistic program induction. Science, 350(6266), 1332-1338.
Intermediate Deep Learning with PyTorch

Character classification

Model schema: character images is passed to a neural network.

Intermediate Deep Learning with PyTorch

Character classification

Model schema: one-hot alphabet vector is passed to a neural network.

Intermediate Deep Learning with PyTorch

Character classification

Model schema: character and alphabet embeddings are combined.

Intermediate Deep Learning with PyTorch

Character classification

Model schema: from combined embeddings, a classifier predicts the character.

Intermediate Deep Learning with PyTorch

Two-input Dataset

from PIL import Image

class OmniglotDataset(Dataset):

def __init__(self, transform, samples): self.transform = transform self.samples = samples
def __len__(self): return len(self.samples)
def __getitem__(self, idx): img_path, alphabet, label = self.samples[idx] img = Image.open(img_path).convert('L') img = self.transform(img) return img, alphabet, label
  • Assign samples and transforms

    print(samples[0])
    
    [(
      'omniglot_train/.../0459_14.png',
       array([1., 0., 0., ..., 0., 0., 0.]),
       0
     )]
    
  • Implement __len__()

  • Load and transform image

  • Return both inputs and label
Intermediate Deep Learning with PyTorch

Tensor concatenation

x = torch.tensor([
  [1, 2, 3],
])

y = torch.tensor([
  [4, 5, 6],
])

Concatenation along axis 0

torch.cat((x, y), dim=0)
[[1, 2, 3],
 [4, 5, 6]]

Concatenation along axis 1

torch.cat((x, y), dim=1)
[[1, 2, 3, 4, 5, 6]]
Intermediate Deep Learning with PyTorch

Two-input architecture

class Net(nn.Module):
    def __init__(self):
        super().__init__()

self.image_layer = nn.Sequential( nn.Conv2d(1, 16, kernel_size=3, padding=1), nn.MaxPool2d(kernel_size=2), nn.ELU(), nn.Flatten(), nn.Linear(16*32*32, 128) )
self.alphabet_layer = nn.Sequential( nn.Linear(30, 8), nn.ELU(), )
self.classifier = nn.Sequential( nn.Linear(128 + 8, 964), )
  • Define image processing layer
  • Define alphabet processing layer
  • Define classifier layer
Intermediate Deep Learning with PyTorch

Two-input architecture

def forward(self, x_image, x_alphabet):

x_image = self.image_layer(x_image)
x_alphabet = self.alphabet_layer(x_alphabet)
x = torch.cat((x_image, x_alphabet), dim=1)
return self.classifier(x)
  • Pass image through image layer
  • Pass alphabet through alphabet layer
  • Concatenate image and alphabet outputs
  • Pass the result through classifier
Intermediate Deep Learning with PyTorch

Training loop

net = Net()
criterion = nn.CrossEntropyLoss()
optimizer = optim.SGD(net.parameters(), lr=0.01)

for epoch in range(10):
    for img, alpha, labels in dataloader_train:
        optimizer.zero_grad()
        outputs = net(img, alpha)
        loss = criterion(outputs, labels)
        loss.backward()
        optimizer.step()
  • Training data consists of three items:
    • Image
    • Alphabet vector
    • Labels
  • We pass the model images and alphabets
Intermediate Deep Learning with PyTorch

Let's practice!

Intermediate Deep Learning with PyTorch

Preparing Video For Download...