Transformer Models with PyTorch
James Chapman
Curriculum Manager, DataCamp
x
: decoder information flow, becomes cross-attention queryy
: encoder output, becomes cross-attention key and valuesclass DecoderLayer(nn.Module): def __init__(self, d_model, num_heads, d_ff, dropout): super().__init__() self.self_attn = MultiHeadAttention( d_model, num_heads) self.cross_attn = MultiHeadAttention( d_model, num_heads) ...
def forward(self, x, y, tgt_mask, cross_mask): self_attn_output = self.self_attn(x, x, x, tgt_mask) x = self.norm1(x + self.dropout(self_attn_output)) cross_attn_output = self.cross_attn(x, y, y, cross_mask) x = self.norm2(x + self.dropout(cross_attn_output)) ...
class TransformerDecoder(nn.Module):
...
def forward(self, x, tgt_mask):
x = self.embedding(x)
x = self.positional_encoding(x)
for layer in self.layers:
x = layer(x, tgt_mask)
x = self.fc(x)
return F.log_softmax(x, dim=-1)
class TransformerDecoder(nn.Module): ...
def forward(self, x, y, tgt_mask, cross_mask): x = self.embedding(x) x = self.positional_encoding(x) for layer in self.layers: x = layer(x, y, tgt_mask, cross_mask) x = self.fc(x) return F.log_softmax(x, dim=-1)
For other tasks, different activations may be required
class InputEmbeddings(nn.Module):
...
class PositionalEncoding(nn.Module):
...
class MultiHeadAttention(nn.Module):
...
class FeedForwardSubLayer(nn.Module):
...
class EncoderLayer(nn.Module):
...
class DecoderLayer(nn.Module):
...
class TransformerEncoder(nn.Module):
...
class TransformerDecoder(nn.Module):
...
class ClassificationHead(nn.Module):
...
class Transformer(nn.Module): def __init__(self, vocab_size, d_model, num_heads, num_layers, d_ff, max_seq_len, dropout): super().__init__()
self.encoder = TransformerEncoder(vocab_size, d_model, num_heads, num_layers, d_ff, dropout, max_seq_len) self.decoder = TransformerDecoder(vocab_size, d_model, num_heads, num_layers, d_ff, dropout, max_seq_len)
def forward(self, x, src_mask, tgt_mask, cross_mask): encoder_output = self.encoder(x, src_mask) decoder_output = self.decoder(x, encoder_output, tgt_mask, cross_mask) return decoder_output
Transformer Models with PyTorch