Deep Reinforcement Learning in Python
Timothée Carayol
Principal Machine Learning Engineer, Komment











# Set rollout length rollout_length = 10# Initiate loss batchesactor_losses = torch.tensor([]) critic_losses = torch.tensor([])
for episode in range(10):
  state, info = env.reset()
  done = False
  while not done:
    action, action_log_prob = select_action(actor, 
                                            state)                
    next_state, reward, terminated, truncated, _ = (
                                   env.step(action))
    done = terminated or truncated    
    actor_loss, critic_loss = calculate_losses(
        critic, action_log_prob, 
        reward, state, next_state, done)
    ...
... actor_losses = torch.cat((actor_losses, actor_loss)) critic_losses = torch.cat((critic_losses, critic_loss))# If rollout is full, update the networks if len(actor_losses) >= rollout_length:actor_loss_batch = actor_losses.mean() critic_loss_batch = critic_losses.mean()actor_optimizer.zero_grad() actor_loss_batch.backward() actor_optimizer.step() critic_optimizer.zero_grad() critic_loss_batch.backward() critic_optimizer.step()actor_losses = torch.tensor([]) critic_losses = torch.tensor([])state = next_state
.mean()



Deep Reinforcement Learning in Python