Scalable AI Models with PyTorch Lightning
Sergiy Tkachuk
Director, GenAI Productivity
Independent of Python
Efficient in production
Examples:
Independent of Python
Efficient in production
Examples:
Two methods for conversion:
torch.jit.trace
: Uses example inputs to trace executiontorch.jit.script
: Compiles the model by analyzing Python codeWhen to use:
trace
for simpler modelsscript
for models with control flow (e.g., loops)import torch import torch.nn as nn
class SimpleModel(nn.Module): def forward(self, x): return x * 2
model = SimpleModel() scripted_model = torch.jit.script(model)
$$
torch.jit.save
: Save the scripted model to a filetorch.jit.load
: Load the model back for inference$$
# Save the model
torch.jit.save(scripted_mod,"model.pt")
# Load the model
loaded_model=torch.jit.load("model.pt")
Example Input:
[1.0, 2.0, 3.0]
Example Output:
[2.0, 4.0, 6.0]
# Perform inference input_arr = [1.0, 2.0, 3.0] input_tensor = torch.tensor(input_arr)
output = loaded_model(input_tensor) print(output)
$$
torch.jit.trace
: Works for static modelstorch.jit.script
: Handles dynamic control flowtorch.jit.save
: Saves the scripted modeltorch.jit.load
: Reloads for inferenceScalable AI Models with PyTorch Lightning