LSTM and GRU cells

Intermediate Deep Learning with PyTorch

Michal Oleszak

Machine Learning Engineer

Short-term memory problem

  • RNN cells maintain memory via hidden state
  • This memory is very short-term
  • Two more powerful cells solve the problem:
    • LSTM (Long Short-Term Memory) cell
    • GRU (Gated Recurrent Unit) cell

Schema of the recurrent neuron. At time step 2, it receives inputs h2 and x2, and produces outputs y2 and h3.

Intermediate Deep Learning with PyTorch

RNN cell

Schema of the RNN cell.

  • Two inputs:
    • current input data x
    • previous hidden state h
  • Two outputs:
    • current output y
    • next hidden state h
Intermediate Deep Learning with PyTorch

LSTM cell

Schema of the LSTM cell.

  • Outputs h and y are the same
  • Three inputs and outputs (two hidden states):

    • h: short-term state
    • c: long-term state
  • Three "gates":

    • Forget gate: what to remove from long-term memory
    • Input gate: what to save to long-term memory
    • Output gate: what to return at the current time step
Intermediate Deep Learning with PyTorch

LSTM in PyTorch

class Net(nn.Module):
    def __init__(self, input_size):
        super().__init__()

self.lstm = nn.LSTM( input_size=1, hidden_size=32, num_layers=2, batch_first=True, ) self.fc = nn.Linear(32, 1)
def forward(self, x): h0 = torch.zeros(2, x.size(0), 32) c0 = torch.zeros(2, x.size(0), 32)
out, _ = self.lstm(x, (h0, c0))
out = self.fc(out[:, -1, :]) return out
  • __init__():
    • Replace nn.RNN with nn.LSTM
  • forward():
    • Add another hidden state c
    • Initialize c and h with zeros
    • Pass both hidden states to lstm layer
Intermediate Deep Learning with PyTorch

GRU cell

Schema of the GRU cell.

  • Simplified version of LSTM cell
  • Just one hidden state
  • No output gate
Intermediate Deep Learning with PyTorch

GRU in PyTorch

class Net(nn.Module):
    def __init__(self, input_size):
        super().__init__()

self.gru = nn.GRU( input_size=1, hidden_size=32, num_layers=2, batch_first=True, ) self.fc = nn.Linear(32, 1)
def forward(self, x): h0 = torch.zeros(2, x.size(0), 32) out, _ = self.gru(x, h0) out = self.fc(out[:, -1, :]) return out
  • __init__():
    • Replace nn.RNN with nn.GRU
  • forward():
    • Use the gru layer
Intermediate Deep Learning with PyTorch

Should I use RNN, LSTM, or GRU?

  • RNN is not used much anymore
  • GRU is simpler than LSTM = less computation
  • Relative performance varies per use-case
  • Try both and compare

Schemas of the LSTM and GRU cells.

Intermediate Deep Learning with PyTorch

Let's practice!

Intermediate Deep Learning with PyTorch

Preparing Video For Download...