Batch updates in policy gradient

Deep Reinforcement Learning in Python

Timothée Carayol

Principal Machine Learning Engineer, Komment

Stepwise vs batch gradient updates

A big box representing an episode.

Deep Reinforcement Learning in Python

Stepwise vs batch gradient updates

In the large box, a smaller box appears, representing step 1. Within it, another box with the text 'select action.'

Deep Reinforcement Learning in Python

Stepwise vs batch gradient updates

In the step 1 box, another small box appears with the text 'iterate environment'

Deep Reinforcement Learning in Python

Stepwise vs batch gradient updates

Underneath the step 1 box, another box with the labels 'calculate loss' and 'gradient descent'

Deep Reinforcement Learning in Python

Stepwise vs batch gradient updates

An identical pair of boxes appear for the second step, with the same content

Deep Reinforcement Learning in Python

Stepwise vs batch gradient updates

Step 3 and step 4 appear as well.

Deep Reinforcement Learning in Python

Batching the A2C / PPO updates

A large episode box; taking up half its area, another box labelled 'rollout 1'; within it,two empty boxes labelled 'step 1' and 'step 2'

Deep Reinforcement Learning in Python

Batching the A2C / PPO updates

In the step 1 box, the labels 'select action' and 'iterate environment' appear.

Deep Reinforcement Learning in Python

Batching the A2C / PPO updates

Same for step 2.

Deep Reinforcement Learning in Python

Batching the A2C / PPO updates

Underneath the step 1 and step 2 boxes, appears a single 'calculate loss' label, and a single 'gradient descent' label.

Deep Reinforcement Learning in Python

Batching the A2C / PPO updates

The remaining half of the episode area is now taken up by another identical rollout box with two steps, labelled 'rollout 2'.

Deep Reinforcement Learning in Python

The A2C training loop with batch updates

 

# Set rollout length
rollout_length = 10

# Initiate loss batches
actor_losses = torch.tensor([]) critic_losses = torch.tensor([])
  • Initiate loss batches
  • Iterate through episodes and steps as usual

 

for episode in range(10):
  state, info = env.reset()
  done = False
  while not done:
    action, action_log_prob = select_action(actor, 
                                            state)                
    next_state, reward, terminated, truncated, _ = (
                                   env.step(action))
    done = terminated or truncated    
    actor_loss, critic_loss = calculate_losses(
        critic, action_log_prob, 
        reward, state, next_state, done)
    ...
Deep Reinforcement Learning in Python

The A2C training loop with batch updates

  ...
  actor_losses = torch.cat((actor_losses, actor_loss))
  critic_losses = torch.cat((critic_losses, critic_loss))

# If rollout is full, update the networks if len(actor_losses) >= rollout_length:
actor_loss_batch = actor_losses.mean() critic_loss_batch = critic_losses.mean()
actor_optimizer.zero_grad() actor_loss_batch.backward() actor_optimizer.step() critic_optimizer.zero_grad() critic_loss_batch.backward() critic_optimizer.step()
actor_losses = torch.tensor([]) critic_losses = torch.tensor([])
state = next_state

 

  • Append step loss to the loss batches
  • When rollout is full:
    • Take the batch average loss with .mean()
    • Perform gradient descent
    • Reinitialize the batch losses
Deep Reinforcement Learning in Python

A2C / PPO with multiple agents

 

Two horizontal strips represent agent 1 and agent 2. Each agent experiences respectively 4 and 3 episodes of varying lengths. Within each each episode, step boxes are visible as per the previous slides. Under the two strips, three rollout boxes are visible, each covering an interval of 8 steps. Within each rollout box, the labels 'calculate loss' and 'gradient descent' are visible. On the top of the graph, a legend area indicates "rollout length: 8 steps; number of agents: 2"

Deep Reinforcement Learning in Python

Rollouts and minibatches

Two agent stripes identical to the previous slide. Underneath, 3 rollout boxes are again visible, but their content has changed. They now have, on top, a long box labelled 'shuffle'. Under that, they are divided lengthwise into 4 boxes each, labelled 'minibatch'; within each minibatch is a 'calculate loss' and 'gradient descent' box. On top of the drawing, a legend indicates: 'Rollout length: 8 steps; minibatch size: 4 (2x2); number of agents: 2'

Deep Reinforcement Learning in Python

PPO with multiple epochs

A drawing very similar to the one before, except for the rollout batchs which are now also split vertically into 4 areas: the top one is a 'shuffle' label; the second one is a large box labelled 'epoch 1' containing 4 minibatches spread lengthwise; the third one is a 'reshuffle' label; the last one is a large box labelled 'epoch 2', also containing 4 minibatches. The legend says: 'Rollout length: 8 steps; minibatch size: 4 (2x2); number of agents: 2; number of epochs: 2'.

Deep Reinforcement Learning in Python

Let's practice!

Deep Reinforcement Learning in Python

Preparing Video For Download...