LSTM- und GRU-Zellen

Deep Learning mit PyTorch für Fortgeschrittene

Michal Oleszak

Machine Learning Engineer

Kurzzeitgedächtnis-Problem

  • RNN-Zellen halten Speicher über den Hidden State
  • Dieser Speicher ist sehr kurzfristig
  • Zwei stärkere Zellen lösen das Problem:
    • LSTM (Long Short-Term Memory)
    • GRU (Gated Recurrent Unit)

Schema des rekurrenten Neurons. Zum Zeitschritt 2 erhält es Eingaben h2 und x2 und gibt y2 und h3 aus.

Deep Learning mit PyTorch für Fortgeschrittene

RNN-Zelle

Schema der RNN-Zelle.

  • Zwei Eingaben:
    • aktueller Input x
    • vorheriger Hidden State h
  • Zwei Ausgaben:
    • aktueller Output y
    • nächster Hidden State h
Deep Learning mit PyTorch für Fortgeschrittene

LSTM-Zelle

Schema der LSTM-Zelle.

  • Ausgaben h und y sind identisch
  • Drei Eingaben und Ausgaben (zwei Hidden States):

    • h: Kurzzeitzustand
    • c: Langzeitzustand
  • Drei „Gates“:

    • Forget-Gate: was aus dem Langzeitgedächtnis entfernt wird
    • Input-Gate: was im Langzeitgedächtnis gespeichert wird
    • Output-Gate: was im aktuellen Zeitschritt ausgegeben wird
Deep Learning mit PyTorch für Fortgeschrittene

LSTM in PyTorch

class Net(nn.Module):
    def __init__(self, input_size):
        super().__init__()

self.lstm = nn.LSTM( input_size=1, hidden_size=32, num_layers=2, batch_first=True, ) self.fc = nn.Linear(32, 1)
def forward(self, x): h0 = torch.zeros(2, x.size(0), 32) c0 = torch.zeros(2, x.size(0), 32)
out, _ = self.lstm(x, (h0, c0))
out = self.fc(out[:, -1, :]) return out
  • __init__():
    • Ersetze nn.RNN durch nn.LSTM
  • forward():
    • Füge einen weiteren Hidden State c hinzu
    • Initialisiere c und h mit Nullen
    • Übergib beide Hidden States an die lstm-Schicht
Deep Learning mit PyTorch für Fortgeschrittene

GRU-Zelle

Schema der GRU-Zelle.

  • Vereinfachte Version der LSTM-Zelle
  • Nur ein Hidden State
  • Kein Output-Gate
Deep Learning mit PyTorch für Fortgeschrittene

GRU in PyTorch

class Net(nn.Module):
    def __init__(self, input_size):
        super().__init__()

self.gru = nn.GRU( input_size=1, hidden_size=32, num_layers=2, batch_first=True, ) self.fc = nn.Linear(32, 1)
def forward(self, x): h0 = torch.zeros(2, x.size(0), 32) out, _ = self.gru(x, h0) out = self.fc(out[:, -1, :]) return out
  • __init__():
    • Ersetze nn.RNN durch nn.GRU
  • forward():
    • Verwende die gru-Schicht
Deep Learning mit PyTorch für Fortgeschrittene

RNN, LSTM oder GRU?

  • RNN wird kaum noch genutzt
  • GRU ist einfacher als LSTM = weniger Rechenaufwand
  • Relative Performance hängt vom Use Case ab
  • Probiere beide und vergleiche

Schemata der LSTM- und GRU-Zellen.

Deep Learning mit PyTorch für Fortgeschrittene

Lass uns üben!

Deep Learning mit PyTorch für Fortgeschrittene

Preparing Video For Download...