Double DQN

Deep Reinforcement Learning dengan Python

Timothée Carayol

Principal Machine Learning Engineer, Komment

Double Q-learning

  • Q-learning melebihkan Q-value, menurunkan efisiensi pembelajaran
  • Penyebabnya: bias maksimisasi
  • Double Q-Learning menghilangkan bias dengan memisahkan pemilihan aksi dan estimasi nilai

Dua tabel Q (ilustrasi dari kursus Reinforcement Learning with Gymnasium using Python); double Q-learning menggunakan keduanya secara bergantian

Deep Reinforcement Learning dengan Python

Gagasan di balik DDQN

  • Mulai dari DQN lengkap (dengan Q-target tetap)
  • Pada target TD DQN:
    • Pemilihan aksi: network target
    • Estimasi nilai: network target
  • Pada target TD DDQN:
    • Pemilihan aksi: network online
    • Estimasi nilai: network target
  • Tidak persis double Q-learning (tidak ada pergiliran Q-network)
  • Manfaat utama, perubahan minimal

Galat Bellman (DQN dengan Q-target tetap): Q_online(s_t, a_t) - (r_t+1 + gamma max(Q_target(s_t+1, a)))

Galat Bellman (DDQN dengan Q-target tetap): Q_online(s_t, a_t) - (r_t+1 + gamma Q_target(s_t+1, ∼a)) dengan ∼a = argmax_a(Q_online(s_t+1, a))

Deep Reinforcement Learning dengan Python

Implementasi Double DQN

DQN:

... # instantiate online and target networks
q_values = (online_network(states)
            .gather(1, actions).squeeze(1))

with torch.no_grad():
# # next_q_values = (target_network(next_states) .amax(1))
target_q_values = (rewards + gamma * next_q_values * (1 - dones))
loss = torch.nn.MSELoss()(q_values, target_q_values) ... # gradient descent ... # target network update

DDQN:

... # instantiate online and target networks
q_values = (online_network(states)
            .gather(1, actions).squeeze(1))

with torch.no_grad():

target_q_values = (rewards + gamma * next_q_values * (1 - dones))
loss = torch.nn.MSELoss()(q_values, target_q_values) ... # gradient descent ... # target network update
Deep Reinforcement Learning dengan Python

Implementasi Double DQN

DQN:

... # instantiate online and target networks
q_values = (online_network(states)
            .gather(1, actions).squeeze(1))

with torch.no_grad():
next_actions = (target_network(next_states) .argmax(1).unsqueeze(1))
next_q_values = (target_network(next_states) .gather(1, next_actions).squeeze(1))
target_q_values = (rewards + gamma * next_q_values * (1 - dones))
loss = torch.nn.MSELoss()(q_values, target_q_values) ... # gradient descent ... # target network update

DDQN:

... # instantiate online and target networks
q_values = (online_network(states)
            .gather(1, actions).squeeze(1))

with torch.no_grad():
next_actions = (online_network(next_states) .argmax(1).unsqueeze(1))
next_q_values = (target_network(next_states) .gather(1, next_actions).squeeze(1))
target_q_values = (rewards + gamma * next_q_values * (1 - dones))
loss = torch.nn.MSELoss()(q_values, target_q_values) ... # gradient descent ... # target network update
Deep Reinforcement Learning dengan Python

Kinerja DDQN

 

  • Bandingkan kinerja DDQN, DQN, dan pemain manusia di game Atari
  • DDQN: skor lebih tinggi daripada DQN asli
  • Tidak selalu benar -> coba keduanya

Bagan batang yang menunjukkan DQN hampir menyamai performa manusia pada gim median, dan mencapai skor supermanusia rata-rata; serta DDQN mengungguli manusia dan DQN pada median dan rata-rata.

1 https://arxiv.org/abs/2303.11634
Deep Reinforcement Learning dengan Python

Ayo berlatih!

Deep Reinforcement Learning dengan Python

Preparing Video For Download...