Exporting models with TorchScript

Scalable AI Models with PyTorch Lightning

Sergiy Tkachuk

Director, GenAI Productivity

What is TorchScript?

  • Independent of Python

  • Efficient in production

  • Examples:

    • Deployment on mobile devices

A mobile phone

Scalable AI Models with PyTorch Lightning

What is TorchScript?

  • Independent of Python

  • Efficient in production

  • Examples:

    • Deployment on mobile devices
    • High-performance inference in production systems

A mobile phone and a laptop with gears representing a production environment

Scalable AI Models with PyTorch Lightning

Converting models to TorchScript

Two methods for conversion:

  • torch.jit.trace: Uses example inputs to trace execution
  • torch.jit.script: Compiles the model by analyzing Python code

When to use:

  1. Use trace for simpler models
  2. Use script for models with control flow (e.g., loops)
import torch
import torch.nn as nn

class SimpleModel(nn.Module): def forward(self, x): return x * 2
model = SimpleModel() scripted_model = torch.jit.script(model)
Scalable AI Models with PyTorch Lightning

Saving and loading TorchScript models

$$

  • Saving the model:
    • torch.jit.save: Save the scripted model to a file
  • Loading the model:
    • torch.jit.load: Load the model back for inference

$$

# Save the model
torch.jit.save(scripted_mod,"model.pt")

# Load the model
loaded_model=torch.jit.load("model.pt")
Scalable AI Models with PyTorch Lightning

Performing inference with TorchScript

  • Steps:
    • Load the TorchScript model
    • Pass inputs to the model for predictions
    • Outputs are identical to PyTorch predictions

Example Input:

  • Input Tensor: [1.0, 2.0, 3.0]

Example Output:

  • Output Tensor: [2.0, 4.0, 6.0]
# Perform inference
input_arr = [1.0, 2.0, 3.0]
input_tensor = torch.tensor(input_arr)

output = loaded_model(input_tensor) print(output)
Scalable AI Models with PyTorch Lightning

TorchScript in a nutshell

$$

  • torch.jit.trace: Works for static models
  • torch.jit.script: Handles dynamic control flow
  • torch.jit.save: Saves the scripted model
  • torch.jit.load: Reloads for inference
Scalable AI Models with PyTorch Lightning

Let's practice!

Scalable AI Models with PyTorch Lightning

Preparing Video For Download...