Pembaruan batch pada policy gradient

Deep Reinforcement Learning dengan Python

Timothée Carayol

Principal Machine Learning Engineer, Komment

Pembaruan gradien per langkah vs batch

Sebuah kotak besar mewakili satu episode.

Deep Reinforcement Learning dengan Python

Pembaruan gradien per langkah vs batch

Di kotak besar, muncul kotak lebih kecil mewakili step 1. Di dalamnya, kotak lain dengan teks 'select action.'

Deep Reinforcement Learning dengan Python

Pembaruan gradien per langkah vs batch

Di kotak step 1, muncul kotak kecil lain dengan teks 'iterate environment'

Deep Reinforcement Learning dengan Python

Pembaruan gradien per langkah vs batch

Di bawah kotak step 1, kotak lain dengan label 'hitung loss' dan 'gradient descent'

Deep Reinforcement Learning dengan Python

Pembaruan gradien per langkah vs batch

Sepasang kotak identik muncul untuk langkah kedua, dengan konten yang sama

Deep Reinforcement Learning dengan Python

Pembaruan gradien per langkah vs batch

Langkah 3 dan langkah 4 juga muncul.

Deep Reinforcement Learning dengan Python

Batching pembaruan A2C / PPO

Sebuah kotak episode besar; di setengah areanya, kotak lain berlabel 'rollout 1'; di dalamnya, dua kotak kosong berlabel 'step 1' dan 'step 2'

Deep Reinforcement Learning dengan Python

Batching pembaruan A2C / PPO

Di kotak step 1, muncul label 'pilih aksi' dan 'iterasi environment'.

Deep Reinforcement Learning dengan Python

Batching pembaruan A2C / PPO

Sama untuk step 2.

Deep Reinforcement Learning dengan Python

Batching pembaruan A2C / PPO

Di bawah kotak step 1 dan step 2, muncul satu label 'hitung loss' dan satu label 'gradient descent'.

Deep Reinforcement Learning dengan Python

Batching pembaruan A2C / PPO

Setengah area episode yang tersisa kini diisi kotak rollout lain identik dengan dua langkah, berlabel 'rollout 2'.

Deep Reinforcement Learning dengan Python

Loop pelatihan A2C dengan pembaruan batch

 

# Set rollout length
rollout_length = 10

# Initiate loss batches
actor_losses = torch.tensor([]) critic_losses = torch.tensor([])
  • Inisialisasi batch loss
  • Iterasi episode dan langkah seperti biasa

 

for episode in range(10):
  state, info = env.reset()
  done = False
  while not done:
    action, action_log_prob = select_action(actor, 
                                            state)                
    next_state, reward, terminated, truncated, _ = (
                                   env.step(action))
    done = terminated or truncated    
    actor_loss, critic_loss = calculate_losses(
        critic, action_log_prob, 
        reward, state, next_state, done)
    ...
Deep Reinforcement Learning dengan Python

Loop pelatihan A2C dengan pembaruan batch

  ...
  actor_losses = torch.cat((actor_losses, actor_loss))
  critic_losses = torch.cat((critic_losses, critic_loss))

# If rollout is full, update the networks if len(actor_losses) >= rollout_length:
actor_loss_batch = actor_losses.mean() critic_loss_batch = critic_losses.mean()
actor_optimizer.zero_grad() actor_loss_batch.backward() actor_optimizer.step() critic_optimizer.zero_grad() critic_loss_batch.backward() critic_optimizer.step()
actor_losses = torch.tensor([]) critic_losses = torch.tensor([])
state = next_state

 

  • Tambahkan loss langkah ke batch loss
  • Saat rollout penuh:
    • Ambil rata-rata batch dengan .mean()
    • Lakukan gradient descent
    • Inisialisasi ulang batch loss
Deep Reinforcement Learning dengan Python

A2C / PPO dengan banyak agen

 

Dua pita horizontal mewakili agen 1 dan agen 2. Masing-masing mengalami 4 dan 3 episode dengan panjang bervariasi. Di tiap episode, kotak langkah terlihat seperti slide sebelumnya. Di bawah dua pita, terlihat tiga kotak rollout, masing-masing mencakup interval 8 langkah. Di setiap kotak rollout ada label 'hitung loss' dan 'gradient descent'. Di atas grafik, legenda menunjukkan "panjang rollout: 8 langkah; jumlah agen: 2"

Deep Reinforcement Learning dengan Python

Rollout dan minibatch

Dua pita agen identik dengan slide sebelumnya. Di bawahnya, 3 kotak rollout masih terlihat, namun isinya berubah. Kini di bagian atas ada kotak panjang berlabel 'shuffle'. Di bawahnya, masing-masing dibagi memanjang menjadi 4 kotak berlabel 'minibatch'; dalam tiap minibatch ada 'hitung loss' dan 'gradient descent'. Di atas gambar, legenda: 'Panjang rollout: 8 langkah; ukuran minibatch: 4 (2x2); jumlah agen: 2'

Deep Reinforcement Learning dengan Python

PPO dengan banyak epoch

Gambar mirip sebelumnya, kecuali batch rollout kini juga terbagi vertikal menjadi 4 area: paling atas label 'shuffle'; kedua kotak besar berlabel 'epoch 1' berisi 4 minibatch memanjang; ketiga label 'reshuffle'; terakhir kotak besar berlabel 'epoch 2' juga berisi 4 minibatch. Legenda: 'Panjang rollout: 8 langkah; ukuran minibatch: 4 (2x2); jumlah agen: 2; jumlah epoch: 2'.

Deep Reinforcement Learning dengan Python

Ayo berlatih!

Deep Reinforcement Learning dengan Python

Preparing Video For Download...