Deep Reinforcement Learning in Python
Timothée Carayol
Principal Machine Learning Engineer, Komment
REINFORCE limitations:
Actor Critic methods introduce a critic network, enabling Temporal Difference learning
Actor network:
Critic network:
class Critic(nn.Module): def __init__(self, state_size): super(Critic, self).__init__() self.fc1 = nn.Linear(state_size, 64) self.fc2 = nn.Linear(64, 1)
def forward(self, state): x = torch.relu(self.fc1(torch.tensor(state))) value = self.fc2(x) return value
critic_network = Critic(8)
def calculate_losses(critic_network, action_log_prob, reward, state, next_state, done):
# Critic provides the state value estimates value = critic_network(state)
next_value = critic_network(next_state)
td_target = (reward + gamma * next_value * (1-done))
td_error = td_target - value
# Apply formulas for actor and critic losses actor_loss = -action_log_prob * td_error.detach()
critic_loss = td_error ** 2
return actor_loss, critic_loss
.detach()
to stop gradient propagation to critic weightsfor episode in range(10): state, info = env.reset() done = False while not done:
# Select action action, action_log_prob = select_action(actor, state)
next_state, reward, terminated, truncated, _ = env.step(action) done = terminated or truncated
# Calculate losses actor_loss, critic_loss = calculate_losses(critic, action_log_prob, reward, state, next_state, done)
# Update actor actor_optimizer.zero_grad(); actor_loss.backward(); actor_optimizer.step()
# Update critic critic_optimizer.zero_grad(); critic_loss.backward(); critic_optimizer.step()
state = next_state
Deep Reinforcement Learning in Python