Deep Reinforcement Learning in Python
Timothée Carayol
Principal Machine Learning Engineer, Komment
ratio = action_log_prob.exp() / old_action_log_prob.exp().detach()
# Or equivalently ratio = torch.exp(action_log_prob - old_action_log_prob.detach())
detach
the denominator to prevent gradient propagation
clipped_ratio = torch.clamp(ratio,
1-epsilon,
1+epsilon)
def calculate_ratios(action_log_prob, action_log_prob_old, epsilon):
prob = action_log_prob.exp() prob_old = action_log_prob_old.exp() prob_old_detached = prob_old.detach() ratio = prob / prob_old_detached clipped_ratio = torch.clamp(ratio, 1-epsilon, 1+epsilon)
return (ratio, clipped_ratio)
Example with epsilon = .2:
Ratio: tensor(1.25)
Clipped ratio: tensor(1.20)
surr1 = ratio * td_error.detach()
surr2 = clipped_ratio * td_error.detach()
objective = torch.min(surr1, surr2)
$$\mathrm{clip}(r_t(\theta),1-\varepsilon,1+\varepsilon)\hat{A}$$
def calculate_losses(critic_network,
action_log_prob,
action_log_prob_old,
reward, state, next_state,
done
):
# calculate TD error (same as A2C)
value = critic_network(state)
next_value = critic_network(next_state)
td_target = (reward +
gamma * next_value * (1-done))
td_error = td_target - value
...
... ratio, clipped_ratio = calculate_ratios(action_log_prob, action_log_prob_old, epsilon)
surr1 = ratio * td_error.detach()
surr2 = clipped_ratio * td_error.detach()
objective = torch.min(surr1, surr2)
actor_loss = -objective
critic_loss = td_error ** 2 return actor_loss, critic_loss
Deep Reinforcement Learning in Python