Double DQN

Deep Reinforcement Learning in Python

Timothée Carayol

Principal Machine Learning Engineer, Komment

Double Q-learning

  • Q-learning overschat Q-waarden, wat leren minder efficiënt maakt
  • Oorzaak: maximisatiebias
  • Double Q-learning verwijdert bias door actiekeuze en waardeschatting te scheiden

Twee Q-tabellen (illustratie uit de cursus Reinforcement Learning with Gymnasium using Python); double Q-learning gebruikt ze afwisselend

Deep Reinforcement Learning in Python

Het idee achter DDQN

  • Begin met complete DQN (met vaste Q-targets)
  • In DQN TD-doel:
    • Actiekeuze: targetnetwerk
    • Waardeschatting: targetnetwerk
  • In DDQN TD-doel:
    • Actiekeuze: online netwerk
    • Waardeschatting: targetnetwerk
  • Niet precies double Q-learning (geen alternerende Q-netwerken)
  • Wel meeste voordeel, met minimale aanpassing

Bellman-fout (DQN met vaste Q-targets): Q_online(s_t, a_t) - (r_t+1 + gamma max(Q_target(s_t+1, a)))

Bellman-fout (DDQN met vaste Q-targets): Q_online(s_t, a_t) - (r_t+1 + gamma Q_target(s_t+1, tilde a)) met tilde a = argmax_a(Q_online(s_t+1, a))

Deep Reinforcement Learning in Python

Double DQN-implementatie

DQN:

... # instantiate online and target networks
q_values = (online_network(states)
            .gather(1, actions).squeeze(1))

with torch.no_grad():
# # next_q_values = (target_network(next_states) .amax(1))
target_q_values = (rewards + gamma * next_q_values * (1 - dones))
loss = torch.nn.MSELoss()(q_values, target_q_values) ... # gradient descent ... # target network update

DDQN:

... # instantiate online and target networks
q_values = (online_network(states)
            .gather(1, actions).squeeze(1))

with torch.no_grad():

target_q_values = (rewards + gamma * next_q_values * (1 - dones))
loss = torch.nn.MSELoss()(q_values, target_q_values) ... # gradient descent ... # target network update
Deep Reinforcement Learning in Python

Double DQN-implementatie

DQN:

... # instantiate online and target networks
q_values = (online_network(states)
            .gather(1, actions).squeeze(1))

with torch.no_grad():
next_actions = (target_network(next_states) .argmax(1).unsqueeze(1))
next_q_values = (target_network(next_states) .gather(1, next_actions).squeeze(1))
target_q_values = (rewards + gamma * next_q_values * (1 - dones))
loss = torch.nn.MSELoss()(q_values, target_q_values) ... # gradient descent ... # target network update

DDQN:

... # instantiate online and target networks
q_values = (online_network(states)
            .gather(1, actions).squeeze(1))

with torch.no_grad():
next_actions = (online_network(next_states) .argmax(1).unsqueeze(1))
next_q_values = (target_network(next_states) .gather(1, next_actions).squeeze(1))
target_q_values = (rewards + gamma * next_q_values * (1 - dones))
loss = torch.nn.MSELoss()(q_values, target_q_values) ... # gradient descent ... # target network update
Deep Reinforcement Learning in Python

DDQN-prestaties

 

  • Vergelijk DDQN, DQN en mensen op Atari-games
  • DDQN: hogere scores dan originele DQN
  • Niet altijd zo -> probeer beide

Staafdiagram dat laat zien dat DQN bijna menselijke prestaties evenaart voor het mediane spel en gemiddeld supermenselijke scores haalt; en DDQN dat zowel mensen als DQN verslaat op mediaan en gemiddeld.

1 https://arxiv.org/abs/2303.11634
Deep Reinforcement Learning in Python

Laten we oefenen!

Deep Reinforcement Learning in Python

Preparing Video For Download...