Células LSTM e GRU

Aprendizagem profunda intermediária com PyTorch

Michal Oleszak

Machine Learning Engineer

Problema de memória de curto prazo

  • Células RNN mantêm memória via estado oculto
  • Essa memória é de curtíssimo prazo
  • Duas células mais fortes resolvem isso:
    • LSTM (Long Short-Term Memory)
    • GRU (Gated Recurrent Unit)

Esquema do neurônio recorrente. No passo 2, recebe h2 e x2 e produz y2 e h3.

Aprendizagem profunda intermediária com PyTorch

Célula RNN

Esquema da célula RNN.

  • Duas entradas:
    • dado atual x
    • estado oculto anterior h
  • Duas saídas:
    • saída atual y
    • próximo estado oculto h
Aprendizagem profunda intermediária com PyTorch

Célula LSTM

Esquema da célula LSTM.

  • Saídas h e y são iguais
  • Três entradas e saídas (dois estados ocultos):

    • h: curto prazo
    • c: longo prazo
  • Três “portas”:

    • Porta de esquecimento: o que remover da memória longa
    • Porta de entrada: o que salvar na memória longa
    • Porta de saída: o que retornar no passo atual
Aprendizagem profunda intermediária com PyTorch

LSTM no PyTorch

class Net(nn.Module):
    def __init__(self, input_size):
        super().__init__()

self.lstm = nn.LSTM( input_size=1, hidden_size=32, num_layers=2, batch_first=True, ) self.fc = nn.Linear(32, 1)
def forward(self, x): h0 = torch.zeros(2, x.size(0), 32) c0 = torch.zeros(2, x.size(0), 32)
out, _ = self.lstm(x, (h0, c0))
out = self.fc(out[:, -1, :]) return out
  • __init__():
    • Troca nn.RNN por nn.LSTM
  • forward():
    • Adiciona outro estado oculto c
    • Inicializa c e h com zeros
    • Passa ambos os estados para a camada lstm
Aprendizagem profunda intermediária com PyTorch

Célula GRU

Esquema da célula GRU.

  • Versão simplificada da LSTM
  • Só um estado oculto
  • Sem porta de saída
Aprendizagem profunda intermediária com PyTorch

GRU no PyTorch

class Net(nn.Module):
    def __init__(self, input_size):
        super().__init__()

self.gru = nn.GRU( input_size=1, hidden_size=32, num_layers=2, batch_first=True, ) self.fc = nn.Linear(32, 1)
def forward(self, x): h0 = torch.zeros(2, x.size(0), 32) out, _ = self.gru(x, h0) out = self.fc(out[:, -1, :]) return out
  • __init__():
    • Troca nn.RNN por nn.GRU
  • forward():
    • Usa a camada gru
Aprendizagem profunda intermediária com PyTorch

Uso RNN, LSTM ou GRU?

  • RNN quase não é mais usada
  • GRU é mais simples que LSTM = menos computação
  • Desempenho relativo varia por caso de uso
  • Testa as duas e compara

Esquemas das células LSTM e GRU.

Aprendizagem profunda intermediária com PyTorch

Vamos praticar!

Aprendizagem profunda intermediária com PyTorch

Preparing Video For Download...