Prioritized experience replay

Deep Reinforcement Learning dengan Python

Timothée Carayol

Principal Machine Learning Engineer, Komment

Tidak semua pengalaman bernilai sama

 

  • Experience Replay:
    • Pengambilan sampel uniform bisa melewatkan memori penting
  • Prioritized Experience Replay:
    • Beri prioritas tiap pengalaman, berdasar TD error
    • Fokus pada pengalaman dengan potensi belajar tinggi

 

Siswa belajar di perpustakaan

Deep Reinforcement Learning dengan Python

Prioritized Experience Replay (PER)

 

for step = 1 to T do:
    # Ambil aksi optimal sesuai fungsi nilai
    # Amati state berikutnya dan reward
    # Tambahkan transisi ke replay buffer

# Beri prioritas tertinggi (1)
# Ambil sampel batch transisi masa lalu
# Berdasarkan prioritas (2)
# Hitung TD error untuk batch
# Hitung loss dan perbarui Q Network
# Gunakan bobot importance sampling (4)
# Perbarui prioritas transisi yang diambil sampelnya (3)
# Tingkatkan importance sampling seiring waktu. (5)

(1) Transisi baru ditambahkan dengan prioritas tertinggi $p_i = \max_k(p_k)$

(2) Ambil sampel transisi $i$ dengan probabilitas $$P(i) = p_i^{\alpha} / \sum_k p_k^{\alpha}\ \ \ \ \ \ \ \ (0<\alpha<1)$$

(3) Prioritas transisi yang terambil diperbarui ke TD error-nya: $p_i = |\delta_i| + \varepsilon$

(4) Gunakan bobot importance sampling $$w_i = \left( \frac{1}{N} \cdot \frac{1}{P(i)} \right)^\beta\ \ \ \ \ \ \ \ (0<\beta<1)$$

(5) Tingkatkan $\beta$ secara bertahap menuju 1

Deep Reinforcement Learning dengan Python

Mengimplementasikan PER

def __init__(self, capacity, alpha=0.6, beta=0.4, beta_increment=0.001, epsilon=0.001):
    # Inisialisasi memory buffer
    self.memory = deque(maxlen=capacity)

# Simpan parameter dan inisialisasi prioritas self.alpha, self.beta, self.beta_increment, self.epsilon = (alpha, beta, beta_increment, epsilon) self.priorities = deque(maxlen=capacity)
...
Deep Reinforcement Learning dengan Python

Mengimplementasikan PER

...

def push(self, state, action, reward, next_state, done):
    # Tambahkan pengalaman ke memory buffer
    experience_tuple = (state, action, reward, next_state, done)
    self.memory.append(experience_tuple)

# Set prioritas transisi baru ke prioritas maksimum max_priority = max(self.priorities) if self.memory else 1.0 self.priorities.append(max_priority)
...
Deep Reinforcement Learning dengan Python

Mengimplementasikan PER

def sample(self, batch_size):
    priorities = np.array(self.priorities)
    # Hitung probabilitas pengambilan sampel
    probabilities = priorities**self.alpha / np.sum(priorities**self.alpha)

# Pilih indeks secara acak indices = np.random.choice(len(self.memory), batch_size, p=probabilities)
# Hitung bobot weights = (1 / (len(self.memory) * probabilities)) ** self.beta weights /= np.max(weights) states, actions, rewards, next_states, dones = zip(*[self.memory[idx] for idx in indices]) weights = [weights[idx] for idx in indices] states, actions, rewards, next_states, dones = (zip(*[self.memory[idx] for idx in indices]))
# Kembalikan tensor states = torch.tensor(states, dtype=torch.float32) ... # Ulangi untuk rewards, next_states, dones, weights actions = torch.tensor(actions, dtype=torch.long).unsqueeze(1) return (states, actions, rewards, next_states, dones, indices, weights)
Deep Reinforcement Learning dengan Python

Mengimplementasikan PER

...

def update_priorities(self, indices, td_errors: torch.Tensor):
    # Perbarui prioritas untuk transisi yang diambil sampelnya
    for idx, td_error in zip(indices, td_errors.abs()):
        self.priorities[idx] = abs(td_error.item()) + self.epsilon

def increase_beta(self): # Naikkan beta menuju 1 self.beta = min(1.0, self.beta + self.beta_increment)
Deep Reinforcement Learning dengan Python

PER dalam loop pelatihan DQN

 

  1. Dalam kode pra-loop:

    buffer = PrioritizedReplayBuffer(capacity)
    
  2. Di awal tiap episode:

    buffer.increase_beta()
    

3. Di setiap langkah:

# Setelah memilih aksi
buffer.push(state, action, reward, 
            next_state, done)
...

# Sebelum menghitung TD error: replay_buffer.sample(batch_size) ...
# Setelah menghitung TD error buffer.update_priorities(indices, td_errors)
loss = torch.sum(weights * (td_errors ** 2))
Deep Reinforcement Learning dengan Python

PER dalam Aksi: Cartpole

100 kali pelatihan di lingkungan Cartpole:

  1. dengan Prioritized Experience Replay
  2. dengan Uniform Experience Replay
  • Pembelajaran lebih cepat dan kinerja lebih baik dengan PER dibanding uniform experience replay

Kurva pembelajaran menunjukkan PER belajar lebih cepat

 

Setelah 100 epoch: Cartpole, tidak stabil setelah 100 epoch

 

Setelah 400 epoch: Cartpole, stabil setelah 400 epoch

Deep Reinforcement Learning dengan Python

PER dalam Aksi: Lingkungan Atari

 

  • Peningkatan kinerja besar dengan PER di lingkungan Atari

Diagram batang membandingkan kinerja manusia, DQN, DDQN, Dueling DDQN, Prioritized DDQN, dan Prioritized Dueling DQN. Empat pertama sama dengan diagram pada pelajaran Dueling DQN sebelumnya. Yang terakhir menunjukkan bahwa penambahan Prioritized Experience Replay meningkatkan kinerja DDQN.

1 https://arxiv.org/abs/2303.11634
Deep Reinforcement Learning dengan Python

Ayo berlatih!

Deep Reinforcement Learning dengan Python

Preparing Video For Download...