Attention mechanisms for text generation

Deep Learning for Text with PyTorch

Shubham Jain

Instructor

The ambiguity in text processing

  • "The monkey ate that banana because it was too hungry"

  • What does the word "it" refer to?

HumanVsMachine

Deep Learning for Text with PyTorch

Attention mechanisms

  • Assigns importance to words
  • Ensures that machine's interpretation aligns with human understanding

Attention Chart

1 Xie, Huiqiang & Qin, Zhijin & Li, Geoffrey & Juang, Biing-Hwang. (2020). Deep Learning Enabled Semantic Communication Systems
Deep Learning for Text with PyTorch

Self and multi-head attention

  • Self-Attention: assigns significance to words within a sentence

    • The cat, which was on the roof, was scared"
    • Linking "was scared" to "The cat"
  • Multi-Head Attention: like having multiple spotlights, capturing different facets

    • Understanding "was scared" can relate to
    • "The cat", "the roof", or "was on"
Deep Learning for Text with PyTorch

Attention mechanism - setting vocabulary and data

data = ["the cat sat on the mat", ...]

vocab = set(' '.join(data).split())
word_to_ix = {word: i for i, word in enumerate(vocab)} ix_to_word = {i: word for word, i in word_to_ix.items()}
pairs = [sentence.split() for sentence in data] input_data = [[word_to_ix[word] for word in sentence[:-1]] for sentence in pairs] target_data = [word_to_ix[sentence[-1]] for sentence in pairs] inputs = [torch.tensor(seq, dtype=torch.long) for seq in input_data] targets = torch.tensor(target_data, dtype=torch.long)
Deep Learning for Text with PyTorch

Model definition

embedding_dim = 10
hidden_dim = 16

class RNNWithAttentionModel(nn.Module): def __init__(self): super(RNNWithAttentionModel, self).__init__()
self.embeddings = nn.Embedding(vocab_size, embedding_dim) self.rnn = nn.RNN(embedding_dim, hidden_dim, batch_first=True)
self.attention = nn.Linear(hidden_dim, 1)
self.fc = nn.Linear(hidden_dim, vocab_size)
Deep Learning for Text with PyTorch

Forward propagation with attention

def forward(self, x):
    x = self.embeddings(x)
    out, _ = self.rnn(x)

attn_weights = torch.nn.functional.softmax(self.attention(out).squeeze(2), dim=1)
context = torch.sum(attn_weights.unsqueeze(2) * out, dim=1) out = self.fc(context) return out
def pad_sequences(batch): max_len = max([len(seq) for seq in batch]) return torch.stack([torch.cat([seq, torch.zeros(max_len-len(seq)).long()]) for seq in batch])
Deep Learning for Text with PyTorch

Training preparation

criterion = nn.CrossEntropyLoss()

attention_model = RNNWithAttentionModel() optimizer = torch.optim.Adam(attention_model.parameters(), lr=0.01)
for epoch in range(300): attention_model.train() optimizer.zero_grad()
padded_inputs = pad_sequences(inputs) outputs = attention_model(padded_inputs)
loss = criterion(outputs, targets) loss.backward() optimizer.step()
Deep Learning for Text with PyTorch

Model evaluation

for input_seq, target in zip(input_data, target_data):
    input_test = torch.tensor(input_seq, dtype=torch.long).unsqueeze(0)

attention_model.eval() attention_output = attention_model(input_test)
attention_prediction = ix_to_word[torch.argmax(attention_output).item()]
print(f"\nInput: {' '.join([ix_to_word[ix] for ix in input_seq])}") print(f"Target: {ix_to_word[target]}") print(f"RNN with Attention prediction: {attention_prediction}")
Input: the cat sat on the
Target: mat
RNN with Attention prediction: mat
Deep Learning for Text with PyTorch

Let's practice!

Deep Learning for Text with PyTorch

Preparing Video For Download...