Deep Reinforcement Learning in Python
Timothée Carayol
Principal Machine Learning Engineer, Komment
from collections import deque
# Instantiate with limited capacity buffer = deque([1,2,3,4], maxlen=7)
# Extend to the right side buffer.extend([5,6,7,8])
import random
class ReplayBuffer:
def __init__(self, capacity):
self.memory = deque([], maxlen=capacity)
def push(self, state, action, reward, next_state, done):
experience_tuple = (state, action, reward, next_state, done)
self.memory.append(experience_tuple)
def __len__(self): return len(self.memory)
...
deque
with limited capacity.push()
:... def sample(self, batch_size):
batch = random.sample(self.memory, batch_size)
states, actions, rewards, next_states, dones = ( zip(*batch))
states_tensor = torch.tensor( states, dtype=torch.float32) ... # repeat identically for # rewards, next_states, dones
actions_tensor = torch.tensor( actions, dtype=torch.long).unsqueeze(1)
return states_tensor, actions_tensor, rewards_tensor, next_states_tensor, dones_tensor
batch
: from list of transition tuples...Before training loop: replay_buffer = ReplayBuffer(10000)
In training loop, after action selection:
replay_buffer.push((state, action, reward, next_state, done))
if len(replay_buffer) >= batch_size:
states, actions, rewards, next_states, dones = ( replay_buffer.sample(batch_size))
q_values = ( q_network(states).gather(1, actions).squeeze(1))
next_states_q_values = q_network(next_states).amax(1)
target_q_values = ( rewards + gamma * next_states_q_values * (1-dones))
loss = nn.MSELoss()(target_q_values, q_values)
If buffer length $\geq$ batch_size
:
Deep Reinforcement Learning in Python