Encoder-decoder transformers

Modelli Transformer con PyTorch

James Chapman

Curriculum Manager, DataCamp

Encoder meets decoder

Original transformer architecture

Modelli Transformer con PyTorch

Encoder meets decoder

Encoder and decoder put together

Modelli Transformer con PyTorch

Cross-attention mechanism

 

  1. Information processed throughout decoder
  2. Final hidden states from encoder block

 

Cross attention example

Decoder with cross-attention

Modelli Transformer con PyTorch

Modifying the DecoderLayer

 

  1. Information processed throughout decoder
  2. Final hidden states from encoder block

 

  • x: decoder information flow, becomes cross-attention query
  • y: encoder output, becomes cross-attention key and values
class DecoderLayer(nn.Module):
    def __init__(self, d_model, num_heads, d_ff, dropout):
        super().__init__()
        self.self_attn = MultiHeadAttention(
                          d_model, num_heads)
        self.cross_attn = MultiHeadAttention(
                          d_model, num_heads)
        ...


def forward(self, x, y, tgt_mask, cross_mask): self_attn_output = self.self_attn(x, x, x, tgt_mask) x = self.norm1(x + self.dropout(self_attn_output)) cross_attn_output = self.cross_attn(x, y, y, cross_mask) x = self.norm2(x + self.dropout(cross_attn_output)) ...
Modelli Transformer con PyTorch

Modifying DecoderTransformer

 

Decoder-only
class TransformerDecoder(nn.Module):
...
def forward(self, x, tgt_mask):
    x = self.embedding(x)
    x = self.positional_encoding(x)
    for layer in self.layers:
        x = layer(x, tgt_mask)
    x = self.fc(x)
    return F.log_softmax(x, dim=-1)

 

Encoder-decoder
class TransformerDecoder(nn.Module):
...

def forward(self, x, y, tgt_mask, cross_mask): x = self.embedding(x) x = self.positional_encoding(x) for layer in self.layers: x = layer(x, y, tgt_mask, cross_mask) x = self.fc(x) return F.log_softmax(x, dim=-1)
Modelli Transformer con PyTorch

Encoder meets decoder

Encoder and decoder put together

Modelli Transformer con PyTorch

Transformer head

 

Outputs example for translation

  • jugar (to play): 0.03
  • viajar (to travel): 0.96
  • dormir (to sleep): 0.01

For other tasks, different activations may be required

Decoder with transformer head

Modelli Transformer con PyTorch

Everything brought together!

Overall encoder-decoder transformer

Modelli Transformer con PyTorch

Everything brought together!

class InputEmbeddings(nn.Module):
  ...  
class PositionalEncoding(nn.Module):
  ...  
class MultiHeadAttention(nn.Module):
  ...
class FeedForwardSubLayer(nn.Module):
  ...  
class EncoderLayer(nn.Module):
  ...
class DecoderLayer(nn.Module):
  ...
class TransformerEncoder(nn.Module):
  ...
class TransformerDecoder(nn.Module):
  ...
class ClassificationHead(nn.Module):
  ...
class Transformer(nn.Module):
    def __init__(self, vocab_size, d_model, num_heads, 
                 num_layers, d_ff, max_seq_len, dropout):
        super().__init__()


self.encoder = TransformerEncoder(vocab_size, d_model, num_heads, num_layers, d_ff, dropout, max_seq_len) self.decoder = TransformerDecoder(vocab_size, d_model, num_heads, num_layers, d_ff, dropout, max_seq_len)
def forward(self, x, src_mask, tgt_mask, cross_mask): encoder_output = self.encoder(x, src_mask) decoder_output = self.decoder(x, encoder_output, tgt_mask, cross_mask) return decoder_output
Modelli Transformer con PyTorch

Let's practice!

Modelli Transformer con PyTorch

Preparing Video For Download...