Defining models with LightningModule

Scalable AI Models with PyTorch Lightning

Sergiy Tkachuk

Director, GenAI Productivity

LightningModule in focus

  1. Encapsulates your model architecture
  2. Organizes training logic into a single, manageable unit
  3. Blueprint that brings order and clarity to deep learning projects

PyTorch LightningModule diagram

Scalable AI Models with PyTorch Lightning

Defining the init method

Key tasks:

  • Model's initialization
  • super():
    • Automated handling of training loops
    • Logging
    • Checkpointing
  • Model layers defined after initialization
  • Modular and easy to maintain
import lightning.pytorch as pl
import torch.nn as nn

class ClassificationModel(pl.LightningModule):
    def __init__(self, input_dim,
                 hidden_dim, num_class):
          # Initialize parent class
        super().__init__()

# First layer self.layer1 = nn.Linear(input_dim, hidden_dim) # Activation function self.relu = nn.ReLU() # Output layer self.layer2 = nn.Linear(hidden_dim, num_class)
Scalable AI Models with PyTorch Lightning

Implementing the forward method

Key steps:

  • Define data flow through network
  • Process input through layers sequentially
    • Linear transformation
    • Activation
    • Last layer and output
import lightning.pytorch as pl
import torch.nn as nn

class ClassificationModel(pl.LightningModule):
    def __init__(self, input_dim,
                 hidden_dim, num_class):
          ...

def forward(self, x):
x = self.layer1(x) # Pass input
x = nn.ReLU(x) # Apply activation
x = self.layer2(x) # Compute output
return x # Return result
Scalable AI Models with PyTorch Lightning

Example: classifying hand written digits

import lightning.pytorch as pl
from torch.utils.data import DataLoader
from torchvision.datasets import MNIST
from torchvision import transforms

transform = transforms.ToTensor() train_ds = MNIST(root='.', train=True, download=True, transform=transform) test_ds = MNIST(root='.', train=False, download=True, transform=transform) train_loader = DataLoader(train_ds, batch_size=64, shuffle=True) test_loader = DataLoader(test_ds, batch_size=64)
model = ClassificationModel(input_dim=28*28, hidden_dim=128, num_class=10)
trainer = pl.Trainer(max_epochs=3, accelerator='auto') trainer.fit(model, train_loader, test_loader)
Scalable AI Models with PyTorch Lightning

Integrating the model with classification tasks

$$

  • Focus on classification use case
  • Entire flow within LightningModule
  • Output raw outputs for softmax activation
  • Integration with Lightning Trainer
class ClassificationModel(pl.LightningModule):
  def __init__(self, input_dim, 
               hidden_dim, output_dim):
    super().__init__()

self.hid = nn.Linear(input_dim, hidden_dim) self.out = nn.Linear(hidden_dim, output_dim)
def forward(self, x):
x = self.hidden(x) x = nn.ReLU(x) x = self.output(x)
return x
Scalable AI Models with PyTorch Lightning

Let's practice!

Scalable AI Models with PyTorch Lightning

Preparing Video For Download...