Gradient checkpointing and local SGD

Efficient AI Model Training with PyTorch

Dennis Lee

Data Engineer

Improving training efficiency

 

 

Icons representing memory efficiency, communication efficiency, and computational efficiency.

Efficient AI Model Training with PyTorch

Gradient checkpointing improves memory efficiency

 

 

Icons representing memory efficiency, communication efficiency, and computational efficiency.

Efficient AI Model Training with PyTorch

Local SGD addresses communication efficiency

 

 

Icons representing memory efficiency, communication efficiency, and computational efficiency.

Efficient AI Model Training with PyTorch

What is gradient checkpointing?

  • Gradient checkpointing: reduce memory by selecting which activations to save
  • Example: compute A + B = C

Graph illustrating gradient checkpointing with nodes and edges

Efficient AI Model Training with PyTorch

What is gradient checkpointing?

  • Gradient checkpointing: reduce memory by selecting which activations to save
  • Example: compute A + B = C
    • First compute A, B, then compute C

Graph illustrating gradient checkpointing with nodes and edges

Efficient AI Model Training with PyTorch

What is gradient checkpointing?

  • Gradient checkpointing: reduce memory by selecting which activations to save
  • Example: compute A + B = C
    • First compute A, B, then compute C
    • A, B not needed for rest of forward pass
  • Should we save or remove A and B?

Graph illustrating gradient checkpointing with nodes and edges

Efficient AI Model Training with PyTorch

What is gradient checkpointing?

  • Gradient checkpointing: reduce memory by selecting which activations to save
  • Example: compute A + B = C
    • First compute A, B, then compute C
    • A, B not needed for rest of forward pass
  • Should we save or remove A and B?
    • No gradient checkpointing: save A, B

Graph illustrating gradient checkpointing with nodes and edges

Efficient AI Model Training with PyTorch

What is gradient checkpointing?

  • Gradient checkpointing: reduce memory by selecting which activations to save
  • Example: compute A + B = C
    • First compute A, B, then compute C
    • A, B not needed for rest of forward pass
  • Should we save or remove A and B?
    • No gradient checkpointing: save A, B
    • Gradient checkpointing: remove A, B

Graph illustrating gradient checkpointing with nodes and edges

Efficient AI Model Training with PyTorch

What is gradient checkpointing?

  • Gradient checkpointing: reduce memory by selecting which activations to save
  • Example: compute A + B = C
    • First compute A, B, then compute C
    • A, B not needed for rest of forward pass
  • Should we save or remove A and B?
    • No gradient checkpointing: save A, B
    • Gradient checkpointing: remove A, B
    • Recompute A, B during backward pass

Graph illustrating gradient checkpointing with nodes and edges

Efficient AI Model Training with PyTorch

What is gradient checkpointing?

  • Gradient checkpointing: reduce memory by selecting which activations to save
  • Example: compute A + B = C
    • First compute A, B, then compute C
    • A, B not needed for rest of forward pass
  • Should we save or remove A and B?
    • No gradient checkpointing: save A, B
    • Gradient checkpointing: remove A, B
    • Recompute A, B during backward pass
    • If B is expensive to recompute, save it

Graph illustrating gradient checkpointing with nodes and edges

Efficient AI Model Training with PyTorch

Trainer and Accelerator

Chart comparing ease of use vs. ability to customize for Accelerator and Trainer.

Efficient AI Model Training with PyTorch

Trainer and Accelerator

Chart comparing ease of use vs. ability to customize for Accelerator and Trainer.

Efficient AI Model Training with PyTorch

Gradient checkpointing with Trainer

training_args = TrainingArguments(output_dir="./results",
                                  evaluation_strategy="epoch",
                                  gradient_accumulation_steps=4)







Efficient AI Model Training with PyTorch

Gradient checkpointing with Trainer

training_args = TrainingArguments(output_dir="./results",
                                  evaluation_strategy="epoch",
                                  gradient_accumulation_steps=4,
                                  gradient_checkpointing=True)

trainer = Trainer(model=model, args=training_args, train_dataset=dataset["train"], eval_dataset=dataset["validation"], compute_metrics=compute_metrics)
trainer.train()
{'epoch': 1.0, 'eval_loss': 0.73, 'eval_accuracy': 0.03, 'eval_f1': 0.05}
Efficient AI Model Training with PyTorch

From Trainer to Accelerator

Chart comparing ease of use vs. ability to customize for Accelerator and Trainer.

Efficient AI Model Training with PyTorch

Gradient checkpointing with Accelerator

accelerator = Accelerator(gradient_accumulation_steps=2)


for index, batch in enumerate(dataloader):
    with accelerator.accumulate(model):
        inputs, targets = batch["input_ids"], batch["labels"]
        outputs = model(inputs, labels=targets)
        loss = outputs.loss
        accelerator.backward(loss)
        optimizer.step()
        lr_scheduler.step()
        optimizer.zero_grad()
Efficient AI Model Training with PyTorch

Gradient checkpointing with Accelerator

accelerator = Accelerator(gradient_accumulation_steps=2)
model.gradient_checkpointing_enable()

for index, batch in enumerate(dataloader):
    with accelerator.accumulate(model):
        inputs, targets = batch["input_ids"], batch["labels"]
        outputs = model(inputs, labels=targets)
        loss = outputs.loss
        accelerator.backward(loss)
        optimizer.step()
        lr_scheduler.step()
        optimizer.zero_grad()
Efficient AI Model Training with PyTorch

Local SGD improves communication efficiency

 

 

Icons representing memory efficiency, communication efficiency, and computational efficiency.

Efficient AI Model Training with PyTorch

What is local SGD?

Diagram showing how local SGD works by synchronizing gradients after a certain number of steps.

  • Each device computes gradients in parallel
Efficient AI Model Training with PyTorch

What is local SGD?

Diagram showing how local SGD works by synchronizing gradients after a certain number of steps.

  • Each device computes gradients in parallel
  • Gradient synchronization: Driver node updates model parameters on each device
  • Local SGD: Reduce frequency of gradient synchronization
Efficient AI Model Training with PyTorch

Local SGD with Accelerator





for index, batch in enumerate(dataloader):
    with accelerator.accumulate(model):
        inputs, targets = batch["input_ids"], batch["labels"]
        outputs = model(inputs, labels=targets)
        loss = outputs.loss
        accelerator.backward(loss)
        optimizer.step()
        lr_scheduler.step()
        optimizer.zero_grad()

Efficient AI Model Training with PyTorch

Local SGD with Accelerator

from accelerate.local_sgd import LocalSGD

with LocalSGD(accelerator=accelerator, model=model, local_sgd_steps=8, 
              enabled=True) as local_sgd:
    for index, batch in enumerate(dataloader):
        with accelerator.accumulate(model):
            inputs, targets = batch["input_ids"], batch["labels"]
            outputs = model(inputs, labels=targets)
            loss = outputs.loss
            accelerator.backward(loss)
            optimizer.step()
            lr_scheduler.step()
            optimizer.zero_grad()

Efficient AI Model Training with PyTorch

Local SGD with Accelerator

from accelerate.local_sgd import LocalSGD

with LocalSGD(accelerator=accelerator, model=model, local_sgd_steps=8, 
              enabled=True) as local_sgd:
    for index, batch in enumerate(dataloader):
        with accelerator.accumulate(model):
            inputs, targets = batch["input_ids"], batch["labels"]
            outputs = model(inputs, labels=targets)
            loss = outputs.loss
            accelerator.backward(loss)
            optimizer.step()
            lr_scheduler.step()
            optimizer.zero_grad()
            local_sgd.step()
Efficient AI Model Training with PyTorch

Let's practice!

Efficient AI Model Training with PyTorch

Preparing Video For Download...