Advantage Actor Critic

Deep Reinforcement Learning in Python

Timothée Carayol

Principal Machine Learning Engineer, Komment

Why actor critic?

 

  • REINFORCE limitations:

    • High variance
    • Poor sample efficiency
  • Actor Critic methods introduce a critic network, enabling Temporal Difference learning

A large rectangle labelled 'agent'; and two smaller rectangles inside it, labelled respectively 'actor' and 'critic'.

Deep Reinforcement Learning in Python

The intuition behind Actor Critic methods

Students talking around a table, with books and pens scattered around.

 

  • Actor network:

    • Makes decisions
    • Cannot evaluate them
  • Critic network:

    • Provides feedback to actor at every step
Deep Reinforcement Learning in Python

The Critic network

 

  • Critic approximates the state value function

A representation of the critic network, with the state as input and the Value function as output; as a result it only has one output node.

  • Judges action $a_t$ based on the advantage or TD-error

 

class Critic(nn.Module):
    def __init__(self, state_size):
        super(Critic, self).__init__()
        self.fc1 = nn.Linear(state_size, 64)
        self.fc2 = nn.Linear(64, 1)

def forward(self, state): x = torch.relu(self.fc1(torch.tensor(state))) value = self.fc2(x) return value
critic_network = Critic(8)
Deep Reinforcement Learning in Python

The Actor Critic dynamics

 

  • At every step:
    • Actor chooses action (same as policy network in REINFORCE)

On top: A large rectangle labelled 'agent'; and two smaller rectangles inside it, labelled respectively 'actor' and 'critic'. At the bottom: a separate rectangle labelled 'environment'.

Deep Reinforcement Learning in Python

The Actor Critic dynamics

 

  • At every step:
    • Actor chooses action (same as policy network in REINFORCE)
    • Critic observes reward and state

A red arrow labelled 'action' goes from the actor to the environment.

Deep Reinforcement Learning in Python

The Actor Critic dynamics

 

  • At every step:
    • Actor chooses action (same as policy network in REINFORCE)
    • Critic observes reward and state
    • Critic evaluates TD Error
    • Actor and Critic use TD Error to update weights

Two red arrows, labelled respectively 'State' and 'Reward', go from the environment to the Critic.

Deep Reinforcement Learning in Python

The Actor Critic dynamics

 

  • At every step:
    • Actor chooses action (same as policy network in REINFORCE)
    • Critic observes reward and state
    • Critic evaluates TD Error
    • Actor and Critic use TD Error to update weights
    • Updated Actor observes new state

An arrow labelled 'TD error' goes from Critic to Actor.

Deep Reinforcement Learning in Python

The Actor Critic dynamics

 

  • At every step:
    • Actor chooses action (same as policy network in REINFORCE)
    • Critic observes reward and state
    • Critic evaluates TD Error
    • Actor and Critic use TD Error to update weights
    • Updated Actor observes new state
  • ... start over

The State arrow now goes to the Actor as well.

Deep Reinforcement Learning in Python

The A2C losses

 

Critic

The critic loss function. Use the Squared TD error for the critic: Lc(theta c) = ((r_t + gamma * V theta c (s t + 1)) - V theta c) squared

  • Critic loss: squared TD error

 

Actor

The actor loss function. It can be shown that we can use, at every step t, the following loss function for the actor: L(theta) equals minus the action log probability times the TD error or advantage.

  • TD error captures critic rating
  • Increase probability of actions with positive TD error
Deep Reinforcement Learning in Python

Calculating the losses

 

def calculate_losses(critic_network, action_log_prob, 
                     reward, state, next_state, done):

# Critic provides the state value estimates value = critic_network(state)
next_value = critic_network(next_state)
td_target = (reward + gamma * next_value * (1-done))
td_error = td_target - value
# Apply formulas for actor and critic losses actor_loss = -action_log_prob * td_error.detach()
critic_loss = td_error ** 2
return actor_loss, critic_loss

 

 

  • Calculate TD-Error
  • Calculate actor loss
    • Use .detach() to stop gradient propagation to critic weights
  • Calculate critic loss
Deep Reinforcement Learning in Python

The Actor Critic training loop

for episode in range(10):
  state, info = env.reset()
  done = False
  while not done:

# Select action action, action_log_prob = select_action(actor, state)
next_state, reward, terminated, truncated, _ = env.step(action) done = terminated or truncated
# Calculate losses actor_loss, critic_loss = calculate_losses(critic, action_log_prob, reward, state, next_state, done)
# Update actor actor_optimizer.zero_grad(); actor_loss.backward(); actor_optimizer.step()
# Update critic critic_optimizer.zero_grad(); critic_loss.backward(); critic_optimizer.step()
state = next_state
Deep Reinforcement Learning in Python

Let's practice!

Deep Reinforcement Learning in Python

Preparing Video For Download...