The complete DQN algorithm

Deep Reinforcement Learning in Python

Timothée Carayol

Principal Machine Learning Engineer, Komment

The DQN algorithm

 

 

  • We studied DQN with Experience Replay
  • Close to DQN as first published (2015)
  • We still miss two components:
    • Epsilon-greediness -> more exploration
    • Fixed Q-targets -> more stable learning

An adventurer stands, the greek letter epsilon at her side

The capital letter Q, frozen in a block of ice

Deep Reinforcement Learning in Python

Epsilon-greediness in the DQN algorithm

  • Implement Decayed Epsilon-greediness in select_action()
def select_action(q_values, step, start, end, decay):

# Calculate the threshold value for this step epsilon = ( end + (start-end) * math.exp(-step / decay))
# Draw a random number between 0 and 1 sample = random.random()
if sample < epsilon: # Return a random action index return random.choice(range(len(q_values)))
# Return the action index with highest Q-value return torch.argmax(q_values).item()
  • $\varepsilon = end + (start-end) \cdot e^{-\frac{step}{decay}}$
  • Take random action with probability $\varepsilon$
  • Take highest value action with probability $1 - \varepsilon$

A plot representing the Epsilon decay schedule for different values of the decay parameter.

Deep Reinforcement Learning in Python

Fixed Q-targets

 

  • In Bellman Error:
    • Q-Network in both Q-Value and TD-Target calculation
    • Instability from shifting target

 

  • Introduce target network to stabilize target

 

The Bellman Error: (r_t+1 + gamma max(Q(s_t+1, a))) - Q(s_t, a_t)

(r_t+1 + gamma max(Q_target(s_t+1, a))) - Q_online(s_t, a_t)

Deep Reinforcement Learning in Python

Implementing fixed Q-targets

online_network = QNetwork(state_size, action_size)
target_network = QNetwork(state_size, action_size)

target_network.load_state_dict( online_network.state_dict())
def update_target_network( target_network, online_network, tau):
target_net_state_dict = target_network.state_dict() online_net_state_dict = online_network.state_dict() for key in online_net_state_dict:
target_net_state_dict[key] = ( online_net_state_dict[key] * tau + target_net_state_dict[key] * (1 - tau))
target_network.load_state_dict( target_net_state_dict)
return None
  • Initially Online Network = Target Network
  • A network's state dict contains all weights: the representation of a network state dictionary, with entries for fc1.weight, fc1.bias, and fc2.weight; the value for each entry is a tensor.
  • Each step, every weight of Target Network gets a bit closer to Online Network
Deep Reinforcement Learning in Python

Loss calculation with fixed Q-targets

# In the inner loop, after action selection
if len(replay_buffer) >= batch_size:
  states, actions, rewards, next_states, dones = 
      replay_buffer.sample(64)

q_values = (online_network(states) .gather(1, actions).squeeze(1))
with torch.no_grad():
next_q_values = ( target_network(next_states).amax(1)) target_q_values = ( rewards + gamma * next_q_values * (1 - dones))
loss = torch.nn.MSELoss()(target_q_values, q_values) optimizer.zero_grad() loss.backward() optimizer.step()
update_target_network( target_network, online_network, tau)

 

  • Q-values use online_network
  • Target Q-values use target_network
  • Use torch.no_grad() to disable gradient tracking for target Q-values
  • Still use Mean Squared Bellman Error for loss calculation
  • Use update_target_network() to slowly update target_network
Deep Reinforcement Learning in Python

Let's practice!

Deep Reinforcement Learning in Python

Preparing Video For Download...