Algoritme DQN lengkap

Deep Reinforcement Learning dengan Python

Timothée Carayol

Principal Machine Learning Engineer, Komment

Algoritme DQN

 

 

  • Kita mempelajari DQN dengan Experience Replay
  • Hampir sama dengan DQN versi publikasi awal (2015)
  • Masih kurang dua komponen:
    • Epsilon-greediness -> lebih banyak eksplorasi
    • Fixed Q-targets -> pembelajaran lebih stabil

Seorang petualang berdiri, huruf Yunani epsilon di sisinya

Huruf kapital Q, membeku dalam balok es

Deep Reinforcement Learning dengan Python

Epsilon-greediness dalam algoritme DQN

  • Terapkan Decayed Epsilon-greediness di select_action()
def select_action(q_values, step, start, end, decay):

# Hitung nilai ambang untuk langkah ini epsilon = ( end + (start-end) * math.exp(-step / decay))
# Ambil angka acak antara 0 dan 1 sample = random.random()
if sample < epsilon: # Kembalikan indeks aksi acak return random.choice(range(len(q_values)))
# Kembalikan indeks aksi dengan Q-value tertinggi return torch.argmax(q_values).item()
  • $\varepsilon = end + (start-end) \cdot e^{-\frac{step}{decay}}$
  • Ambil aksi acak dengan probabilitas $\varepsilon$
  • Ambil aksi bernilai tertinggi dengan probabilitas $1 - \varepsilon$

Plot yang merepresentasikan jadwal peluruhan Epsilon untuk berbagai nilai parameter decay.

Deep Reinforcement Learning dengan Python

Fixed Q-targets

 

  • Dalam Bellman Error:
    • Q-Network dipakai untuk Q-value dan perhitungan TD-target
    • Tidak stabil karena target berubah-ubah

 

  • Perkenalkan target network untuk menstabilkan target

 

Bellman Error: (r_t+1 + gamma max(Q(s_t+1, a))) - Q(s_t, a_t)

(r_t+1 + gamma max(Q_target(s_t+1, a))) - Q_online(s_t, a_t)

Deep Reinforcement Learning dengan Python

Menerapkan fixed Q-targets

online_network = QNetwork(state_size, action_size)
target_network = QNetwork(state_size, action_size)

target_network.load_state_dict( online_network.state_dict())
def update_target_network( target_network, online_network, tau):
target_net_state_dict = target_network.state_dict() online_net_state_dict = online_network.state_dict() for key in online_net_state_dict:
target_net_state_dict[key] = ( online_net_state_dict[key] * tau + target_net_state_dict[key] * (1 - tau))
target_network.load_state_dict( target_net_state_dict)
return None
  • Awalnya Online Network = Target Network
  • State dict suatu network berisi semua bobot: representasi kamus status network, dengan entri fc1.weight, fc1.bias, dan fc2.weight; nilai tiap entri adalah tensor.
  • Tiap langkah, setiap bobot Target Network makin mendekati Online Network
Deep Reinforcement Learning dengan Python

Perhitungan loss dengan fixed Q-targets

# Di inner loop, setelah pemilihan aksi
if len(replay_buffer) >= batch_size:
  states, actions, rewards, next_states, dones = 
      replay_buffer.sample(64)

q_values = (online_network(states) .gather(1, actions).squeeze(1))
with torch.no_grad():
next_q_values = ( target_network(next_states).amax(1)) target_q_values = ( rewards + gamma * next_q_values * (1 - dones))
loss = torch.nn.MSELoss()(target_q_values, q_values) optimizer.zero_grad() loss.backward() optimizer.step()
update_target_network( target_network, online_network, tau)

 

  • Q-value pakai online_network
  • Target Q-value pakai target_network
  • Gunakan torch.no_grad() untuk menonaktifkan pelacakan gradien pada target Q-value
  • Tetap gunakan Mean Squared Bellman Error untuk menghitung loss
  • Gunakan update_target_network() untuk memperbarui target_network secara bertahap
Deep Reinforcement Learning dengan Python

Ayo berlatih!

Deep Reinforcement Learning dengan Python

Preparing Video For Download...