Mixed precision training with 8-bit Adam

Efficient AI Model Training with PyTorch

Dennis Lee

Data Engineer, Amazon

Optimizers for training efficiency

Diagram showing the tradeoffs between number of parameters and precision for AdamW, Adafactor, and 8-bit Adam.

Efficient AI Model Training with PyTorch

Optimizers for training efficiency

Diagram showing the tradeoffs between number of parameters and precision for AdamW, Adafactor, and 8-bit Adam.

Efficient AI Model Training with PyTorch

How does 8-bit Adam work?

Diagram showing the steps of 8-bit Adam.

  • Store parameters in FP8; optimize in FP32
Efficient AI Model Training with PyTorch

How does 8-bit Adam work?

Diagram showing the steps of 8-bit Adam.

  • Store parameters in FP8; optimize in FP32
Efficient AI Model Training with PyTorch

How does 8-bit Adam work?

Diagram showing the steps of 8-bit Adam.

  • Store parameters in FP8; optimize in FP32
Efficient AI Model Training with PyTorch

How does 8-bit Adam work?

Diagram showing the steps of 8-bit Adam.

  • Store parameters in FP8; optimize in FP32
  • EMA: exponential moving average
  • Compute the EMA of the gradients and squared gradients
Efficient AI Model Training with PyTorch

How does 8-bit Adam work?

Diagram showing the steps of 8-bit Adam.

  • Store parameters in FP8; optimize in FP32
  • EMA: exponential moving average
  • Compute the EMA of the gradients and squared gradients
Efficient AI Model Training with PyTorch

How does 8-bit Adam save memory?

Diagram depicting parameter gradients, EMA of gradients, and EMA of squared gradients.

Efficient AI Model Training with PyTorch

How does 8-bit Adam save memory?

Diagram depicting parameter gradients, EMA of gradients, and EMA of squared gradients.

Efficient AI Model Training with PyTorch

How does 8-bit Adam save memory?

Diagram depicting parameter gradients, EMA of gradients, and EMA of squared gradients.

  • Each square is a parameter, and each color is a state
  • Memory per parameter = 2 bytes = 1 byte per state * 2 states
  • Total memory = Memory per parameter (2 bytes) * Number of parameters
Efficient AI Model Training with PyTorch

Estimate memory usage of 8-bit Adam

model = AutoModelForSequenceClassification.from_pretrained(
    "distilbert-base-cased", return_dict=True)
num_parameters = sum(p.numel() for p in model.parameters())
print(f"Number of model parameters: {num_parameters:,}")
Number of model parameters: 65,783,042
estimated_memory = num_parameters * 2 / (1024 ** 2)
print(f"Estimated memory usage of 8-bit Adam: {estimated_memory:.0f} MB")
Estimated memory usage of 8-bit Adam: 125 MB
Efficient AI Model Training with PyTorch

Set up the 8-bit Adam optimizer

import bitsandbytes as bnb
from torch import nn
from transformers.trainer_pt_utils import get_parameter_names


args = TrainingArguments(output_dir="./results")
decay_parameters = get_parameter_names(model, [nn.LayerNorm])
decay_parameters = [name for name in decay_parameters if "bias" not in name]
  • Weight decay prevents overfitting
  • decay_parameters: specify parameters for applying weight decay
  • get_parameter_names: return parameter names; ignore nn.LayerNorm layers
  • Remove bias parameters from decay_parameters
  • Don't apply weight decay to normalization layers and biases
Efficient AI Model Training with PyTorch

Set up the 8-bit Adam optimizer

optimizer_grouped_parameters = [{"params": [p for n, p in model.named_parameters() 
                                            if n in decay_parameters],
                                 "weight_decay": args.weight_decay},

{"params": [p for n, p in model.named_parameters() if n not in decay_parameters], "weight_decay": 0.0}]
adam_bnb_optim = bnb.optim.Adam8bit(optimizer_grouped_parameters,
betas=(args.adam_beta1, args.adam_beta2),
eps=args.adam_epsilon,
lr=args.learning_rate)
  • optimizer_grouped_parameters: One group applies weight decay; the other does not
  • beta1, beta2: Decay rates of 1st and 2nd moments; higher = stable, slow training
Efficient AI Model Training with PyTorch

Implement 8-bit Adam with Trainer

trainer = Trainer(model=model,
                  args=training_args,
                  train_dataset=train_dataset,
                  eval_dataset=validation_dataset,
                  optimizers=(adam_bnb_optim, None),
                  compute_metrics=compute_metrics)

trainer.train()
{'epoch': 1.0, 'eval_loss': 0.63, 'eval_accuracy': 0.67, 'eval_f1': 0.62}
{'epoch': 2.0, 'eval_loss': 0.61, 'eval_accuracy': 0.71, 'eval_f1': 0.66}
Efficient AI Model Training with PyTorch

Implement 8-bit Adam with Accelerator

model, adam_bnb_optim, train_dataloader, lr_scheduler = \
    accelerator.prepare(model, adam_bnb_optim, train_dataloader, lr_scheduler)


for batch in train_dataloader: inputs, targets = batch["input_ids"], batch["labels"] outputs = model(inputs, labels=targets) loss = outputs.loss accelerator.backward(loss) adam_bnb_optim.step() lr_scheduler.step() adam_bnb_optim.zero_grad() print(f"Loss = {loss}")
Loss = 0.75
Efficient AI Model Training with PyTorch

Compute memory usage of 8-bit Adam

total_size_megabytes, total_num_elements = \
    compute_optimizer_size(trainer.optimizer.state.values())
print(f"Number of 8-bit Adam parameters: {total_num_elements:,}")
print(f"8-bit Adam size: {total_size_megabytes:.0f} MB")
Number of 8-bit Adam parameters: 131,566,188
8-bit Adam size: 128 MB
  • Compare with AdamW: 8-bit Adam uses 1/4 of the memory
Number of AdamW parameters: 131,566,188
AdamW size: 502 MB
Efficient AI Model Training with PyTorch

Let's practice!

Efficient AI Model Training with PyTorch

Preparing Video For Download...