LLM 課程文件

在 TRL 中實現 GRPO

Hugging Face's logo
加入 Hugging Face 社群

並獲得增強的文件體驗

開始使用

在 TRL 中實現 GRPO

在本頁中,我們將學習如何使用 Transformer Reinforcement Learning (TRL) 庫實現組相對策略最佳化 (GRPO)。我們將專注於實際實現,使用最少的程式碼。

我們將使用 TRL 官方文件中的程式碼片段作為指導,探索 GRPO 在 TRL 的 GRPOTrainer 中體現的核心概念。

本章面向 TRL 初學者。如果您已經熟悉 TRL,您可能還想檢視 GRPO 的 Open R1 實現

首先,讓我們回顧一下 GRPO 演算法的一些重要概念

  • 組形成:模型為每個提示生成多個補全。
  • 偏好學習:模型從獎勵函式中學習,該獎勵函式比較補全組。
  • 訓練配置:模型使用配置來控制訓練過程。

我們需要做什麼來實施 GRPO?

  • 定義提示資料集。
  • 定義一個獎勵函式,該函式接受補全列表並返回獎勵列表。
  • 使用 GRPOConfig 配置訓練過程。
  • 使用 GRPOTrainer 訓練模型。

這是一個開始 GRPO 訓練的最小示例

from trl import GRPOTrainer, GRPOConfig
from datasets import load_dataset

# 1. Load your dataset
dataset = load_dataset("your_dataset", split="train")


# 2. Define a simple reward function
def reward_func(completions, **kwargs):
    """Example: Reward longer completions"""
    return [float(len(completion)) for completion in completions]


# 3. Configure training
training_args = GRPOConfig(
    output_dir="output",
    num_train_epochs=3,
    per_device_train_batch_size=4,
    gradient_accumulation_steps=2,
    logging_steps=10,
)

# 4. Initialize and train
trainer = GRPOTrainer(
    model="your_model",  # e.g. "Qwen/Qwen2-0.5B-Instruct"
    args=training_args,
    train_dataset=dataset,
    reward_funcs=reward_func,
)
trainer.train()

關鍵元件

1. 資料集格式

您的資料集應包含模型將響應的提示。GRPO 訓練器將為每個提示生成多個補全,並使用獎勵函式對它們進行比較。

2. 獎勵函式

獎勵函式至關重要——它決定了模型如何學習。這裡有兩個實用示例

# Example 1: Reward based on completion length
def reward_length(completions, **kwargs):
    return [float(len(completion)) for completion in completions]


# Example 2: Reward based on matching a pattern
import re


def reward_format(completions, **kwargs):
    pattern = r"^<think>.*?</think><answer>.*?</answer>$"
    return [1.0 if re.match(pattern, c) else 0.0 for c in completions]

3. 訓練配置

GRPOConfig 中要考慮的關鍵引數

training_args = GRPOConfig(
    # Essential parameters
    output_dir="output",
    num_train_epochs=3,
    num_generation=4,  # Number of completions to generate for each prompt
    per_device_train_batch_size=4,  # We want to get all generations in one device batch
    # Optional but useful
    gradient_accumulation_steps=2,
    learning_rate=1e-5,
    logging_steps=10,
    # GRPO specific (optional)
    use_vllm=True,  # Speed up generation
)

num_generation 引數對 GRPO 尤其重要,因為它定義了組大小——模型將為每個提示生成多少個不同的補全。這是與其他 RL 方法的關鍵區別。

  • 過小(例如 2-3):可能無法提供足夠的 Diversity 進行有意義的比較
  • 推薦(4-16):在多樣性和計算效率之間取得了良好的平衡
  • 更大的值:可能會改善學習,但會顯著增加計算成本

應根據您的計算資源和任務的複雜性選擇組大小。對於簡單任務,較小的組 (4-8) 可能足夠,而更復雜的推理任務可能受益於較大的組 (8-16)。

成功秘訣

  1. 記憶體管理:根據您的 GPU 記憶體調整 per_device_train_batch_sizegradient_accumulation_steps
  2. 速度:如果您的模型支援,請啟用 use_vllm=True 以加快生成速度。
  3. 監控:在訓練期間觀察記錄的指標
    • reward:補全的平均獎勵
    • reward_std:獎勵組內的標準差
    • kl:與參考模型的 KL 散度

獎勵函式設計

DeepSeek R1 論文展示了幾種有效的獎勵函式設計方法,您可以將其調整用於您自己的 GRPO 實現

1. 基於長度的獎勵

最容易實現的獎勵之一是基於長度的獎勵。您可以獎勵更長的補全

def reward_len(completions, **kwargs):
    ideal_length = 20
    return [-abs(ideal_length - len(completion)) for completion in completions]

此獎勵函式會懲罰過短或過長的補全,鼓勵模型生成接近理想長度 20 個 token 的補全。

2. 可驗證任務的基於規則的獎勵

對於具有客觀正確答案的任務(例如數學或編碼),您可以實現基於規則的獎勵函式

def problem_reward(completions, answers, **kwargs):
    """Reward function for math problems with verifiable answers
    completions: list of completions to evaluate
    answers: list of answers to the problems from the dataset
    """

    rewards = []
    for completion, correct_answer in zip(completions, answers):
        # Extract the answer from the completion
        try:
            # This is a simplified example - you'd need proper parsing
            answer = extract_final_answer(completion)
            # Binary reward: 1 for correct, 0 for incorrect
            reward = 1.0 if answer == correct_answer else 0.0
            rewards.append(reward)
        except:
            # If we can't parse an answer, give a low reward
            rewards.append(0.0)

    return rewards

3. 基於格式的獎勵

您還可以獎勵正確的格式,這在 DeepSeek R1 訓練中很重要

def format_reward(completions, **kwargs):
    """Reward completions that follow the desired format"""
    # Example: Check if the completion follows a think-then-answer format
    pattern = r"<think>(.*?)</think>\s*<answer>(.*?)</answer>"

    rewards = []
    for completion in completions:
        match = re.search(pattern, completion, re.DOTALL)
        if match:
            # Check if there's substantial content in both sections
            think_content = match.group(1).strip()
            answer_content = match.group(2).strip()

            if len(think_content) > 20 and len(answer_content) > 0:
                rewards.append(1.0)
            else:
                rewards.append(
                    0.5
                )  # Partial reward for correct format but limited content
        else:
            rewards.append(0.0)  # No reward for incorrect format

    return rewards

這些示例展示瞭如何實現受 DeepSeek R1 訓練過程啟發,專注於正確性、格式和組合訊號的獎勵函式。

就是這樣!

在下一節中,您將進行一項練習,在 TRL 中實現 GRPO。

< > 在 GitHub 上更新

© . This site is unofficial and not affiliated with Hugging Face, Inc.