LLM 課程文件
在 TRL 中實現 GRPO
並獲得增強的文件體驗
開始使用
在 TRL 中實現 GRPO
在本頁中,我們將學習如何使用 Transformer Reinforcement Learning (TRL) 庫實現組相對策略最佳化 (GRPO)。我們將專注於實際實現,使用最少的程式碼。
我們將使用 TRL 官方文件中的程式碼片段作為指導,探索 GRPO 在 TRL 的 GRPOTrainer 中體現的核心概念。
本章面向 TRL 初學者。如果您已經熟悉 TRL,您可能還想檢視 GRPO 的 Open R1 實現。
首先,讓我們回顧一下 GRPO 演算法的一些重要概念
- 組形成:模型為每個提示生成多個補全。
- 偏好學習:模型從獎勵函式中學習,該獎勵函式比較補全組。
- 訓練配置:模型使用配置來控制訓練過程。
我們需要做什麼來實施 GRPO?
- 定義提示資料集。
- 定義一個獎勵函式,該函式接受補全列表並返回獎勵列表。
- 使用 GRPOConfig 配置訓練過程。
- 使用 GRPOTrainer 訓練模型。
這是一個開始 GRPO 訓練的最小示例
from trl import GRPOTrainer, GRPOConfig
from datasets import load_dataset
# 1. Load your dataset
dataset = load_dataset("your_dataset", split="train")
# 2. Define a simple reward function
def reward_func(completions, **kwargs):
"""Example: Reward longer completions"""
return [float(len(completion)) for completion in completions]
# 3. Configure training
training_args = GRPOConfig(
output_dir="output",
num_train_epochs=3,
per_device_train_batch_size=4,
gradient_accumulation_steps=2,
logging_steps=10,
)
# 4. Initialize and train
trainer = GRPOTrainer(
model="your_model", # e.g. "Qwen/Qwen2-0.5B-Instruct"
args=training_args,
train_dataset=dataset,
reward_funcs=reward_func,
)
trainer.train()
關鍵元件
1. 資料集格式
您的資料集應包含模型將響應的提示。GRPO 訓練器將為每個提示生成多個補全,並使用獎勵函式對它們進行比較。
2. 獎勵函式
獎勵函式至關重要——它決定了模型如何學習。這裡有兩個實用示例
# Example 1: Reward based on completion length
def reward_length(completions, **kwargs):
return [float(len(completion)) for completion in completions]
# Example 2: Reward based on matching a pattern
import re
def reward_format(completions, **kwargs):
pattern = r"^<think>.*?</think><answer>.*?</answer>$"
return [1.0 if re.match(pattern, c) else 0.0 for c in completions]
3. 訓練配置
GRPOConfig
中要考慮的關鍵引數
training_args = GRPOConfig(
# Essential parameters
output_dir="output",
num_train_epochs=3,
num_generation=4, # Number of completions to generate for each prompt
per_device_train_batch_size=4, # We want to get all generations in one device batch
# Optional but useful
gradient_accumulation_steps=2,
learning_rate=1e-5,
logging_steps=10,
# GRPO specific (optional)
use_vllm=True, # Speed up generation
)
num_generation
引數對 GRPO 尤其重要,因為它定義了組大小——模型將為每個提示生成多少個不同的補全。這是與其他 RL 方法的關鍵區別。
- 過小(例如 2-3):可能無法提供足夠的 Diversity 進行有意義的比較
- 推薦(4-16):在多樣性和計算效率之間取得了良好的平衡
- 更大的值:可能會改善學習,但會顯著增加計算成本
應根據您的計算資源和任務的複雜性選擇組大小。對於簡單任務,較小的組 (4-8) 可能足夠,而更復雜的推理任務可能受益於較大的組 (8-16)。
成功秘訣
- 記憶體管理:根據您的 GPU 記憶體調整
per_device_train_batch_size
和gradient_accumulation_steps
。 - 速度:如果您的模型支援,請啟用
use_vllm=True
以加快生成速度。 - 監控:在訓練期間觀察記錄的指標
reward
:補全的平均獎勵reward_std
:獎勵組內的標準差kl
:與參考模型的 KL 散度
獎勵函式設計
DeepSeek R1 論文展示了幾種有效的獎勵函式設計方法,您可以將其調整用於您自己的 GRPO 實現
1. 基於長度的獎勵
最容易實現的獎勵之一是基於長度的獎勵。您可以獎勵更長的補全
def reward_len(completions, **kwargs):
ideal_length = 20
return [-abs(ideal_length - len(completion)) for completion in completions]
此獎勵函式會懲罰過短或過長的補全,鼓勵模型生成接近理想長度 20 個 token 的補全。
2. 可驗證任務的基於規則的獎勵
對於具有客觀正確答案的任務(例如數學或編碼),您可以實現基於規則的獎勵函式
def problem_reward(completions, answers, **kwargs):
"""Reward function for math problems with verifiable answers
completions: list of completions to evaluate
answers: list of answers to the problems from the dataset
"""
rewards = []
for completion, correct_answer in zip(completions, answers):
# Extract the answer from the completion
try:
# This is a simplified example - you'd need proper parsing
answer = extract_final_answer(completion)
# Binary reward: 1 for correct, 0 for incorrect
reward = 1.0 if answer == correct_answer else 0.0
rewards.append(reward)
except:
# If we can't parse an answer, give a low reward
rewards.append(0.0)
return rewards
3. 基於格式的獎勵
您還可以獎勵正確的格式,這在 DeepSeek R1 訓練中很重要
def format_reward(completions, **kwargs):
"""Reward completions that follow the desired format"""
# Example: Check if the completion follows a think-then-answer format
pattern = r"<think>(.*?)</think>\s*<answer>(.*?)</answer>"
rewards = []
for completion in completions:
match = re.search(pattern, completion, re.DOTALL)
if match:
# Check if there's substantial content in both sections
think_content = match.group(1).strip()
answer_content = match.group(2).strip()
if len(think_content) > 20 and len(answer_content) > 0:
rewards.append(1.0)
else:
rewards.append(
0.5
) # Partial reward for correct format but limited content
else:
rewards.append(0.0) # No reward for incorrect format
return rewards
這些示例展示瞭如何實現受 DeepSeek R1 訓練過程啟發,專注於正確性、格式和組合訊號的獎勵函式。
就是這樣!
在下一節中,您將進行一項練習,在 TRL 中實現 GRPO。
< > 在 GitHub 上更新