Deep Reinforcement Learning in Python
Timothée Carayol
Principal Machine Learning Engineer, Komment
# Set rollout length rollout_length = 10
# Initiate loss batches
actor_losses = torch.tensor([]) critic_losses = torch.tensor([])
for episode in range(10):
state, info = env.reset()
done = False
while not done:
action, action_log_prob = select_action(actor,
state)
next_state, reward, terminated, truncated, _ = (
env.step(action))
done = terminated or truncated
actor_loss, critic_loss = calculate_losses(
critic, action_log_prob,
reward, state, next_state, done)
...
... actor_losses = torch.cat((actor_losses, actor_loss)) critic_losses = torch.cat((critic_losses, critic_loss))
# If rollout is full, update the networks if len(actor_losses) >= rollout_length:
actor_loss_batch = actor_losses.mean() critic_loss_batch = critic_losses.mean()
actor_optimizer.zero_grad() actor_loss_batch.backward() actor_optimizer.step() critic_optimizer.zero_grad() critic_loss_batch.backward() critic_optimizer.step()
actor_losses = torch.tensor([]) critic_losses = torch.tensor([])
state = next_state
.mean()
Deep Reinforcement Learning in Python