Çift DQN

Python ile Deep Reinforcement Learning

Timothée Carayol

Principal Machine Learning Engineer, Komment

Çift Q-öğrenme

  • Q-öğrenme Q-değerlerini fazla tahmin eder, öğrenme verimini düşürür
  • Nedeni: maksimize etme yanlılığı
  • Çift Q-öğrenme, eylem seçimi ile değer tahminini ayırarak yanlılığı giderir

İki Q-tablosu (Reinforcement Learning with Gymnasium using Python kursundan görsel); çift Q-öğrenme bunları dönüşümlü kullanır

Python ile Deep Reinforcement Learning

DDQN’in ardındaki fikir

  • Tüm DQN’den başlayın (sabit Q-hedefleriyle)
  • DQN’de TD hedefi:
    • Eylem seçimi: hedef ağ
    • Değer tahmini: hedef ağ
  • DDQN’de TD hedefi:
    • Eylem seçimi: çevrim içi ağ
    • Değer tahmini: hedef ağ
  • Tam olarak çift Q-öğrenme değil (sırayla Q-ağları yok)
  • Az değişiklikle büyük kazanım

Bellman Hatası (sabit Q-hedefli DQN): Q_online(s_t, a_t) - (r_t+1 + gamma max(Q_target(s_t+1, a)))

Bellman Hatası (sabit Q-hedefli DDQN): Q_online(s_t, a_t) - (r_t+1 + gamma Q_target(s_t+1, tilde a)) burada tilde a = argmax_a(Q_online(s_t+1, a))

Python ile Deep Reinforcement Learning

Çift DQN uygulaması

DQN:

... # instantiate online and target networks
q_values = (online_network(states)
            .gather(1, actions).squeeze(1))

with torch.no_grad():
# # next_q_values = (target_network(next_states) .amax(1))
target_q_values = (rewards + gamma * next_q_values * (1 - dones))
loss = torch.nn.MSELoss()(q_values, target_q_values) ... # gradient descent ... # target network update

DDQN:

... # instantiate online and target networks
q_values = (online_network(states)
            .gather(1, actions).squeeze(1))

with torch.no_grad():

target_q_values = (rewards + gamma * next_q_values * (1 - dones))
loss = torch.nn.MSELoss()(q_values, target_q_values) ... # gradient descent ... # target network update
Python ile Deep Reinforcement Learning

Çift DQN uygulaması

DQN:

... # instantiate online and target networks
q_values = (online_network(states)
            .gather(1, actions).squeeze(1))

with torch.no_grad():
next_actions = (target_network(next_states) .argmax(1).unsqueeze(1))
next_q_values = (target_network(next_states) .gather(1, next_actions).squeeze(1))
target_q_values = (rewards + gamma * next_q_values * (1 - dones))
loss = torch.nn.MSELoss()(q_values, target_q_values) ... # gradient descent ... # target network update

DDQN:

... # instantiate online and target networks
q_values = (online_network(states)
            .gather(1, actions).squeeze(1))

with torch.no_grad():
next_actions = (online_network(next_states) .argmax(1).unsqueeze(1))
next_q_values = (target_network(next_states) .gather(1, next_actions).squeeze(1))
target_q_values = (rewards + gamma * next_q_values * (1 - dones))
loss = torch.nn.MSELoss()(q_values, target_q_values) ... # gradient descent ... # target network update
Python ile Deep Reinforcement Learning

DDQN performansı

 

  • Atari oyunlarında DDQN, DQN ve insan performansını karşılaştırın
  • DDQN: özgün DQN’den daha yüksek skorlar
  • Her zaman böyle olmayabilir → ikisini de deneyin

Çubuk grafiği: DQN, medyan oyunda neredeyse insanı yakalıyor ve ortalamada insanüstü; DDQN ise medyan ve ortalamada hem insanı hem DQN’i geçiyor.

1 https://arxiv.org/abs/2303.11634
Python ile Deep Reinforcement Learning

Hadi pratik yapalım!

Python ile Deep Reinforcement Learning

Preparing Video For Download...