Multi-Input-Modelle

Deep Learning mit PyTorch für Fortgeschrittene

Michal Oleszak

Machine Learning Engineer

Warum Multi-Input?

Mehr Informationen nutzen

Schema eines Modells, das zwei Autobilder als Eingaben erhält und eine Ausgabe liefert.

Multimodale Modelle

Schema eines Modells, das ein Bild und einen Text als Eingaben erhält und einen Text ausgibt.

Metric Learning

Schema eines Modells, das zwei Gesichter als Eingaben erhält und vorhersagt, ob sie gleich sind.

Self-Supervised Learning

Schema eines Modells, das zwei augmentierte Versionen desselben Bildes als Eingaben erhält und lernt, dass sie gleich sind.

Deep Learning mit PyTorch für Fortgeschrittene

Omniglot-Datensatz

Eine Auswahl von Bildern aus dem Omniglot-Datensatz.

1 Lake, B. M., Salakhutdinov, R., and Tenenbaum, J. B. (2015). Human-level concept learning through probabilistic program induction. Science, 350(6266), 1332-1338.
Deep Learning mit PyTorch für Fortgeschrittene

Zeichenklassifikation

Modellschema: Zeichenbilder werden an ein neuronales Netz übergeben.

Deep Learning mit PyTorch für Fortgeschrittene

Zeichenklassifikation

Modellschema: Ein One-Hot-Alphabetvektor wird an ein neuronales Netz übergeben.

Deep Learning mit PyTorch für Fortgeschrittene

Zeichenklassifikation

Modellschema: Zeichen- und Alphabet-Embeddings werden kombiniert.

Deep Learning mit PyTorch für Fortgeschrittene

Zeichenklassifikation

Modellschema: Aus den kombinierten Embeddings sagt ein Klassifikator das Zeichen voraus.

Deep Learning mit PyTorch für Fortgeschrittene

Datensatz mit zwei Inputs

from PIL import Image

class OmniglotDataset(Dataset):

def __init__(self, transform, samples): self.transform = transform self.samples = samples
def __len__(self): return len(self.samples)
def __getitem__(self, idx): img_path, alphabet, label = self.samples[idx] img = Image.open(img_path).convert('L') img = self.transform(img) return img, alphabet, label
  • Samples und Transforms zuweisen

    print(samples[0])
    
    [(
      'omniglot_train/.../0459_14.png',
       array([1., 0., 0., ..., 0., 0., 0.]),
       0
     )]
    
  • __len__() implementieren

  • Bild laden und transformieren

  • Beide Inputs und Label zurückgeben
Deep Learning mit PyTorch für Fortgeschrittene

Tensor-Konkatenation

x = torch.tensor([
  [1, 2, 3],
])

y = torch.tensor([
  [4, 5, 6],
])

Konkatenation entlang Achse 0

torch.cat((x, y), dim=0)
[[1, 2, 3],
 [4, 5, 6]]

Konkatenation entlang Achse 1

torch.cat((x, y), dim=1)
[[1, 2, 3, 4, 5, 6]]
Deep Learning mit PyTorch für Fortgeschrittene

Zwei-Input-Architektur

class Net(nn.Module):
    def __init__(self):
        super().__init__()

self.image_layer = nn.Sequential( nn.Conv2d(1, 16, kernel_size=3, padding=1), nn.MaxPool2d(kernel_size=2), nn.ELU(), nn.Flatten(), nn.Linear(16*32*32, 128) )
self.alphabet_layer = nn.Sequential( nn.Linear(30, 8), nn.ELU(), )
self.classifier = nn.Sequential( nn.Linear(128 + 8, 964), )
  • Bildebene definieren
  • Alphabeteingabe-Ebene definieren
  • Klassifikator definieren
Deep Learning mit PyTorch für Fortgeschrittene

Zwei-Input-Architektur

def forward(self, x_image, x_alphabet):

x_image = self.image_layer(x_image)
x_alphabet = self.alphabet_layer(x_alphabet)
x = torch.cat((x_image, x_alphabet), dim=1)
return self.classifier(x)
  • Bild durch Bildebene schicken
  • Alphabet durch Alphabeteingabe-Ebene schicken
  • Bild- und Alphabet-Outputs konkatenieren
  • Ergebnis durch Klassifikator schicken
Deep Learning mit PyTorch für Fortgeschrittene

Trainingsschleife

net = Net()
criterion = nn.CrossEntropyLoss()
optimizer = optim.SGD(net.parameters(), lr=0.01)

for epoch in range(10):
    for img, alpha, labels in dataloader_train:
        optimizer.zero_grad()
        outputs = net(img, alpha)
        loss = criterion(outputs, labels)
        loss.backward()
        optimizer.step()
  • Trainingsdaten bestehen aus drei Elementen:
    • Bild
    • Alphabetvektor
    • Labels
  • Wir geben dem Modell Bilder und Alphabete
Deep Learning mit PyTorch für Fortgeschrittene

Lass uns üben!

Deep Learning mit PyTorch für Fortgeschrittene

Preparing Video For Download...