Prioritized experience replay

Deep Reinforcement Learning in Python

Timothée Carayol

Principal Machine Learning Engineer, Komment

Not all experiences are created equal

 

  • Experience Replay:
    • Uniform sampling of experiences may overlook important memories
  • Prioritized Experience Replay:
    • Assign priority to each experience, based on TD errors
    • Focus on experiences with high learning potential

 

Students are studying in a library

Deep Reinforcement Learning in Python

Prioritized Experience Replay (PER)

 

for step = 1 to T do:
    # Take optimal action according to value function
    # Observe next state and reward
    # Append transition to replay buffer

# Give it highest priority (1)
# Sample a batch of past transitions
# Based on priority (2)
# Calculate TD errors for the batch
# Calculate the loss and update the Q Network
# Use importance sampling weights (4)
# Update priority of sampled transitions (3)
# Increase importance sampling over time. (5)

(1) New transitions are appended with highest priority $p_i = \max_k(p_k)$

(2) Sample transition $i$ with probability $$P(i) = p_i^{\alpha} / \sum_k p_k^{\alpha}\ \ \ \ \ \ \ \ (0<\alpha<1)$$

(3) Sampled transitions have their priority updated to their TD error: $p_i = |\delta_i| + \varepsilon$

(4) Use importance sampling weights $$w_i = \left( \frac{1}{N} \cdot \frac{1}{P(i)} \right)^\beta\ \ \ \ \ \ \ \ (0<\beta<1)$$

(5) Progressively increase $\beta$ towards 1

Deep Reinforcement Learning in Python

Implementing PER

def __init__(self, capacity, alpha=0.6, beta=0.4, beta_increment=0.001, epsilon=0.001):
    # Initialize memory buffer
    self.memory = deque(maxlen=capacity)

# Store parameters and initialize priorities self.alpha, self.beta, self.beta_increment, self.epsilon = (alpha, beta, beta_increment, epsilon) self.priorities = deque(maxlen=capacity)
...
Deep Reinforcement Learning in Python

Implementing PER

...

def push(self, state, action, reward, next_state, done):
    # Append experience to memory buffer
    experience_tuple = (state, action, reward, next_state, done)
    self.memory.append(experience_tuple)

# Set priority of new transition to maximum priority max_priority = max(self.priorities) if self.memory else 1.0 self.priorities.append(max_priority)
...
Deep Reinforcement Learning in Python

Implementing PER

def sample(self, batch_size):
    priorities = np.array(self.priorities)
    # Calculate sampling probabilities
    probabilities = priorities**self.alpha / np.sum(priorities**self.alpha)

# Randomly select sampled indices indices = np.random.choice(len(self.memory), batch_size, p=probabilities)
# Calculate weights weights = (1 / (len(self.memory) * probabilities)) ** self.beta weights /= np.max(weights) states, actions, rewards, next_states, dones = zip(*[self.memory[idx] for idx in indices]) weights = [weights[idx] for idx in indices] states, actions, rewards, next_states, dones = (zip(*[self.memory[idx] for idx in indices]))
# Return tensors states = torch.tensor(states, dtype=torch.float32) ... # Repeat for rewards, next_states, dones, weights actions = torch.tensor(actions, dtype=torch.long).unsqueeze(1) return (states, actions, rewards, next_states, dones, indices, weights)
Deep Reinforcement Learning in Python

Implementing PER

...

def update_priorities(self, indices, td_errors: torch.Tensor):
    # Update priorities for sampled transitions
    for idx, td_error in zip(indices, td_errors.abs()):
        self.priorities[idx] = abs(td_error.item()) + self.epsilon

def increase_beta(self): # Increment beta towards 1 self.beta = min(1.0, self.beta + self.beta_increment)
Deep Reinforcement Learning in Python

PER in the DQN training loop

 

  1. In pre-loop code:

    buffer = PrioritizedReplayBuffer(capacity)
    
  2. At the start of each episode:

    buffer.increase_beta()
    

3. At every step:

# After selecting an action
buffer.push(state, action, reward, 
            next_state, done)
...

# Before calculating the TD errors: replay_buffer.sample(batch_size) ...
# After calculating the TD errors buffer.update_priorities(indices, td_errors)
loss = torch.sum(weights * (td_errors ** 2))
Deep Reinforcement Learning in Python

PER In Action: Cartpole

100 training runs in the Cartpole environment:

  1. with Prioritized Experience Replay
  2. with Uniform Experience Replay
  • Faster learning and better performance with PER than uniform experience replay

Learning curves shows PER learning faster

 

After 100 epochs: Cartpole, unstable after 100 epochs

 

After 400 epochs: Cartpole, stable after 400 epochs

Deep Reinforcement Learning in Python

PER In Action: Atari environments

 

  • Substantial performance boost with PER in Atari environments

Bar chart comparing the performance of humans, DQN, DDQN, Dueling DDQN, Prioritized DDQN and Prioritized Dueling DQN. The first four are identical to the bart chart of the previous lesson on Dueling DQN. The last one shows that the introduction of Prioritize Experience Replay improves the performance of DDQN.

1 https://arxiv.org/abs/2303.11634
Deep Reinforcement Learning in Python

Let's practice!

Deep Reinforcement Learning in Python

Preparing Video For Download...