Attentionmechanismen voor tekstopwekking

Deep Learning voor tekst met PyTorch

Shubham Jain

Instructor

Ambiguïteit in tekstverwerking

  • "The monkey ate that banana because it was too hungry"

  • Waar verwijst het woord "it" naar?

Mens vs machine

Deep Learning voor tekst met PyTorch

Attentionmechanismen

  • Wijs belangrijkheid toe aan woorden
  • Zorgt dat de interpretatie van de machine aansluit bij menselijk begrip

Attention-grafiek

1 Xie, Huiqiang & Qin, Zhijin & Li, Geoffrey & Juang, Biing-Hwang. (2020). Deep Learning Enabled Semantic Communication Systems
Deep Learning voor tekst met PyTorch

Self- en multi-head attention

  • Self-attention: kent gewicht toe aan woorden in een zin

    • "The cat, which was on the roof, was scared"
    • Verbindt "was scared" met "The cat"
  • Multi-head attention: meerdere spotlights voor verschillende aspecten

    • Begrijpen dat "was scared" kan verwijzen naar
    • "The cat", "the roof" of "was on"
Deep Learning voor tekst met PyTorch

Attentionmechanisme - vocabulaire en data instellen

data = ["the cat sat on the mat", ...]

vocab = set(' '.join(data).split())
word_to_ix = {word: i for i, word in enumerate(vocab)} ix_to_word = {i: word for word, i in word_to_ix.items()}
pairs = [sentence.split() for sentence in data] input_data = [[word_to_ix[word] for word in sentence[:-1]] for sentence in pairs] target_data = [word_to_ix[sentence[-1]] for sentence in pairs] inputs = [torch.tensor(seq, dtype=torch.long) for seq in input_data] targets = torch.tensor(target_data, dtype=torch.long)
Deep Learning voor tekst met PyTorch

Modeldefinitie

embedding_dim = 10
hidden_dim = 16

class RNNWithAttentionModel(nn.Module): def __init__(self): super(RNNWithAttentionModel, self).__init__()
self.embeddings = nn.Embedding(vocab_size, embedding_dim) self.rnn = nn.RNN(embedding_dim, hidden_dim, batch_first=True)
self.attention = nn.Linear(hidden_dim, 1)
self.fc = nn.Linear(hidden_dim, vocab_size)
Deep Learning voor tekst met PyTorch

Voorwaartse propagatie met attention

def forward(self, x):
    x = self.embeddings(x)
    out, _ = self.rnn(x)

attn_weights = torch.nn.functional.softmax(self.attention(out).squeeze(2), dim=1)
context = torch.sum(attn_weights.unsqueeze(2) * out, dim=1) out = self.fc(context) return out
def pad_sequences(batch): max_len = max([len(seq) for seq in batch]) return torch.stack([torch.cat([seq, torch.zeros(max_len-len(seq)).long()]) for seq in batch])
Deep Learning voor tekst met PyTorch

Training voorbereiden

criterion = nn.CrossEntropyLoss()

attention_model = RNNWithAttentionModel() optimizer = torch.optim.Adam(attention_model.parameters(), lr=0.01)
for epoch in range(300): attention_model.train() optimizer.zero_grad()
padded_inputs = pad_sequences(inputs) outputs = attention_model(padded_inputs)
loss = criterion(outputs, targets) loss.backward() optimizer.step()
Deep Learning voor tekst met PyTorch

Modelevaluatie

for input_seq, target in zip(input_data, target_data):
    input_test = torch.tensor(input_seq, dtype=torch.long).unsqueeze(0)

attention_model.eval() attention_output = attention_model(input_test)
attention_prediction = ix_to_word[torch.argmax(attention_output).item()]
print(f"\nInput: {' '.join([ix_to_word[ix] for ix in input_seq])}") print(f"Target: {ix_to_word[target]}") print(f"RNN with Attention prediction: {attention_prediction}")
Input: the cat sat on the
Target: mat
RNN with Attention prediction: mat
Deep Learning voor tekst met PyTorch

Laten we oefenen!

Deep Learning voor tekst met PyTorch

Preparing Video For Download...