Células LSTM y GRU

Aprendizaje profundo intermedio con PyTorch

Michal Oleszak

Machine Learning Engineer

Problema de memoria a corto plazo

  • Las células RNN mantienen la memoria a través de un estado oculto.
  • Esta memoria es muy a corto plazo.
  • Dos celdas más potentes resuelven el problema:
    • Célula LSTM (memoria a corto y largo plazo)
    • Célula GRU (unidad recurrente con puerta)

Esquema de la neurona recurrente. En el paso de tiempo 2, recibe las entradas h2 y x2, y produce las salidas y2 y h3.

Aprendizaje profundo intermedio con PyTorch

célula RNN

Esquema de la célula RNN.

  • Dos entradas:
    • datos de entrada actuales x
    • estado oculto anterior h
  • Dos salidas:
    • salida de corriente y
    • siguiente estado oculto h
Aprendizaje profundo intermedio con PyTorch

celda LSTM

Esquema de la celda LSTM.

  • Los resultados h y y son los mismos.
  • Tres entradas y salidas (dos estados ocultos):

    • h: estado a corto plazo
    • c: estado a largo plazo
  • Tres «puertas»:

    • Olvídate de la puerta: qué eliminar de la memoria a largo plazo
    • Puerta de entrada: qué guardar en la memoria a largo plazo
    • Puerta de salida: qué devolver en el paso de tiempo actual.
Aprendizaje profundo intermedio con PyTorch

LSTM en PyTorch

class Net(nn.Module):
    def __init__(self, input_size):
        super().__init__()

self.lstm = nn.LSTM( input_size=1, hidden_size=32, num_layers=2, batch_first=True, ) self.fc = nn.Linear(32, 1)
def forward(self, x): h0 = torch.zeros(2, x.size(0), 32) c0 = torch.zeros(2, x.size(0), 32)
out, _ = self.lstm(x, (h0, c0))
out = self.fc(out[:, -1, :]) return out
  • __init__():
    • Sustituir nn.RNN por nn.LSTM
  • forward():
    • Añade otro estado oculto c
    • Inicializa c y h con ceros.
    • Pasa ambos estados ocultos a una cap lstm.
Aprendizaje profundo intermedio con PyTorch

célula GRU

Esquema de la célula GRU.

  • Versión simplificada de la celda LSTM
  • Solo un estado oculto
  • Sin salida de puerta
Aprendizaje profundo intermedio con PyTorch

GRU en PyTorch

class Net(nn.Module):
    def __init__(self, input_size):
        super().__init__()

self.gru = nn.GRU( input_size=1, hidden_size=32, num_layers=2, batch_first=True, ) self.fc = nn.Linear(32, 1)
def forward(self, x): h0 = torch.zeros(2, x.size(0), 32) out, _ = self.gru(x, h0) out = self.fc(out[:, -1, :]) return out
  • __init__():
    • Sustituir nn.RNN por nn.GRU
  • forward():
    • Utiliza la capa « gru » (Ajustes globales).
Aprendizaje profundo intermedio con PyTorch

¿Debería usar RNN, LSTM o GRU?

  • Las RNN ya no se utilizan mucho.
  • GRU es más sencillo que LSTM = menos computación
  • El rendimiento relativo varía según el caso de uso.
  • Prueba ambos y compara.

Esquemas de las células LSTM y GRU.

Aprendizaje profundo intermedio con PyTorch

¡Vamos a practicar!

Aprendizaje profundo intermedio con PyTorch

Preparing Video For Download...