Double DQN

Deep Reinforcement Learning in Python

Timothée Carayol

Principal Machine Learning Engineer, Komment

Double Q-learning

  • Q-learning overestimates Q-values, compromising learning efficiency
  • This is due to maximization bias
  • Double Q-Learning eliminates bias by decoupling action selection and value estimation

Two Q-tables (illustration from the course Reinforcement Learning with Gymnasium using Python); double Q-learning uses each of them in alternance

Deep Reinforcement Learning in Python

The idea behind DDQN

  • Start from complete DQN (with fixed Q-targets)
  • In DQN TD target:
    • Action selection: target network
    • Value estimation: target network
  • In DDQN TD target:
    • Action selection: online network
    • Value estimation: target network
  • Not exactly double Q-learning (no alternating Q-networks)
  • Most of the benefit, with minimal change

Bellman Error (DQN with fixed Q-targets): Q_online(s_t, a_t) - (r_t+1 + gamma max(Q_target(s_t+1, a)))

Bellman Error (DDQN with fixed Q-targets): Q_online(s_t, a_t) - (r_t+1 + gamma Q_target(s_t+1, tilde a)) with tilde a = argmax_a(Q_online(s_t+1, a))

Deep Reinforcement Learning in Python

Double DQN implementation

DQN:

... # instantiate online and target networks
q_values = (online_network(states)
            .gather(1, actions).squeeze(1))

with torch.no_grad():
# # next_q_values = (target_network(next_states) .amax(1))
target_q_values = (rewards + gamma * next_q_values * (1 - dones))
loss = torch.nn.MSELoss()(q_values, target_q_values) ... # gradient descent ... # target network update

DDQN:

... # instantiate online and target networks
q_values = (online_network(states)
            .gather(1, actions).squeeze(1))

with torch.no_grad():

target_q_values = (rewards + gamma * next_q_values * (1 - dones))
loss = torch.nn.MSELoss()(q_values, target_q_values) ... # gradient descent ... # target network update
Deep Reinforcement Learning in Python

Double DQN implementation

DQN:

... # instantiate online and target networks
q_values = (online_network(states)
            .gather(1, actions).squeeze(1))

with torch.no_grad():
next_actions = (target_network(next_states) .argmax(1).unsqueeze(1))
next_q_values = (target_network(next_states) .gather(1, next_actions).squeeze(1))
target_q_values = (rewards + gamma * next_q_values * (1 - dones))
loss = torch.nn.MSELoss()(q_values, target_q_values) ... # gradient descent ... # target network update

DDQN:

... # instantiate online and target networks
q_values = (online_network(states)
            .gather(1, actions).squeeze(1))

with torch.no_grad():
next_actions = (online_network(next_states) .argmax(1).unsqueeze(1))
next_q_values = (target_network(next_states) .gather(1, next_actions).squeeze(1))
target_q_values = (rewards + gamma * next_q_values * (1 - dones))
loss = torch.nn.MSELoss()(q_values, target_q_values) ... # gradient descent ... # target network update
Deep Reinforcement Learning in Python

DDQN performance

 

  • Compare performance of DDQN, DQN and human players on Atari games
  • DDQN: higher scores than original DQN
  • May not always be true -> try both

Bar chart showing DQN almost matching human performance for the median game, and achieving superhuman scores on average; and DDQN beating both human performance and DQN on median and average.

1 https://arxiv.org/abs/2303.11634
Deep Reinforcement Learning in Python

Let's practice!

Deep Reinforcement Learning in Python

Preparing Video For Download...