Training with PPO

Reinforcement Learning from Human Feedback (RLHF)

Mina Parham

AI Engineer

Fine-Tuning with reinforcement learning

The initial LLM and reward model in the RLHF process.

Reinforcement Learning from Human Feedback (RLHF)

Fine-Tuning with reinforcement learning

The full RLHF process.

Reinforcement Learning from Human Feedback (RLHF)

Fine-Tuning a Language Model with PPO

 

A diagram of a query given to a LLM to generate a continuation based on it.

Reinforcement Learning from Human Feedback (RLHF)

Fine-Tuning a Language Model with PPO

 

A diagram of a query given to a LLM, that completes the query: 'we're half way there, oh livin' on a prayer'.

Reinforcement Learning from Human Feedback (RLHF)

Fine-Tuning a Language Model with PPO

 

A diagram of a query given to a LLM, that completes the query: 'we're half way there, oh livin' on a prayer', with another LLM evaluating the completion.

Reinforcement Learning from Human Feedback (RLHF)

Fine-Tuning a Language Model with PPO

  • PPO: gradual adjustment for the model
  • Avoids overfitting to feedback

A robot and a snail representing slow improvement of the algorithm.

Reinforcement Learning from Human Feedback (RLHF)

Implementing PPOTrainer with TRL

from trl import PPOConfig
config = PPOConfig(model_name="gpt2",learning_rate=1.4e-5)
from trl import AutoModelForCausalLMWithValueHead
model = AutoModelForCausalLMWithValueHead.from_pretrained(config.model_name)
tokenizer = AutoTokenizer.from_pretrained(config.model_name)
from trl import PPOTrainer
ppo_trainer = PPOTrainer(model=model,config=config,dataset=dataset,
                         tokenizer=tokenizer)
Reinforcement Learning from Human Feedback (RLHF)

Starting the training loop

for epoch in tqdm(range(10), "epoch: "):


for batch in tqdm(ppo_trainer.dataloader):
# Get responses response_tensors = ppo_trainer.generate(batch["input_ids"])
batch["response"] = [tokenizer.decode(r.squeeze()) for r in response_tensors]
# Compute reward score texts = [q + r for q, r in zip(batch["query"], batch["response"])]
rewards = reward_model(texts)
stats = ppo_trainer.step(query_tensors, response_tensors, rewards) ppo_trainer.log_stats(stats, batch, rewards)
Reinforcement Learning from Human Feedback (RLHF)

Let's practice!

Reinforcement Learning from Human Feedback (RLHF)

Preparing Video For Download...