Deep Learning for Text with PyTorch
Shubham Jain
Instructor
"The monkey ate that banana because it was too hungry"
What does the word "it" refer to?
Self-Attention: assigns significance to words within a sentence
Multi-Head Attention: like having multiple spotlights, capturing different facets
data = ["the cat sat on the mat", ...]
vocab = set(' '.join(data).split())
word_to_ix = {word: i for i, word in enumerate(vocab)} ix_to_word = {i: word for word, i in word_to_ix.items()}
pairs = [sentence.split() for sentence in data] input_data = [[word_to_ix[word] for word in sentence[:-1]] for sentence in pairs] target_data = [word_to_ix[sentence[-1]] for sentence in pairs] inputs = [torch.tensor(seq, dtype=torch.long) for seq in input_data] targets = torch.tensor(target_data, dtype=torch.long)
embedding_dim = 10 hidden_dim = 16
class RNNWithAttentionModel(nn.Module): def __init__(self): super(RNNWithAttentionModel, self).__init__()
self.embeddings = nn.Embedding(vocab_size, embedding_dim) self.rnn = nn.RNN(embedding_dim, hidden_dim, batch_first=True)
self.attention = nn.Linear(hidden_dim, 1)
self.fc = nn.Linear(hidden_dim, vocab_size)
def forward(self, x): x = self.embeddings(x) out, _ = self.rnn(x)
attn_weights = torch.nn.functional.softmax(self.attention(out).squeeze(2), dim=1)
context = torch.sum(attn_weights.unsqueeze(2) * out, dim=1) out = self.fc(context) return out
def pad_sequences(batch): max_len = max([len(seq) for seq in batch]) return torch.stack([torch.cat([seq, torch.zeros(max_len-len(seq)).long()]) for seq in batch])
criterion = nn.CrossEntropyLoss()
attention_model = RNNWithAttentionModel() optimizer = torch.optim.Adam(attention_model.parameters(), lr=0.01)
for epoch in range(300): attention_model.train() optimizer.zero_grad()
padded_inputs = pad_sequences(inputs) outputs = attention_model(padded_inputs)
loss = criterion(outputs, targets) loss.backward() optimizer.step()
for input_seq, target in zip(input_data, target_data): input_test = torch.tensor(input_seq, dtype=torch.long).unsqueeze(0)
attention_model.eval() attention_output = attention_model(input_test)
attention_prediction = ix_to_word[torch.argmax(attention_output).item()]
print(f"\nInput: {' '.join([ix_to_word[ix] for ix in input_seq])}") print(f"Target: {ix_to_word[target]}") print(f"RNN with Attention prediction: {attention_prediction}")
Input: the cat sat on the
Target: mat
RNN with Attention prediction: mat
Deep Learning for Text with PyTorch