Modelos com múltiplas entradas

Aprendizagem profunda intermediária com PyTorch

Michal Oleszak

Machine Learning Engineer

Por que múltiplas entradas?

Usando mais informação

Esquema de um modelo que recebe duas imagens de carros como entrada e produz uma saída.

Modelos multimodais

Esquema de um modelo que recebe uma imagem e um texto como entradas e produz uma saída de texto.

Aprendizado por métrica

Esquema de um modelo que recebe duas imagens de rosto e prevê se são a mesma pessoa.

Aprendizado autossupervisionado

Esquema de um modelo que recebe duas versões aumentadas da mesma imagem e aprende que são iguais.

Aprendizagem profunda intermediária com PyTorch

Dataset Omniglot

Uma amostra de imagens do dataset Omniglot.

1 Lake, B. M., Salakhutdinov, R., and Tenenbaum, J. B. (2015). Human-level concept learning through probabilistic program induction. Science, 350(6266), 1332-1338.
Aprendizagem profunda intermediária com PyTorch

Classificação de caracteres

Esquema do modelo: imagens de caracteres são enviadas a uma rede neural.

Aprendizagem profunda intermediária com PyTorch

Classificação de caracteres

Esquema do modelo: vetor one-hot do alfabeto é enviado a uma rede neural.

Aprendizagem profunda intermediária com PyTorch

Classificação de caracteres

Esquema do modelo: embeddings do caractere e do alfabeto são combinados.

Aprendizagem profunda intermediária com PyTorch

Classificação de caracteres

Esquema do modelo: a partir dos embeddings combinados, um classificador prevê o caractere.

Aprendizagem profunda intermediária com PyTorch

Dataset com duas entradas

from PIL import Image

class OmniglotDataset(Dataset):

def __init__(self, transform, samples): self.transform = transform self.samples = samples
def __len__(self): return len(self.samples)
def __getitem__(self, idx): img_path, alphabet, label = self.samples[idx] img = Image.open(img_path).convert('L') img = self.transform(img) return img, alphabet, label
  • Atribuir samples e transforms

    print(samples[0])
    
    [(
      'omniglot_train/.../0459_14.png',
       array([1., 0., 0., ..., 0., 0., 0.]),
       0
     )]
    
  • Implementar __len__()

  • Carregar e transformar a imagem

  • Retornar as duas entradas e o rótulo
Aprendizagem profunda intermediária com PyTorch

Concatenação de tensores

x = torch.tensor([
  [1, 2, 3],
])

y = torch.tensor([
  [4, 5, 6],
])

Concatenação ao longo do eixo 0

torch.cat((x, y), dim=0)
[[1, 2, 3],
 [4, 5, 6]]

Concatenação ao longo do eixo 1

torch.cat((x, y), dim=1)
[[1, 2, 3, 4, 5, 6]]
Aprendizagem profunda intermediária com PyTorch

Arquitetura com duas entradas

class Net(nn.Module):
    def __init__(self):
        super().__init__()

self.image_layer = nn.Sequential( nn.Conv2d(1, 16, kernel_size=3, padding=1), nn.MaxPool2d(kernel_size=2), nn.ELU(), nn.Flatten(), nn.Linear(16*32*32, 128) )
self.alphabet_layer = nn.Sequential( nn.Linear(30, 8), nn.ELU(), )
self.classifier = nn.Sequential( nn.Linear(128 + 8, 964), )
  • Definir a camada de imagem
  • Definir a camada do alfabeto
  • Definir a camada do classificador
Aprendizagem profunda intermediária com PyTorch

Arquitetura com duas entradas

def forward(self, x_image, x_alphabet):

x_image = self.image_layer(x_image)
x_alphabet = self.alphabet_layer(x_alphabet)
x = torch.cat((x_image, x_alphabet), dim=1)
return self.classifier(x)
  • Passar a imagem pela camada de imagem
  • Passar o alfabeto pela camada do alfabeto
  • Concatenar as saídas de imagem e alfabeto
  • Passar o resultado pelo classificador
Aprendizagem profunda intermediária com PyTorch

Loop de treino

net = Net()
criterion = nn.CrossEntropyLoss()
optimizer = optim.SGD(net.parameters(), lr=0.01)

for epoch in range(10):
    for img, alpha, labels in dataloader_train:
        optimizer.zero_grad()
        outputs = net(img, alpha)
        loss = criterion(outputs, labels)
        loss.backward()
        optimizer.step()
  • Dados de treino têm três itens:
    • Imagem
    • Vetor do alfabeto
    • Rótulos
  • Passamos imagens e alfabetos ao modelo
Aprendizagem profunda intermediária com PyTorch

Vamos praticar!

Aprendizagem profunda intermediária com PyTorch

Preparing Video For Download...