Memory-efficient training with Adafactor

Efficient AI Model Training with PyTorch

Dennis Lee

Data Engineer

Optimizers for training efficiency

 

 

Icons representing AdamW, Adafactor, and 8-bit Adam.

Efficient AI Model Training with PyTorch

Optimizer tradeoffs

Diagram showing the tradeoffs between number of parameters and precision for AdamW, Adafactor, and 8-bit Adam.

Efficient AI Model Training with PyTorch

How does Adafactor work?

 

Diagram showing the steps of Adafactor.

Efficient AI Model Training with PyTorch

How does Adafactor work?

 

Diagram showing the steps of Adafactor.

Efficient AI Model Training with PyTorch

How does Adafactor work?

 

Diagram showing the steps of Adafactor.

 

  • EMA: exponential moving average
  • Second moment: EMA of the squared gradients
Efficient AI Model Training with PyTorch

How does Adafactor work?

 

Diagram showing the steps of Adafactor.

 

  • EMA: exponential moving average
  • Second moment: EMA of the squared gradients
Efficient AI Model Training with PyTorch

How does Adafactor save memory?

 

Diagram depicting the second moment matrix, column sum, and row sum.

  • Save memory by not storing the second moment matrix
Efficient AI Model Training with PyTorch

How does Adafactor save memory?

 

Diagram depicting the second moment matrix, column sum, and row sum.

  • Save memory by not storing the second moment matrix
  • Instead, store the column sum and row sum of the matrix
Efficient AI Model Training with PyTorch

How does Adafactor save memory?

 

Diagram depicting the second moment matrix, column sum, and row sum.

  • Save memory by not storing the second moment matrix
  • Instead, store the column sum and row sum of the matrix
Efficient AI Model Training with PyTorch

How does Adafactor save memory?

 

Diagram depicting the second moment matrix, column sum, and row sum.

  • Save memory by not storing the second moment matrix
  • Instead, store the column sum and row sum of the matrix
  • Estimate full matrix by multiplying column sum and row sum
Efficient AI Model Training with PyTorch

Trainer and Accelerator implementation

Chart comparing ease of use vs. ability to customize for Accelerator and Trainer.

Efficient AI Model Training with PyTorch

Implement Adafactor with Trainer

training_args = TrainingArguments(output_dir="./results",
                                  evaluation_strategy="epoch",

optim="adafactor")
trainer = Trainer(model=model, args=training_args, train_dataset=train_dataset, eval_dataset=validation_dataset, compute_metrics=compute_metrics) trainer.train()
{'epoch': 1.0, 'eval_accuracy': 0.6, 'eval_f1': 0.5}
Efficient AI Model Training with PyTorch

Implement Adafactor with Accelerator

# Assumes PyTorch 2.5 or higher
from torch.optim import Adafactor

optimizer = Adafactor(params=model.parameters(), lr=lr)
for batch in train_dataloader: inputs, targets = batch["input_ids"], batch["labels"] outputs = model(inputs, labels=targets) loss = outputs.loss accelerator.backward(loss) optimizer.step() lr_scheduler.step() optimizer.zero_grad() print(f"Loss = {loss}")
Loss = 0.71
Efficient AI Model Training with PyTorch

Inspect the optimizer state

  • Access the optimizer through its state
optimizer_state = optimizer.state.values()
  • Or access the optimizer through trainer
optimizer_state = trainer.optimizer.state.values()
print(optimizer_state)
dict_values([{'step': tensor(3.),
              'exp_avg_sq_row': tensor([1.0000e-30, 1.0000e-30, 1.0000e-30,  ...]), 
              'exp_avg_sq_col': tensor([4.6147e-11, 5.5115e-11, 1.6338e-10, ...])}, ...])
Efficient AI Model Training with PyTorch

Compute memory usage of Adafactor

total_size_megabytes, total_num_elements = \
    compute_optimizer_size(trainer.optimizer.state.values())
print(f"Number of Adafactor parameters: {total_num_elements:,}")
print(f"Adafactor size: {total_size_megabytes:.0f} MB")
Number of Adafactor parameters: 178,712
Adafactor size: 1 MB
  • Compare to AdamW: Adafactor uses much less memory!
Number of AdamW parameters: 131,566,188
AdamW size: 502 MB
Efficient AI Model Training with PyTorch

Let's practice!

Efficient AI Model Training with PyTorch

Preparing Video For Download...