Mixed precision training with 8-bit Adam

Efficient AI Model Training with PyTorch

Dennis Lee

Data Engineer

Optimizers for training efficiency

Diagram showing the tradeoffs between number of parameters and precision for AdamW, Adafactor, and 8-bit Adam.

Efficient AI Model Training with PyTorch

Optimizers for training efficiency

Diagram showing the tradeoffs between number of parameters and precision for AdamW, Adafactor, and 8-bit Adam.

Efficient AI Model Training with PyTorch

How does 8-bit Adam work?

Diagram showing the steps of 8-bit Adam.

  • Store parameters in FP8; optimize in FP32
Efficient AI Model Training with PyTorch

How does 8-bit Adam work?

Diagram showing the steps of 8-bit Adam.

  • Store parameters in FP8; optimize in FP32
Efficient AI Model Training with PyTorch

How does 8-bit Adam work?

Diagram showing the steps of 8-bit Adam.

  • Store parameters in FP8; optimize in FP32
Efficient AI Model Training with PyTorch

How does 8-bit Adam work?

Diagram showing the steps of 8-bit Adam.

  • Store parameters in FP8; optimize in FP32
  • EMA: exponential moving average
  • Compute the EMA of the gradients and squared gradients
Efficient AI Model Training with PyTorch

How does 8-bit Adam work?

Diagram showing the steps of 8-bit Adam.

  • Store parameters in FP8; optimize in FP32
  • EMA: exponential moving average
  • Compute the EMA of the gradients and squared gradients
Efficient AI Model Training with PyTorch

How does 8-bit Adam save memory?

Diagram depicting parameter gradients, EMA of gradients, and EMA of squared gradients.

Efficient AI Model Training with PyTorch

How does 8-bit Adam save memory?

Diagram depicting parameter gradients, EMA of gradients, and EMA of squared gradients.

Efficient AI Model Training with PyTorch

How does 8-bit Adam save memory?

Diagram depicting parameter gradients, EMA of gradients, and EMA of squared gradients.

  • Each square is a parameter, and each color is a state
  • Memory per parameter = 2 bytes = 1 byte per state * 2 states
  • Total memory = Memory per parameter (2 bytes) * Number of parameters
Efficient AI Model Training with PyTorch

Estimate memory usage of 8-bit Adam

model = AutoModelForSequenceClassification.from_pretrained(
    "distilbert-base-cased", return_dict=True)
num_parameters = sum(p.numel() for p in model.parameters())
print(f"Number of model parameters: {num_parameters:,}")
Number of model parameters: 65,783,042
estimated_memory = num_parameters * 2 / (1024 ** 2)
print(f"Estimated memory usage of 8-bit Adam: {estimated_memory:.0f} MB")
Estimated memory usage of 8-bit Adam: 125 MB
Efficient AI Model Training with PyTorch

Set up the 8-bit Adam optimizer

import bitsandbytes as bnb
from transformers.trainer_pt_utils import get_parameter_names


args = TrainingArguments(output_dir="./results")
decay_parameters = get_parameter_names(model, [nn.LayerNorm])
decay_parameters = [name for name in decay_parameters if "bias" not in name]
  • Weight decay prevents overfitting
  • decay_parameters: specify parameters for applying weight decay
  • get_parameter_names: return parameter names; ignore nn.LayerNorm layers
  • Remove bias parameters from decay_parameters
  • Don't apply weight decay to normalization layers and biases
Efficient AI Model Training with PyTorch

Set up the 8-bit Adam optimizer

optimizer_grouped_parameters = [{"params": [p for n, p in model.named_parameters() 
                                            if n in decay_parameters],
                                 "weight_decay": args.weight_decay},

{"params": [p for n, p in model.named_parameters() if n not in decay_parameters], "weight_decay": 0.0}]
adam_bnb_optim = bnb.optim.Adam8bit(optimizer_grouped_parameters,
betas=(args.adam_beta1, args.adam_beta2),
eps=args.adam_epsilon,
lr=args.learning_rate)
  • optimizer_grouped_parameters: One group applies weight decay; the other does not
  • beta1, beta2: Decay rates of 1st and 2nd moments; higher = stable, slow training
Efficient AI Model Training with PyTorch

Implement 8-bit Adam with Trainer

trainer = Trainer(model=model,
                  args=training_args,
                  train_dataset=train_dataset,
                  eval_dataset=validation_dataset,
                  optimizers=(adam_bnb_optim, None),
                  compute_metrics=compute_metrics)

trainer.train()
{'epoch': 1.0, 'eval_loss': 0.63, 'eval_accuracy': 0.67, 'eval_f1': 0.62}
{'epoch': 2.0, 'eval_loss': 0.61, 'eval_accuracy': 0.71, 'eval_f1': 0.66}
Efficient AI Model Training with PyTorch

Implement 8-bit Adam with Accelerator

model, adam_bnb_optim, train_dataloader, lr_scheduler = \
    accelerator.prepare(model, adam_bnb_optim, train_dataloader, lr_scheduler)


for batch in train_dataloader: inputs, targets = batch["input_ids"], batch["labels"] outputs = model(inputs, labels=targets) loss = outputs.loss accelerator.backward(loss) trainer.optimizer.step() lr_scheduler.step() adam_bnb_optim.zero_grad() print(f"Loss = {loss}")
Loss = 0.75
Efficient AI Model Training with PyTorch

Compute memory usage of 8-bit Adam

total_size_megabytes, total_num_elements = \
    compute_optimizer_size(trainer.optimizer.state.values())
print(f"Number of 8-bit Adam parameters: {total_num_elements:,}")
print(f"8-bit Adam size: {total_size_megabytes:.0f} MB")
Number of 8-bit Adam parameters: 131,566,188
8-bit Adam size: 128 MB
  • Compare with AdamW: 8-bit Adam uses 1/4 of the memory
Number of AdamW parameters: 131,566,188
AdamW size: 502 MB
Efficient AI Model Training with PyTorch

Let's practice!

Efficient AI Model Training with PyTorch

Preparing Video For Download...