Scalable AI Models with PyTorch Lightning
Sergiy Tkachuk
Director, GenAI Productivity
$$
Standard PyTorch:
PyTorch Lightning:
$$
LightningModule
and Trainer
from lightning.pytorch import LightningModule
from lightning.pytorch import Trainer
Key components:
Key components:
Key components:
Key points:
__init__
: Defines model architectureforward()
: Pass data through the modeltraining_step()
: Define trainingimport lightning.pytorch as pl class LightClassifier(pl.LightningModule): def __init__(self, model, criterion, optimizer):
super().__init__() self.model = model self.criterion = criterion self.optimizer = optimizer
def forward(self, x): return self.model(x)
def training_step(self, batch, batch_idx): x, y = batch logits = self(x) loss = self.criterion(logits, y) return loss
Key points:
model = LightClassifier()
trainer = Trainer(max_epochs=10, accelerator="gpu", devices=1) trainer.fit(model, train_dataloader, val_dataloader)
A set of synthetic MNIST-style datasets for four orthographies used in Afro-Asiatic and Niger-Congo languages: Ge'ez (Ethiopic), Vai, Osmanya, and N'Ko.
Scalable AI Models with PyTorch Lightning