Implementing model pruning techniques

Scalable AI Models with PyTorch Lightning

Sergiy Tkachuk

Director, GenAI Productivity

When to use pruning?

$$

  • 📱 Helpful when deploying models to edge or embedded systems

$$

  • ➕ Can be combined with quantization for compound efficiency gains

$$

  • ⚡ Use when reducing latency or model size is a priority
Scalable AI Models with PyTorch Lightning

What is model pruning?

$$

  • Removes less important connections in neural networks
  • Leads to sparse models that are efficient to store and compute
  • A common approach is L1 unstructured pruning

pruning example

Scalable AI Models with PyTorch Lightning

What is model pruning?

import torch.nn.utils.prune as prune

prune.l1_unstructured(model.fc, name="weight",
                      amount=0.4)

print(model.fc.weight.data)
tensor([[ 0.25, -0.13,  0.05,  0.70],
        [-0.88,  0.31, -0.02,  0.44]]) # Before pruning


tensor([[ 0.25, -0.13,  0.00,  0.70],
        [ 0.00,  0.31,  0.00,  0.44]]) # After pruning (40% weights set to 0)
Scalable AI Models with PyTorch Lightning

Understanding pruning masks

$$

  • Pruning adds a binary mask to each targeted weight tensor.

$$

  • Mask = 1 --> weight kept.
  • Mask = 0 --> weight set to 0 at forward time.

$$

  • Weights are still stored in memory until the mask is removed.
Scalable AI Models with PyTorch Lightning

Making pruning permanent

  • By default, pruned weights remain part of the original tensor
  • To finalize pruning, remove reparametrizations
  • Converts sparse layer into standard layer with zeroed weights
Sequential(
  (fc): Linear(
    in_features=128, out_features=64,
    bias=True
    (weight): PrunedParam()
  )
) # Before prune.remove
import torch.nn.utils.prune as prune

prune.remove(model.fc, 'weight')

# Print model structure
print(model)

Sequential(
  (fc): Linear(in_features=128,
               out_features=64,
               bias=True)
) # After prune.remove
Scalable AI Models with PyTorch Lightning

Evaluating pruning impact

  • Compare original and pruned model performance
  • Expect slight drops in accuracy but major size/memory savings
  • Helps assess if the trade-off is acceptable for deployment

pruning_model_tradeoff.png

Scalable AI Models with PyTorch Lightning

Let's practice!

Scalable AI Models with PyTorch Lightning

Preparing Video For Download...