Deep Reinforcement Learning in Python
Timothée Carayol
Principal Machine Learning Engineer, Komment



DQN:
... # instantiate online and target networks q_values = (online_network(states) .gather(1, actions).squeeze(1))with torch.no_grad():# # next_q_values = (target_network(next_states) .amax(1))target_q_values = (rewards + gamma * next_q_values * (1 - dones))loss = torch.nn.MSELoss()(q_values, target_q_values) ... # gradient descent ... # target network update
DDQN:
... # instantiate online and target networks q_values = (online_network(states) .gather(1, actions).squeeze(1))with torch.no_grad():target_q_values = (rewards + gamma * next_q_values * (1 - dones))loss = torch.nn.MSELoss()(q_values, target_q_values) ... # gradient descent ... # target network update
DQN:
... # instantiate online and target networks q_values = (online_network(states) .gather(1, actions).squeeze(1))with torch.no_grad():next_actions = (target_network(next_states) .argmax(1).unsqueeze(1))next_q_values = (target_network(next_states) .gather(1, next_actions).squeeze(1))target_q_values = (rewards + gamma * next_q_values * (1 - dones))loss = torch.nn.MSELoss()(q_values, target_q_values) ... # gradient descent ... # target network update
DDQN:
... # instantiate online and target networks q_values = (online_network(states) .gather(1, actions).squeeze(1))with torch.no_grad():next_actions = (online_network(next_states) .argmax(1).unsqueeze(1))next_q_values = (target_network(next_states) .gather(1, next_actions).squeeze(1))target_q_values = (rewards + gamma * next_q_values * (1 - dones))loss = torch.nn.MSELoss()(q_values, target_q_values) ... # gradient descent ... # target network update

Deep Reinforcement Learning in Python