Het complete DQN-algoritme

Deep Reinforcement Learning in Python

Timothée Carayol

Principal Machine Learning Engineer, Komment

Het DQN-algoritme

 

 

  • We hebben DQN met Experience Replay behandeld
  • Dicht bij het oorspronkelijke DQN (2015)
  • Er missen nog twee onderdelen:
    • Epsilon-greediness -> meer exploratie
    • Fixed Q-targets -> stabieler leren

Een avonturier staat, de Griekse letter epsilon aan haar zijde

De hoofdletter Q, bevroren in een blok ijs

Deep Reinforcement Learning in Python

Epsilon-greediness in het DQN-algoritme

  • Implementeer afnemende epsilon-greediness in select_action()
def select_action(q_values, step, start, end, decay):

# Bereken de drempel voor deze stap epsilon = ( end + (start-end) * math.exp(-step / decay))
# Trek een willekeurig getal tussen 0 en 1 sample = random.random()
if sample < epsilon: # Geef een willekeurige actie-index return random.choice(range(len(q_values)))
# Geef de actie-index met hoogste Q-waarde return torch.argmax(q_values).item()
  • $\varepsilon = end + (start-end) \cdot e^{-\frac{step}{decay}}$
  • Neem een willekeurige actie met kans $\varepsilon$
  • Neem de hoogste waarde-actie met kans $1 - \varepsilon$

Een grafiek van het epsilon-vervalschema voor verschillende decay-waarden.

Deep Reinforcement Learning in Python

Fixed Q-targets

 

  • In de Bellman-error:
    • Q-netwerk in zowel Q-waarde- als TD-targetberekening
    • Instabiliteit door verschuivend target

 

  • Voeg een targetnetwerk toe om het target te stabiliseren

 

De Bellman-error: (r_t+1 + gamma max(Q(s_t+1, a))) - Q(s_t, a_t)

(r_t+1 + gamma max(Q_target(s_t+1, a))) - Q_online(s_t, a_t)

Deep Reinforcement Learning in Python

Fixed Q-targets implementeren

online_network = QNetwork(state_size, action_size)
target_network = QNetwork(state_size, action_size)

target_network.load_state_dict( online_network.state_dict())
def update_target_network( target_network, online_network, tau):
target_net_state_dict = target_network.state_dict() online_net_state_dict = online_network.state_dict() for key in online_net_state_dict:
target_net_state_dict[key] = ( online_net_state_dict[key] * tau + target_net_state_dict[key] * (1 - tau))
target_network.load_state_dict( target_net_state_dict)
return None
  • Initieel: Online Network = Target Network
  • De state dict van een netwerk bevat alle gewichten: de representatie van een state dictionary van een netwerk, met items voor fc1.weight, fc1.bias en fc2.weight; de waarde voor elk item is een tensor.
  • Elke stap komen alle gewichten van Target Network iets dichter bij Online Network
Deep Reinforcement Learning in Python

Loss-berekening met fixed Q-targets

# In de inner loop, na actieselectie
if len(replay_buffer) >= batch_size:
  states, actions, rewards, next_states, dones = 
      replay_buffer.sample(64)

q_values = (online_network(states) .gather(1, actions).squeeze(1))
with torch.no_grad():
next_q_values = ( target_network(next_states).amax(1)) target_q_values = ( rewards + gamma * next_q_values * (1 - dones))
loss = torch.nn.MSELoss()(target_q_values, q_values) optimizer.zero_grad() loss.backward() optimizer.step()
update_target_network( target_network, online_network, tau)

 

  • Q-waarden via online_network
  • Target Q-waarden via target_network
  • Gebruik torch.no_grad() om gradients uit te schakelen voor target Q-waarden
  • Nog steeds Mean Squared Bellman Error voor de loss
  • Gebruik update_target_network() om target_network langzaam bij te werken
Deep Reinforcement Learning in Python

Laten we oefenen!

Deep Reinforcement Learning in Python

Preparing Video For Download...