Managing data with LightningDataModule

Scalable AI Models with PyTorch Lightning

Sergiy Tkachuk

Director, GenAI Productivity

Data preparation for model training

  • Poorly prepared data results in training issues
    • Slow training speeds
    • Frequent interruptions
    • Convergence failure

Data preparation for training.png

Scalable AI Models with PyTorch Lightning

Why use LightningDataModule?

$$

  • 📂 Centralizes dataset handling

$$

  • 📊 Standardizes data preparation workflows

$$

  • 🚀 Simplifies training and evaluation phases
Scalable AI Models with PyTorch Lightning

Managing data with LightningDataModule

Key methods:

  • prepare_data: Download and set up data
  • setup: Split data into train, validation, and test sets
class ImageDataModule(pl.LightningDataModule):
    def __init__(self, data_dir="./data", batch_size=32):
        super().__init__()
        ...

def prepare_data(self): datasets.MNIST(self.data_dir, train=True, download=True)
def setup(self, stage=None): dataset = datasets.MNIST(self.data_dir, train=True, transform=self.transform) self.train_data, self.val_data = random_split(dataset, [55000, 5000]) self.test_data = datasets.MNIST(self.data_dir, train=False, transform=self.transform)
Scalable AI Models with PyTorch Lightning

Creating the train DataLoader

$$

  • Supplies batches of training data
  • Helps optimize GPU utilization
  • Enables efficient iteration over large datasets
def train_dataloader(self):
    return DataLoader(self.train_data, batch_size=self.batch_size, shuffle=True)
Scalable AI Models with PyTorch Lightning

Creating the validation DataLoader

$$

  • Supplies data for model validation
  • Helps monitor generalization performance
  • Ensures consistency across evaluation runs through shuffling
def val_dataloader(self):
    return DataLoader(self.val_data, batch_size=self.batch_size)
Scalable AI Models with PyTorch Lightning

Creating the test DataLoader

$$

  • Supplies data for final model evaluation after training is completed
  • Simulates real-world performance assessment
  • Ensures unbiased performance measurement
def test_dataloader(self):
    return DataLoader(self.test_data, batch_size=self.batch_size)
Scalable AI Models with PyTorch Lightning

Connecting DataModule to LightningModule

  • Modular design separates data and model logic

PyTorch Lightning diagram

Scalable AI Models with PyTorch Lightning

Connecting DataModule to LightningModule

  • Modular design separates data and model logic
  • LightningDataModule pairs with LightningModule
  • Standardized workflow enhances reproducibility

PyTorch Lightning diagram with DataModule and LightningModule

Scalable AI Models with PyTorch Lightning

Let's practice!

Scalable AI Models with PyTorch Lightning

Preparing Video For Download...