Deep Reinforcement Learning in Python
Timothée Carayol
Principal Machine Learning Engineer, Komment
DQN:
... # instantiate online and target networks q_values = (online_network(states) .gather(1, actions).squeeze(1))
with torch.no_grad():
# # next_q_values = (target_network(next_states) .amax(1))
target_q_values = (rewards + gamma * next_q_values * (1 - dones))
loss = torch.nn.MSELoss()(q_values, target_q_values) ... # gradient descent ... # target network update
DDQN:
... # instantiate online and target networks q_values = (online_network(states) .gather(1, actions).squeeze(1))
with torch.no_grad():
target_q_values = (rewards + gamma * next_q_values * (1 - dones))
loss = torch.nn.MSELoss()(q_values, target_q_values) ... # gradient descent ... # target network update
DQN:
... # instantiate online and target networks q_values = (online_network(states) .gather(1, actions).squeeze(1))
with torch.no_grad():
next_actions = (target_network(next_states) .argmax(1).unsqueeze(1))
next_q_values = (target_network(next_states) .gather(1, next_actions).squeeze(1))
target_q_values = (rewards + gamma * next_q_values * (1 - dones))
loss = torch.nn.MSELoss()(q_values, target_q_values) ... # gradient descent ... # target network update
DDQN:
... # instantiate online and target networks q_values = (online_network(states) .gather(1, actions).squeeze(1))
with torch.no_grad():
next_actions = (online_network(next_states) .argmax(1).unsqueeze(1))
next_q_values = (target_network(next_states) .gather(1, next_actions).squeeze(1))
target_q_values = (rewards + gamma * next_q_values * (1 - dones))
loss = torch.nn.MSELoss()(q_values, target_q_values) ... # gradient descent ... # target network update
Deep Reinforcement Learning in Python