Trainen met PPO

Reinforcement Learning from Human Feedback (RLHF)

Mina Parham

AI Engineer

Fine-tunen met reinforcement learning

Het initiële LLM en beloningsmodel in het RLHF-proces.

Reinforcement Learning from Human Feedback (RLHF)

Fine-tunen met reinforcement learning

Het volledige RLHF-proces.

Reinforcement Learning from Human Feedback (RLHF)

Een taalmodel fine-tunen met PPO

 

Een schema van een query aan een LLM die daarop een vervolg genereert.

Reinforcement Learning from Human Feedback (RLHF)

Een taalmodel fine-tunen met PPO

 

Een schema van een query aan een LLM die de query aanvult: 'we're half way there, oh livin' on a prayer'.

Reinforcement Learning from Human Feedback (RLHF)

Een taalmodel fine-tunen met PPO

 

Een schema van een query aan een LLM die de query aanvult: 'we're half way there, oh livin' on a prayer', met een ander LLM dat de aanvulling beoordeelt.

Reinforcement Learning from Human Feedback (RLHF)

Een taalmodel fine-tunen met PPO

  • PPO: geleidelijke aanpassing van het model
  • Voorkomt overfitting op feedback

Een robot en een slak die trage verbetering van het algoritme voorstellen.

Reinforcement Learning from Human Feedback (RLHF)

PPOTrainer implementeren met TRL

from trl import PPOConfig
config = PPOConfig(model_name="gpt2",learning_rate=1.4e-5)
from trl import AutoModelForCausalLMWithValueHead
model = AutoModelForCausalLMWithValueHead.from_pretrained(config.model_name)
tokenizer = AutoTokenizer.from_pretrained(config.model_name)
from trl import PPOTrainer
ppo_trainer = PPOTrainer(model=model,config=config,dataset=dataset,
                         tokenizer=tokenizer)
Reinforcement Learning from Human Feedback (RLHF)

De trainingslus starten

for epoch in tqdm(range(10), "epoch: "):


for batch in tqdm(ppo_trainer.dataloader):
# Get responses response_tensors = ppo_trainer.generate(batch["input_ids"])
batch["response"] = [tokenizer.decode(r.squeeze()) for r in response_tensors]
# Compute reward score texts = [q + r for q, r in zip(batch["query"], batch["response"])]
rewards = reward_model(texts)
stats = ppo_trainer.step(query_tensors, response_tensors, rewards) ppo_trainer.log_stats(stats, batch, rewards)
Reinforcement Learning from Human Feedback (RLHF)

Laten we oefenen!

Reinforcement Learning from Human Feedback (RLHF)

Preparing Video For Download...