Deep Learning for Text with PyTorch
Shubham Jain
Data Scientist
Why?
Example: Detecting sarcasm in a tweet
"I just love getting stuck in traffic."
# Import libraries from torch.utils.data import Dataset, DataLoader
# Create a class class TextDataset(Dataset):
def __init__(self, text): self.text = text
def __len__(self): return len(self.text)
def __getitem__(self, idx): return self.text[idx]
sample_tweet = "This movie had a great plot and amazing acting."
# Preprocess the review and convert it to a tensor (not shown for brevity)
# ...
sentiment_prediction = model(sample_tweet_tensor)
Tweet:
"Loved the cinematography,
hated the dialogue.
The acting was exceptional,
but the plot fell flat."
LSTM architecture: Input gate, forget gate, and output gate
class LSTMModel(nn.Module):
def __init__(self, input_size, hidden_size, output_size): super(LSTMModel, self).__init__() self.lstm = nn.LSTM(input_size, hidden_size, batch_first=True) self.fc = nn.Linear(hidden_size, output_size)
def forward(self, x): _, (hidden, _) = self.lstm(x) output = self.fc(hidden.squeeze(0)) return output
Email subject:
"Congratulations!
You've won a free trip
to Hawaii!"
class GRUModel(nn.Module): def __init__(self, input_size, hidden_size, output_size): super(GRUModel, self).__init__() self.gru = nn.GRU(input_size, hidden_size, batch_first=True) self.fc = nn.Linear(hidden_size, output_size)
def forward(self, x): _, hidden = self.gru(x) output = self.fc(hidden.squeeze(0)) return output
Deep Learning for Text with PyTorch