Scalable AI Models with PyTorch Lightning
Sergiy Tkachuk
Director, GenAI Productivity
$$
$$
$$
Key methods:
prepare_data
: Download and set up datasetup
: Split data into train, validation, and test setsclass ImageDataModule(pl.LightningDataModule): def __init__(self, data_dir="./data", batch_size=32): super().__init__() ...
def prepare_data(self): datasets.MNIST(self.data_dir, train=True, download=True)
def setup(self, stage=None): dataset = datasets.MNIST(self.data_dir, train=True, transform=self.transform) self.train_data, self.val_data = random_split(dataset, [55000, 5000]) self.test_data = datasets.MNIST(self.data_dir, train=False, transform=self.transform)
$$
def train_dataloader(self):
return DataLoader(self.train_data, batch_size=self.batch_size, shuffle=True)
$$
def val_dataloader(self):
return DataLoader(self.val_data, batch_size=self.batch_size)
$$
def test_dataloader(self):
return DataLoader(self.test_data, batch_size=self.batch_size)
LightningDataModule
pairs with LightningModule
Scalable AI Models with PyTorch Lightning