TRL 文件

GRPO 訓練器

Hugging Face's logo
加入 Hugging Face 社群

並獲得增強的文件體驗

開始使用

GRPO 訓練器

概述

TRL 支援 GRPO 訓練器用於訓練語言模型,如論文 DeepSeekMath: 推動開放語言模型數學推理的極限 所述,作者為 Zhihong Shao, Peiyi Wang, Qihao Zhu, Runxin Xu, Junxiao Song, Mingchuan Zhang, Y. K. Li, Y. Wu, Daya Guo

論文摘要如下:

數學推理因其複雜和結構化的性質,對語言模型構成了重大挑戰。在本文中,我們介紹了 DeepSeekMath 7B,它在 DeepSeek-Coder-Base-v1.5 7B 的基礎上繼續預訓練,使用了來自 Common Crawl 的 1200 億個數學相關標記,以及自然語言和程式碼資料。DeepSeekMath 7B 在競賽級別的 MATH 基準測試中取得了令人印象深刻的 51.7% 的分數,而無需依賴外部工具包和投票技術,其效能水平接近 Gemini-Ultra 和 GPT-4。DeepSeekMath 7B 在 64 個樣本上進行自洽性測試,在 MATH 上達到了 60.9%。DeepSeekMath 的數學推理能力歸因於兩個關鍵因素:首先,我們透過精心設計的資料選擇管道,充分利用了公開網路資料的巨大潛力。其次,我們引入了組相對策略最佳化(GRPO),這是近端策略最佳化(PPO)的一個變體,它在增強數學推理能力的同時,優化了 PPO 的記憶體使用。

此後訓練方法由 Quentin Gallouédec 貢獻。

快速入門

此示例演示瞭如何使用 GRPO 方法訓練模型。我們使用 TLDR 資料集 中的提示(忽略完成列!)訓練了一個 Qwen 0.5B Instruct 模型。你可以在此處檢視資料集中的資料

以下是訓練模型的指令碼。

# train_grpo.py
from datasets import load_dataset
from trl import GRPOConfig, GRPOTrainer

dataset = load_dataset("trl-lib/tldr", split="train")

# Define the reward function, which rewards completions that are close to 20 characters
def reward_len(completions, **kwargs):
    return [-abs(20 - len(completion)) for completion in completions]

training_args = GRPOConfig(output_dir="Qwen2-0.5B-GRPO")
trainer = GRPOTrainer(
    model="Qwen/Qwen2-0.5B-Instruct",
    reward_funcs=reward_len,
    args=training_args,
    train_dataset=dataset,
)
trainer.train()

使用以下命令執行指令碼

accelerate launch train_grpo.py

在 8 個 GPU 上分散式訓練大約需要 1 天。

深入瞭解 GRPO 方法

GRPO 是一種線上學習演算法,這意味著它透過使用訓練好的模型在訓練期間生成的資料進行迭代改進。GRPO 目標的直覺是最大化生成完成的優勢,同時確保模型保持接近參考策略。要理解 GRPO 的工作原理,可以將其分解為四個主要步驟:**生成完成**、**計算優勢**、**估計 KL 散度**和**計算損失**。

生成完成

在每個訓練步驟中,我們取樣一批提示並生成一組G G 每個提示的完成(表示為oi o_i ).

計算優勢

對於每個G G 序列,我們使用獎勵模型計算獎勵。為了與獎勵模型的比較性質保持一致——通常在相同問題的輸出之間進行比較的資料集上進行訓練——優勢的計算方式反映了這些相對比較。其標準化如下:A^i,t=rimean(r)std(r)\hat{A}_{i,t} = \frac{r_i - \text{mean}(\mathbf{r})}{\text{std}(\mathbf{r})}

這種方法因此得名:**群組相對策略最佳化 (GRPO)**。

論文 理解 R1-Zero 類似訓練:批判性視角 表明,按std(r) \text{std}(\mathbf{r}) 縮放可能會導致問題級別難度偏差。你可以透過在 GRPOConfig 中設定 scale_rewards=False 來停用此縮放。

估計 KL 散度

KL 散度使用 Schulman et al. (2020) 引入的近似器進行估計。近似器定義如下:DKL[πθπref]=πref(oi,tq,oi,<t)πθ(oi,tq,oi,<t)logπref(oi,tq,oi,<t)πθ(oi,tq,oi,<t)1,\mathbb{D}_{\text{KL}}\left[\pi_\theta \|\pi_{\text{ref}}\right] = \frac{\pi_{\text{ref}}(o_{i,t} \mid q, o_{i,<t})}{\pi_\theta(o_{i,t} \mid q, o_{i,<t})} - \log \frac{\pi_{\text{ref}}(o_{i,t} \mid q, o_{i,<t})}{\pi_\theta(o_{i,t} \mid q, o_{i,<t})} - 1,

計算損失

目標是最大化優勢,同時確保模型保持接近參考策略。因此,損失定義如下:LGRPO(θ)=1i=1Goii=1Gt=1oi[πθ(oi,tq,oi,<t)[πθ(oi,tq,oi,<t)]no gradA^i,tβDKL[πθπref]], \mathcal{L}_{\text{GRPO}}(\theta) = -\frac{1}{\sum_{i=1}^G |o_i|} \sum_{i=1}^G \sum_{t=1}^{|o_i|} \left[ \frac{\pi_\theta(o_{i,t} \mid q, o_{i,< t})}{\left[\pi_\theta(o_{i,t} \mid q, o_{i,< t})\right]_{\text{no grad}}} \hat{A}_{i,t} - \beta \mathbb{D}_{\text{KL}}\left[\pi_\theta \| \pi_{\text{ref}}\right] \right],

其中第一項代表縮放後的優勢,第二項透過 KL 散度懲罰與參考策略的偏差。

注意,與 DeepSeekMath: 推動開放語言模型數學推理的極限 中原始公式相比,我們沒有按1oi \frac{1}{|o_i|} 縮放,因為論文 理解 R1-Zero 類似訓練:批判性視角 表明這會引入響應級別的長度偏差。更多詳情請參見損失型別

注意,與 DeepSeekMath: 推動開放語言模型數學推理的極限 中的原始公式相比,我們預設使用β=0.0 \beta = 0.0 ,這意味著不使用 KL 散度項。此選擇受到幾項近期研究的啟發(例如,Open-Reasoner-Zero: 一種在基礎模型上擴充套件強化學習的開源方法),這些研究表明 KL 散度項對於 GRPO 訓練並非必不可少。因此,將其排除已成為常見做法(例如 理解 R1-Zero 類似訓練:批判性視角DAPO: 一種大規模開源 LLM 強化學習系統)。如果你希望包含 KL 散度項,可以在 GRPOConfig 中將 beta 設定為非零值。

在原始論文中,此公式被泛化以考慮每次生成後的多次更新(表示為μ \mu ,可在 GRPOConfig 中使用 num_iterations 設定),透過利用裁剪替代目標LGRPO(θ)=1i=1Goii=1Gt=1oi[min(πθ(oi,tq,oi,<t)πθold(oi,tq,oi,<t)A^i,t,clip(πθ(oi,tq,oi,<t)πθold(oi,tq,oi,<t),1ϵ,1+ϵ)A^i,t)βDKL[πθπref]], \mathcal{L}_{\text{GRPO}}(\theta) = - \frac{1}{\sum_{i=1}^G |o_i|} \sum_{i=1}^G \sum_{t=1}^{|o_i|} \left[ \min \left( \frac{\pi_\theta(o_{i,t} \mid q, o_{i,< t})}{\pi_{\theta_{\text{old}}}(o_{i,t} \mid q, o_{i,< t})} \hat{A}_{i,t}, \, \text{clip}\left( \frac{\pi_\theta(o_{i,t} \mid q, o_{i,< t})}{\pi_{\theta_{\text{old}}}(o_{i,t} \mid q, o_{i,< t})}, 1 - \epsilon, 1 + \epsilon \right) \hat{A}_{i,t} \right) - \beta \mathbb{D}_{\text{KL}}\left[\pi_\theta \| \pi_{\text{ref}}\right] \right],

其中clip(,1ϵ,1+ϵ)\text{clip}(\cdot, 1 - \epsilon, 1 + \epsilon) 確保更新不會過度偏離參考策略,透過限制策略比率在以下範圍之間:1ϵ 1 - \epsilon 1+ϵ 1 + \epsilon 。當μ=1 \mu = 1 (TRL中的預設值)時,裁剪的替代目標簡化為原始目標。

損失型別

文獻中提出了幾種目標函式形式。最初,GRPO的目標函式定義如下:LGRPO(θ)=1Gi=1G1oit=1oili,t, \mathcal{L}_{\text{GRPO}}(\theta) = - \frac{1}{G} \sum_{i=1}^G \frac{1}{|o_i|} \sum_{t=1}^{|o_i|} l_{i,t},

其中li,t=πθ(oi,tq,oi,<t)[πθ(oi,tq,oi,<t)]no gradA^i,tβDKL[πθπref]. l_{i,t} = \frac{\pi_\theta(o_{i,t} \mid q, o_{i,< t})}{\left[\pi_\theta(o_{i,t} \mid q, o_{i,< t})\right]_{\text{no grad}}} \hat{A}_{i,t} - \beta \mathbb{D}_{\text{KL}}\left[\pi_\theta \| \pi_{\text{ref}}\right].

DAPO論文強調了GRPO演算法在長CoT場景中樣本級損失的侷限性,即較長的響應受到懲罰不足,導致輸出質量較差。提出的解決方案是token級歸一化,它透過為單個token分配更平衡的獎勵,更好地處理較長序列,而不管響應長度如何:LDAPO(θ)=1i=1Goii=1Gt=1oili,t, \mathcal{L}_{\text{DAPO}}(\theta) = - \frac{1}{\sum_{i=1}^G |o_i|} \sum_{i=1}^G \sum_{t=1}^{|o_i|} l_{i,t},

此外,在理解R1-Zero類訓練:一個批判性視角論文中,作者指出原始GRPO公式引入了響應長度偏差。他們表明,雖然DAPO公式減少了這種偏差,但並未完全消除。為了完全消除這種偏差,他們提出用一個常數而不是序列長度進行除法,從而得到以下公式:LDr. GRPO(θ)=1LGi=1Gt=1oili,t, \mathcal{L}_{\text{Dr. GRPO}}(\theta) = - \frac{1}{LG} \sum_{i=1}^G \sum_{t=1}^{|o_i|} l_{i,t},

這個常數建議設定為最大完成長度。要使用此公式,請在GRPOConfig中將loss_type設定為"dr_grpo"

記錄指標

  • num_tokens:迄今為止處理的令牌總數,包括提示和完成。
  • completions/mean_length:生成的完成的平均長度。
  • completions/min_length:生成的完成的最小長度。
  • completions/max_length:生成的完成的最大長度。
  • completions/mean_terminated_length:以EOS終止的生成的完成的平均長度。
  • completions/min_terminated_length:以EOS終止的生成的完成的最小長度。
  • completions/max_terminated_length:以EOS終止的生成的完成的最大長度。
  • completions/clipped_ratio:截斷(裁剪)完成的比例。
  • reward/{reward_func_name}/mean:特定獎勵函式的平均獎勵。
  • reward/{reward_func_name}/std:特定獎勵函式的獎勵標準差。
  • reward:應用獎勵權重後的總平均獎勵。
  • reward_std:應用獎勵權重後,每個批次內總獎勵的標準差。
  • frac_reward_zero_std:生成批次中獎勵標準差為零的樣本比例,這意味著該提示的多樣性很小(所有答案都正確或不正確)。
  • entropy:生成的完成中token預測的平均熵。(如果`mask_truncated_completions=True`,則排除被掩碼的序列token。)
  • kl:模型與參考模型之間的平均KL散度,在生成的完成上計算。僅當`beta`不為零時記錄。
  • clip_ratio/region_mean:GRPO目標被裁剪以保持在信任區域內的token(或序列,如果importance_sampling_level="sequence")機率的平均比率clip(ri,t(θ),1ϵlow,1+ϵhigh),ri,t(θ)=πθ(oi,tq,oi,<t)πθold(oi,tq,oi,<t). \text{clip}\left( r_{i,t}(\theta), 1 - \epsilon_\mathrm{low}, 1 + \epsilon_\mathrm{high} \right)\,, \qquad r_{i,t}(\theta) = \frac{\pi_\theta(o_{i,t} \mid q, o_{i,< t})}{\pi_{\theta_{\text{old}}}(o_{i,t} \mid q, o_{i,< t})}\,. 值越高意味著裁剪的token越多,這限制了策略$\pi_\theta$可以改變的幅度。
  • clip_ratio/low_mean:在信任區域下限被裁剪的token(或序列,如果importance_sampling_level="sequence")機率的平均比率ri,t(θ)<1ϵlowr_{i,t}(\theta) < 1 - \epsilon_\mathrm{low}
  • clip_ratio/low_min:在信任區域下限被裁剪的token(或序列,如果importance_sampling_level="sequence")機率的最小比率ri,t(θ)<1ϵlowr_{i,t}(\theta) < 1 - \epsilon_\mathrm{low}
  • clip_ratio/high_mean:在信任區域上限被裁剪的token(或序列,如果importance_sampling_level="sequence")機率的平均比率ri,t(θ)>1+ϵhighr_{i,t}(\theta) > 1 + \epsilon_\mathrm{high}
  • clip_ratio/high_max:在信任區域上限被裁剪的token(或序列,如果importance_sampling_level="sequence")機率的最大比率ri,t(θ)>1+ϵhighr_{i,t}(\theta) > 1 + \epsilon_\mathrm{high}.

定製化

透過vLLM加速訓練中的生成過程

在使用線上方法進行訓練時,生成通常是主要的瓶頸。為了加速生成,您可以使用vLLM,一個用於LLM的高吞吐量、低延遲推理引擎。要啟用它,首先透過以下方式安裝軟體包:

pip install trl[vllm]

我們支援兩種在訓練期間使用 vLLM 的方式:**伺服器模式**和**共置模式**。

🔌 選項 1:伺服器模式

在此模式下,vLLM 在單獨的程序中(並使用單獨的 GPU)執行,並透過 HTTP 與訓練器通訊。如果您有專用的 GPU 用於推理,此模式是理想選擇。

  1. 啟動 vLLM 伺服器:

    trl vllm-serve --model <model_name>
  2. 在您的訓練指令碼中啟用伺服器模式:

    from trl import GRPOConfig
    
    training_args = GRPOConfig(
        ...,
        use_vllm=True,
        vllm_mode="server",  # default value, can be omitted
    )

請確保伺服器使用的 GPU 與訓練器不同,否則可能會遇到 NCCL 錯誤。您可以透過 `CUDA_VISIBLE_DEVICES` 環境變數指定要使用的 GPU。

🧩 選項 2:並置模式

在此模式下,vLLM 在訓練器程序內執行,並與訓練模型共享 GPU 記憶體。這避免了啟動單獨的伺服器,可以提高 GPU 利用率,但也可能導致訓練 GPU 上的記憶體爭用。

from trl import GRPOConfig

training_args = GRPOConfig(
    ...,
    use_vllm=True,
    vllm_mode="colocate",
)

根據模型大小和訓練的整體 GPU 記憶體要求,您可能需要調整 GRPOConfig 中的 vllm_gpu_memory_utilization 引數,以避免 GPU 利用率不足或記憶體不足錯誤。

我們提供了一個 HF Space 來幫助您根據模型配置和實驗設定估算推薦的 GPU 記憶體利用率。只需按如下方式使用即可獲得 vllm_gpu_memory_utilization 推薦

如果推薦值在您的環境中不起作用,我們建議在推薦值的基礎上增加一個小的緩衝區(例如,+0.05 或 +0.1)以確保穩定性。

預設情況下,GRPO 對 vLLM 使用 MASTER_ADDR=localhostMASTER_PORT=12345,但您可以透過相應地設定環境變數來覆蓋這些值。

有關更多資訊,請參閱 使用 vLLM 加速訓練

大規模 GRPO:在多個節點上訓練 70B+ 模型

當訓練像 Qwen2.5-72B 這樣的大模型時,您需要一些關鍵的最佳化來使其在多個 GPU 和節點上高效且可擴充套件。這些最佳化包括

  • DeepSpeed ZeRO Stage 3:ZeRO 利用資料並行來將模型狀態(權重、梯度、最佳化器狀態)分佈到多個 GPU 和 CPU 上,從而減少每個裝置的記憶體和計算要求。由於大模型無法在單個 GPU 上執行,因此訓練此類模型需要使用 ZeRO Stage 3。有關更多詳細資訊,請參閱 DeepSpeed 整合
  • Accelerate:Accelerate 是一個簡化跨多個 GPU 和節點分散式訓練的庫。它提供了一個簡單的 API 來啟動分散式訓練,並處理分散式訓練的複雜性,例如資料並行、梯度累積和分散式資料載入。有關更多詳細資訊,請參閱 分散式訓練
  • vLLM:請參閱上一節,瞭解如何使用 vLLM 加速生成。

以下是在多個節點上使用 GRPO 訓練 70B 模型的 SLURM 指令碼示例。此指令碼在 4 個節點上訓練模型,並使用第 5 個節點進行 vLLM 驅動的生成。

#!/bin/bash
#SBATCH --nodes=5
#SBATCH --gres=gpu:8

# Get the list of allocated nodes
NODELIST=($(scontrol show hostnames $SLURM_JOB_NODELIST))

# Assign the first 4 nodes for training and the 5th node for vLLM
TRAIN_NODES="${NODELIST[@]:0:4}"  # Nodes 0, 1, 2, 3 for training
VLLM_NODE="${NODELIST[4]}"  # Node 4 for vLLM

# Run training on the first 4 nodes (Group 1)
srun --nodes=4 --ntasks=4 --nodelist="${NODELIST[@]:0:4}" accelerate launch \
     --config_file examples/accelerate_configs/deepspeed_zero3.yaml \
     --num_processes 32 \
     --num_machines 4 \
     --main_process_ip ${NODELIST[0]} \
     --machine_rank $SLURM_PROCID \
     --rdzv_backend c10d \
     train_grpo.py \
     --server_ip $VLLM_NODE &

# Run vLLM server on the 5th node (Group 2)
srun --nodes=1 --ntasks=1 --nodelist="${NODELIST[4]}" trl vllm-serve --model Qwen/Qwen2.5-72B --tensor_parallel_size 8 &

wait
import argparse

from datasets import load_dataset
from trl import GRPOTrainer, GRPOConfig

def main():
    parser = argparse.ArgumentParser()
    parser.add_argument("--vllm_server_host", type=str, default="", help="The server IP")
    args = parser.parse_args()

    # Example dataset from TLDR
    dataset = load_dataset("trl-lib/tldr", split="train")

    # Dummy reward function: count the number of unique characters in the completions
    def reward_num_unique_chars(completions, **kwargs):
        return [len(set(c)) for c in completions]

    training_args = GRPOConfig(
        output_dir="Qwen2.5-72B-GRPO",
        per_device_train_batch_size=4,
        bf16=True,
        gradient_checkpointing=True,
        use_vllm=True,
        vllm_server_host=args.vllm_server_host.replace("ip-", "").replace("-", "."),  # from ip-X-X-X-X to X.X.X.X
    )

    trainer = GRPOTrainer(model="Qwen/Qwen2.5-72B", args=training_args, reward_funcs=reward_num_unique_chars, train_dataset=dataset)
    trainer.train()

if __name__=="__main__":
    main()

使用自定義獎勵函式

GRPOTrainer 支援使用自定義獎勵函式而不是密集獎勵模型。為確保相容性,您的獎勵函式必須滿足以下要求

  1. 輸入引數:

    • 函式必須接受以下作為關鍵字引數

      • prompts(包含提示),
      • completions(包含生成的補全),
      • completions_ids(包含標記化的補全),
      • trainer_stateTrainerState):訓練器的當前狀態。這可用於實現動態獎勵函式,例如課程學習,其中獎勵根據訓練進度進行調整。
      • 資料集可能具有的所有列名(prompt 除外)。例如,如果資料集包含名為 ground_truth 的列,則函式將以 ground_truth 作為關鍵字引數呼叫。

      滿足此要求的最簡單方法是在函式簽名中使用 **kwargs

    • 根據資料集格式,輸入將有所不同

      • 對於 標準格式promptscompletions 將是字串列表。
      • 對於 對話格式promptscompletions 將是訊息字典列表。
  2. 返回值:函式必須返回一個浮點數列表。每個浮點數代表與單個補全對應的獎勵。

示例 1:獎勵更長的補全

下面是一個標準格式的獎勵函式示例,它獎勵更長的補全

def reward_func(completions_ids, **kwargs):
    """Reward function that assigns higher scores to longer completions (in terms of token count)."""
    return [float(len(ids)) for ids in completions_ids]

您可以按如下方式測試它

>>> prompts = ["The sky is", "The sun is"]  # not used in the reward function, but the trainer will pass it
>>> completions = [" blue.", " in the sky."]  # not used in the reward function, but the trainer will pass it
>>> completions_ids = [[6303, 13], [304, 279, 12884, 13]]
>>> reward_func(prompts=prompts, completions=completions, completions_ids=completions_ids)
[2.0, 4.0]

示例 1.1:獎勵更長的補全(基於字元數)

與上一個示例相同,但這次獎勵函式基於字元數而不是標記。

def reward_func(completions, **kwargs):
    """Reward function that assigns higher scores to longer completions (in terms of character count)."""
    return [float(len(completion)) for completion in completions]

您可以按如下方式測試它

>>> prompts = ["The sky is", "The sun is"]
>>> completions = [" blue.", " in the sky."]
>>> completions_ids = [[6303, 13], [304, 279, 12884, 13]]  # not used in the reward function, but the trainer will pass it
>>> reward_func(prompts=prompts, completions=completions, completions_ids=completions_ids)
[6.0, 12.0]

示例 2:獎勵具有特定格式的補全

下面是一個獎勵函式示例,它檢查補全是否具有特定格式。此示例的靈感來自論文 DeepSeek-R1:透過強化學習激勵 LLM 的推理能力 中使用的*格式獎勵*函式。它專為對話格式設計,其中提示和補全由結構化訊息組成。

import re

def format_reward_func(completions, **kwargs):
    """Reward function that checks if the completion has a specific format."""
    pattern = r"^<think>.*?</think><answer>.*?</answer>$"
    completion_contents = [completion[0]["content"] for completion in completions]
    matches = [re.match(pattern, content) for content in completion_contents]
    return [1.0 if match else 0.0 for match in matches]

您可以按如下方式測試此功能

>>> prompts = [
...     [{"role": "assistant", "content": "What is the result of (1 + 2) * 4?"}],
...     [{"role": "assistant", "content": "What is the result of (3 + 1) * 2?"}],
... ]
>>> completions = [
...     [{"role": "assistant", "content": "<think>The sum of 1 and 2 is 3, which we multiply by 4 to get 12.</think><answer>(1 + 2) * 4 = 12</answer>"}],
...     [{"role": "assistant", "content": "The sum of 3 and 1 is 4, which we multiply by 2 to get 8. So (3 + 1) * 2 = 8."}],
... ]
>>> format_reward_func(prompts=prompts, completions=completions)
[1.0, 0.0]

示例 3:根據參考獎勵補全

以下是一個獎勵函式示例,它檢查補全是否正確。此示例的靈感來自論文 DeepSeek-R1:透過強化學習激勵 LLM 的推理能力 中使用的*準確性獎勵*函式。此示例專為 標準格式 設計,其中資料集包含名為 ground_truth 的列。

import re

def reward_func(completions, ground_truth, **kwargs):
    # Regular expression to capture content inside \boxed{}
    matches = [re.search(r"\\boxed\{(.*?)\}", completion) for completion in completions]
    contents = [match.group(1) if match else "" for match in matches]
    # Reward 1 if the content is the same as the ground truth, 0 otherwise
    return [1.0 if c == gt else 0.0 for c, gt in zip(contents, ground_truth)]

您可以按如下方式測試此功能

>>> prompts = ["Problem: Solve the equation $2x + 3 = 7$. Solution:", "Problem: Solve the equation $3x - 5 = 10$."]
>>> completions = [r" The solution is \boxed{2}.", r" The solution is \boxed{6}."]
>>> ground_truth = ["2", "5"]
>>> reward_func(prompts=prompts, completions=completions, ground_truth=ground_truth)
[1.0, 0.0]

示例 4:多工獎勵函式

以下是在 GRPOTrainer 中使用多個獎勵函式的示例。在此示例中,我們定義了兩個特定於任務的獎勵函式:math_reward_funccoding_reward_funcmath_reward_func 根據正確性獎勵數學問題,而 coding_reward_func 根據解決方案是否有效獎勵編碼問題。

from datasets import Dataset
from trl import GRPOTrainer

# Define a dataset that contains both math and coding problems
dataset = Dataset.from_list(
    [
        {"prompt": "What is 2+2?", "task": "math"},
        {"prompt": "Write a function that returns the sum of two numbers.", "task": "code"},
        {"prompt": "What is 3*4?", "task": "math"},
        {"prompt": "Write a function that returns the product of two numbers.", "task": "code"},
    ]
)

# Math-specific reward function
def math_reward_func(prompts, completions, task, **kwargs):
    rewards = []
    for prompt, completion, t in zip(prompts, completions, task):
        if t == "math":
            # Calculate math-specific reward
            correct = check_math_solution(prompt, completion)
            reward = 1.0 if correct else -1.0
            rewards.append(reward)
        else:
            # Return None for non-math tasks
            rewards.append(None)
    return rewards

# Coding-specific reward function
def coding_reward_func(prompts, completions, task, **kwargs):
    rewards = []
    for prompt, completion, t in zip(prompts, completions, task):
        if t == "coding":
            # Calculate coding-specific reward
            works = test_code_solution(prompt, completion)
            reward = 1.0 if works else -1.0
            rewards.append(reward)
        else:
            # Return None for non-coding tasks
            rewards.append(None)
    return rewards

# Use both task-specific reward functions
trainer = GRPOTrainer(
    model="Qwen/Qwen2-0.5B-Instruct",
    reward_funcs=[math_reward_func, coding_reward_func],
    train_dataset=dataset,
)

trainer.train()

在此示例中,math_reward_funccoding_reward_func 旨在與包含數學和編碼問題的混合資料集一起使用。資料集中的 task 列用於確定將哪個獎勵函式應用於每個問題。如果資料集中沒有與樣本相關的獎勵函式,則獎勵函式將返回 NoneGRPOTrainer 將繼續使用有效的函式和任務。這允許 GRPOTrainer 處理具有不同適用性的多個獎勵函式。

請注意,GRPOTrainer 將忽略獎勵函式返回的 None 獎勵,只考慮相關函式返回的獎勵。這確保模型在相關任務上進行訓練,並忽略沒有相關獎勵函式的任務。

將獎勵函式傳遞給訓練器

要使用您的自定義獎勵函式,請按如下方式將其傳遞給 GRPOTrainer

from trl import GRPOTrainer

trainer = GRPOTrainer(
    reward_funcs=reward_func,
    ...,
)

如果您有多個獎勵函式,可以將其作為列表傳遞

from trl import GRPOTrainer

trainer = GRPOTrainer(
    reward_funcs=[reward_func1, reward_func2],
    ...,
)

獎勵將計算為每個函式的獎勵之和,如果配置中提供了 reward_weights,則為加權和。

請注意,GRPOTrainer 支援不同型別的多個獎勵函式。有關更多詳細資訊,請參閱引數文件。

視覺語言模型 (VLM) 訓練

GRPO 支援在包含文字和影像的多模態資料集上訓練視覺語言模型 (VLM)。

支援的模型

已測試的型號:

  • Gemma3 — 例如,google/gemma-3-4b-it
  • LLaVA-NeXT — 例如,llava-hf/llava-v1.6-mistral-7b-hf
  • Qwen2-VL — 例如,Qwen/Qwen2-VL-2B-Instruct
  • Qwen2.5-VL — 例如,Qwen/Qwen2.5-VL-3B-Instruct
  • SmolVLM2 — 例如,HuggingFaceTB/SmolVLM2-2.2B-Instruct
不保證與所有 VLM 相容。如果您認為某個模型應該得到支援,請隨時在 GitHub 上提出問題,或者更好的是,提交包含所需更改的拉取請求。

快速入門

使用 grpo_vlm.py 對 VLM 進行微調。在 lmms-lab/multimodal-open-r1-8k-verified 上訓練的示例命令

accelerate launch \
  --config_file=examples/accelerate_configs/deepspeed_zero3.yaml \
  examples/scripts/grpo_vlm.py \
  --model_name_or_path Qwen/Qwen2.5-VL-3B-Instruct \
  --output_dir grpo-Qwen2.5-VL-3B-Instruct \
  --learning_rate 1e-5 \
  --gradient_checkpointing \
  --torch_dtype bfloat16 \
  --max_prompt_length 2048 \
  --max_completion_length 1024 \
  --use_vllm \
  --vllm_mode colocate \
  --use_peft \
  --lora_target_modules "q_proj", "v_proj" \
  --log_completions

配置提示

如果影像標記被截斷,VLM 訓練可能會失敗。強烈建議透過將 `max_prompt_length` 設定為 `None` 來停用截斷。
  • 在視覺-語言投影層上使用 LoRA
  • 啟用 4 位量化以減少記憶體使用
  • VLM 是記憶體密集型的——從較小的批次大小開始
  • 大多數模型與 vLLM 相容(servercolocate 模式)

資料集格式

每個訓練樣本應包括

  • prompt:透過處理器聊天模板格式化的文字
  • image:單個影像(PIL 或 NumPy 陣列)

訓練器透過模型的影像處理器自動處理影像到張量的轉換。

GRPOTrainer

class trl.GRPOTrainer

< >

( model: typing.Union[str, transformers.modeling_utils.PreTrainedModel] reward_funcs: typing.Union[str, transformers.modeling_utils.PreTrainedModel, typing.Callable[[list, list], list[float]], list[typing.Union[str, transformers.modeling_utils.PreTrainedModel, typing.Callable[[list, list], list[float]]]]] args: typing.Optional[trl.trainer.grpo_config.GRPOConfig] = None train_dataset: typing.Union[datasets.arrow_dataset.Dataset, datasets.iterable_dataset.IterableDataset, NoneType] = None eval_dataset: typing.Union[datasets.arrow_dataset.Dataset, datasets.iterable_dataset.IterableDataset, dict[str, typing.Union[datasets.arrow_dataset.Dataset, datasets.iterable_dataset.IterableDataset]], NoneType] = None processing_class: typing.Union[transformers.tokenization_utils_base.PreTrainedTokenizerBase, transformers.processing_utils.ProcessorMixin, NoneType] = None reward_processing_classes: typing.Union[transformers.tokenization_utils_base.PreTrainedTokenizerBase, list[transformers.tokenization_utils_base.PreTrainedTokenizerBase], NoneType] = None callbacks: typing.Optional[list[transformers.trainer_callback.TrainerCallback]] = None optimizers: tuple = (None, None) peft_config: typing.Optional[ForwardRef('PeftConfig')] = None )

引數

  • model (Union[str, PreTrainedModel]) — 要訓練的模型。可以是以下任一型別:

    • 字串:Hugging Face 模型庫中預訓練模型的*模型 ID*,或包含使用 save_pretrained 儲存的模型權重的*目錄*路徑,例如 './my_model_directory/'。模型使用 from_pretrainedargs.model_init_kwargs 中的關鍵字引數載入。
    • PreTrainedModel 物件。僅支援因果語言模型。
  • reward_funcs (Union[RewardFunc, list[RewardFunc]]) — 用於計算獎勵的獎勵函式。為了計算獎勵,我們將所有獎勵函式與提示和補全一起呼叫並求和。可以是以下任一型別:

    • 單個獎勵函式,例如:

      • 字串:Hugging Face 模型庫中預訓練模型的*模型 ID*,或包含使用 save_pretrained 儲存的模型權重的*目錄*路徑,例如 './my_model_directory/'。模型使用 from_pretrainednum_labels=1 以及 args.model_init_kwargs 中的關鍵字引數載入。

      • PreTrainedModel 物件:僅支援序列分類模型。

      • 自定義獎勵函式:該函式提供提示和生成的補全,以及資料集中的任何附加列。它應該返回一個獎勵列表。當獎勵不適用於這些樣本時,自定義獎勵函式也可以返回 None。這對於多工訓練非常有用,其中不同的獎勵函式適用於不同型別的樣本。當獎勵函式為樣本返回 None 時,該獎勵函式將從該樣本的獎勵計算中排除。有關更多詳細資訊,請參閱 使用自定義獎勵函式

        訓練器的狀態也傳遞給獎勵函式。訓練器的狀態是 TrainerState 的例項,可以透過訪問獎勵函式簽名的 trainer_state 引數來訪問。

    • 獎勵函式列表,其中每個項都可以獨立地是上述任何型別。允許列表中混合不同型別(例如,字串模型 ID 和自定義獎勵函式)。

  • args (GRPOConfig可選,預設為 None) — 此訓練器的配置。如果為 None,則使用預設配置。
  • train_dataset (DatasetIterableDataset) — 用於訓練的資料集。它必須包含一個 "prompt" 列。資料集中任何附加列都將被忽略。樣本的格式可以是:

    • 標準格式:每個樣本包含純文字。
    • 對話格式:每個樣本包含結構化訊息(例如,角色和內容)。
  • eval_dataset (Dataset, IterableDatasetdict[str, Union[Dataset, IterableDataset]]) — 用於評估的資料集。它必須滿足與 train_dataset 相同的要求。
  • processing_class (PreTrainedTokenizerBaseProcessorMixin可選,預設為 None) — 用於處理資料的處理類。填充側必須設定為“左”。如果為 None,則從模型的名稱中使用 from_pretrained 載入處理類。必須設定填充標記 tokenizer.pad_token。如果處理類未設定填充標記,則 tokenizer.eos_token 將用作預設值。
  • reward_processing_classes (Union[PreTrainedTokenizerBase, list[PreTrainedTokenizerBase]]可選,預設為 None) — 與 reward_funcs 中指定的獎勵函式對應的處理類。可以是以下任一型別:

    • 單個處理類:當 reward_funcs 只包含一個獎勵函式時使用。
    • 處理類列表:必須與 reward_funcs 中獎勵函式的順序和長度匹配。如果設定為 None,或者列表中與 PreTrainedModel 對應的元素為 None,則模型的 tokenizer 會自動使用 from_pretrained 載入。對於 reward_funcs 中是自定義獎勵函式(而不是 PreTrainedModel)的元素,reward_processing_classes 中對應的條目將被忽略。
  • callbacks (TrainerCallback 列表,可選,預設為 None) — 自定義訓練迴圈的回撥列表。這將新增到 此處 詳述的預設回撥列表中。

    如果要刪除使用的預設回撥之一,請使用 remove_callback 方法。

  • optimizers (tuple[torch.optim.Optimizer, torch.optim.lr_scheduler.LambdaLR]可選,預設為 (None, None)) — 包含要使用的最佳化器和排程器的元組。預設為模型上的 AdamW 例項和由 args 控制的 get_linear_schedule_with_warmup 提供的排程器。
  • peft_config (~peft.PeftConfig可選,預設為 None) — 用於包裝模型的 PEFT 配置。如果為 None,則不包裝模型。

用於分組相對策略最佳化(GRPO)方法的訓練器。該演算法最初是在論文 DeepSeekMath:在開放語言模型中推動數學推理能力的極限 中提出的。

示例

from datasets import load_dataset
from trl import GRPOTrainer

dataset = load_dataset("trl-lib/tldr", split="train")


def reward_func(completions, **kwargs):
    # Dummy reward function that rewards completions with more unique letters.
    return [float(len(set(completion))) for completion in completions]


trainer = GRPOTrainer(
    model="Qwen/Qwen2-0.5B-Instruct",
    reward_funcs=reward_func,
    train_dataset=dataset,
)

trainer.train()

train

< >

( resume_from_checkpoint: typing.Union[str, bool, NoneType] = None trial: typing.Union[ForwardRef('optuna.Trial'), dict[str, typing.Any], NoneType] = None ignore_keys_for_eval: typing.Optional[list[str]] = None **kwargs )

引數

  • resume_from_checkpoint (strbool可選) — 如果是 str,則為先前 Trainer 例項儲存的檢查點的本地路徑。如果為 bool 且等於 True,則載入先前 Trainer 例項在 args.output_dir 中儲存的最新檢查點。如果存在,訓練將從此處載入的模型/最佳化器/排程器狀態恢復。
  • trial (optuna.Trialdict[str, Any]可選) — 用於超引數搜尋的試驗執行或超引數字典。
  • ignore_keys_for_eval (list[str]可選) — 模型輸出(如果是字典)中應在訓練期間收集預測以進行評估時忽略的鍵列表。
  • kwargs (dict[str, Any]可選) — 用於隱藏已棄用引數的附加關鍵字引數

主訓練入口點。

save_model

< >

( output_dir: typing.Optional[str] = None _internal_call: bool = False )

將儲存模型,以便您可以使用 `from_pretrained()` 重新載入它。

僅從主程序儲存。

push_to_hub

< >

( commit_message: typing.Optional[str] = 'End of training' blocking: bool = True token: typing.Optional[str] = None revision: typing.Optional[str] = None **kwargs )

引數

  • commit_message (str可選,預設為 "End of training") — 推送時提交的訊息。
  • blocking (bool可選,預設為 True) — 函式是否應僅在 git push 完成後返回。
  • token (str可選,預設為 None) — 具有寫入許可權的令牌,用於覆蓋 Trainer 的原始 args。
  • revision (str可選) — 要提交的 Git 修訂版本。預設為“main”分支的頭部。
  • kwargs (dict[str, Any]可選) — 傳遞給 ~Trainer.create_model_card 的附加關鍵字引數。

將 `self.model` 和 `self.processing_class` 上傳到 🤗 模型中心的 `self.args.hub_model_id` 儲存庫。

GRPOConfig

class trl.GRPOConfig

< >

( output_dir: typing.Optional[str] = None overwrite_output_dir: bool = False do_train: bool = False do_eval: bool = False do_predict: bool = False eval_strategy: typing.Union[transformers.trainer_utils.IntervalStrategy, str] = 'no' prediction_loss_only: bool = False per_device_train_batch_size: int = 8 per_device_eval_batch_size: int = 8 per_gpu_train_batch_size: typing.Optional[int] = None per_gpu_eval_batch_size: typing.Optional[int] = None gradient_accumulation_steps: int = 1 eval_accumulation_steps: typing.Optional[int] = None eval_delay: typing.Optional[float] = 0 torch_empty_cache_steps: typing.Optional[int] = None learning_rate: float = 1e-06 weight_decay: float = 0.0 adam_beta1: float = 0.9 adam_beta2: float = 0.999 adam_epsilon: float = 1e-08 max_grad_norm: float = 1.0 num_train_epochs: float = 3.0 max_steps: int = -1 lr_scheduler_type: typing.Union[transformers.trainer_utils.SchedulerType, str] = 'linear' lr_scheduler_kwargs: typing.Union[dict[str, typing.Any], str, NoneType] = <factory> warmup_ratio: float = 0.0 warmup_steps: int = 0 log_level: str = 'passive' log_level_replica: str = 'warning' log_on_each_node: bool = True logging_dir: typing.Optional[str] = None logging_strategy: typing.Union[transformers.trainer_utils.IntervalStrategy, str] = 'steps' logging_first_step: bool = False logging_steps: float = 10 logging_nan_inf_filter: bool = True save_strategy: typing.Union[transformers.trainer_utils.SaveStrategy, str] = 'steps' save_steps: float = 500 save_total_limit: typing.Optional[int] = None save_safetensors: typing.Optional[bool] = True save_on_each_node: bool = False save_only_model: bool = False restore_callback_states_from_checkpoint: bool = False no_cuda: bool = False use_cpu: bool = False use_mps_device: bool = False seed: int = 42 data_seed: typing.Optional[int] = None jit_mode_eval: bool = False use_ipex: bool = False bf16: typing.Optional[bool] = None fp16: bool = False fp16_opt_level: str = 'O1' half_precision_backend: str = 'auto' bf16_full_eval: bool = False fp16_full_eval: bool = False tf32: typing.Optional[bool] = None local_rank: int = -1 ddp_backend: typing.Optional[str] = None tpu_num_cores: typing.Optional[int] = None tpu_metrics_debug: bool = False debug: typing.Union[str, list[transformers.debug_utils.DebugOption]] = '' dataloader_drop_last: bool = False eval_steps: typing.Optional[float] = None dataloader_num_workers: int = 0 dataloader_prefetch_factor: typing.Optional[int] = None past_index: int = -1 run_name: typing.Optional[str] = None disable_tqdm: typing.Optional[bool] = None remove_unused_columns: typing.Optional[bool] = False label_names: typing.Optional[list[str]] = None load_best_model_at_end: typing.Optional[bool] = False metric_for_best_model: typing.Optional[str] = None greater_is_better: typing.Optional[bool] = None ignore_data_skip: bool = False fsdp: typing.Union[list[transformers.trainer_utils.FSDPOption], str, NoneType] = '' fsdp_min_num_params: int = 0 fsdp_config: typing.Union[dict[str, typing.Any], str, NoneType] = None fsdp_transformer_layer_cls_to_wrap: typing.Optional[str] = None accelerator_config: typing.Union[dict, str, NoneType] = None deepspeed: typing.Union[dict, str, NoneType] = None label_smoothing_factor: float = 0.0 optim: typing.Union[transformers.training_args.OptimizerNames, str] = 'adamw_torch' optim_args: typing.Optional[str] = None adafactor: bool = False group_by_length: bool = False length_column_name: typing.Optional[str] = 'length' report_to: typing.Union[NoneType, str, list[str]] = None ddp_find_unused_parameters: typing.Optional[bool] = None ddp_bucket_cap_mb: typing.Optional[int] = None ddp_broadcast_buffers: typing.Optional[bool] = None dataloader_pin_memory: bool = True dataloader_persistent_workers: bool = False skip_memory_metrics: bool = True use_legacy_prediction_loop: bool = False push_to_hub: bool = False resume_from_checkpoint: typing.Optional[str] = None hub_model_id: typing.Optional[str] = None hub_strategy: typing.Union[transformers.trainer_utils.HubStrategy, str] = 'every_save' hub_token: typing.Optional[str] = None hub_private_repo: typing.Optional[bool] = None hub_always_push: bool = False hub_revision: typing.Optional[str] = None gradient_checkpointing: bool = False gradient_checkpointing_kwargs: typing.Union[dict[str, typing.Any], str, NoneType] = None include_inputs_for_metrics: bool = False include_for_metrics: list = <factory> eval_do_concat_batches: bool = True fp16_backend: str = 'auto' push_to_hub_model_id: typing.Optional[str] = None push_to_hub_organization: typing.Optional[str] = None push_to_hub_token: typing.Optional[str] = None mp_parameters: str = '' auto_find_batch_size: bool = False full_determinism: bool = False torchdynamo: typing.Optional[str] = None ray_scope: typing.Optional[str] = 'last' ddp_timeout: int = 1800 torch_compile: bool = False torch_compile_backend: typing.Optional[str] = None torch_compile_mode: typing.Optional[str] = None include_tokens_per_second: typing.Optional[bool] = False include_num_input_tokens_seen: typing.Optional[bool] = False neftune_noise_alpha: typing.Optional[float] = None optim_target_modules: typing.Union[NoneType, str, list[str]] = None batch_eval_metrics: bool = False eval_on_start: bool = False use_liger_kernel: typing.Optional[bool] = False liger_kernel_config: typing.Optional[dict[str, bool]] = None eval_use_gather_object: typing.Optional[bool] = False average_tokens_across_devices: typing.Optional[bool] = True model_init_kwargs: typing.Union[dict, str, NoneType] = None disable_dropout: bool = False max_prompt_length: typing.Optional[int] = 512 num_generations: typing.Optional[int] = 8 max_completion_length: typing.Optional[int] = 256 ds3_gather_for_generation: bool = True shuffle_dataset: typing.Optional[bool] = True generation_batch_size: typing.Optional[int] = None steps_per_generation: typing.Optional[int] = None temperature: float = 1.0 top_p: float = 1.0 top_k: typing.Optional[int] = None min_p: typing.Optional[float] = None generation_kwargs: typing.Optional[dict] = None repetition_penalty: float = 1.0 use_transformers_paged: bool = False cache_implementation: typing.Optional[str] = None use_vllm: bool = False vllm_server_base_url: typing.Optional[str] = None vllm_mode: str = 'server' vllm_model_impl: str = 'vllm' vllm_guided_decoding_regex: typing.Optional[str] = None vllm_server_host: str = '0.0.0.0' vllm_server_port: int = 8000 vllm_server_timeout: float = 240.0 vllm_gpu_memory_utilization: float = 0.3 vllm_tensor_parallel_size: int = 1 beta: float = 0.0 num_iterations: int = 1 epsilon: float = 0.2 delta: typing.Optional[float] = None epsilon_high: typing.Optional[float] = None importance_sampling_level: str = 'token' reward_weights: typing.Optional[list[float]] = None scale_rewards: bool = True loss_type: str = 'bnpo' mask_truncated_completions: bool = False sync_ref_model: bool = False ref_model_mixup_alpha: float = 0.6 ref_model_sync_steps: int = 512 top_entropy_quantile: float = 1.0 use_liger_loss: bool = False log_completions: bool = False num_completions_to_print: typing.Optional[int] = None wandb_log_unique_prompts: typing.Optional[bool] = False )

控制模型和參考模型的引數

  • model_init_kwargs (str, dict[str, Any]None, 可選, 預設為 None) — from_pretrained 的關鍵字引數,當 GRPOTrainermodel 引數以字串形式提供時使用。
  • disable_dropout (bool, 可選, 預設為 False) — 是否在模型中停用 dropout。這對於使用參考模型進行訓練很有用,因為它可以防止模型為相同輸入生成不同的 logprobs。

控制資料預處理的引數

  • remove_unused_columns (bool, 可選, 預設為 False) — 是否僅保留資料集中 "prompt" 列。如果您使用自定義獎勵函式,並且該函式除了 "prompts""completions" 之外還需要其他列,則應將其設定為 False
  • max_prompt_length (intNone, 可選, 預設為 512) — 提示的最大長度。如果提示長度超過此值,將從左側截斷。
  • num_generations (intNone, 可選, 預設為 8) — 每個提示的生成樣本數。有效批處理大小(num_processes * per_device_batch_size * gradient_accumulation_steps)必須能被此值整除。
  • max_completion_length (intNone, 可選, 預設為 256) — 生成完成的最大長度。
  • ds3_gather_for_generation (bool, 可選, 預設為 True) — 此設定適用於 DeepSpeed ZeRO-3。如果啟用,將收集策略模型權重以進行生成,從而提高生成速度。但是,停用此選項可以訓練超出單個 GPU 視訊記憶體容量的模型,儘管代價是生成速度較慢。停用此選項與 vLLM 生成不相容。
  • shuffle_dataset (bool, 可選, 預設為 True) — 是否打亂訓練資料集。

控制生成的引數

  • generation_batch_size — (intNone, 可選, 預設為 None): 用於生成的批處理大小。如果為 None,則預設為有效訓練批處理大小:per_device_train_batch_size * num_processes * steps_per_generation。換句話說,每個最佳化步驟處理一個生成批次。與 steps_per_generation 互斥。
  • steps_per_generation — (intNone, 可選, 預設為 None): 每次生成步數。如果為 None,則預設為 gradient_accumulation_steps。與 generation_batch_size 互斥。
  • temperature (float, 預設為 1.0) — 取樣的溫度。溫度越高,完成度越隨機。
  • top_p (float, 可選, 預設為 1.0) — 控制要考慮的最高機率標記的累積機率的浮點數。必須在 (0, 1] 範圍內。設定為 1.0 以考慮所有標記。
  • top_k (intNone, 可選, 預設為 None) — 保留用於 top-k 過濾的最高機率詞彙標記數量。如果為 None,則停用 top-k 過濾,並考慮所有標記。
  • min_p (floatNone, 可選, 預設為 None) — 最小標記機率,將按最可能標記的機率進行縮放。它必須是 0.01.0 之間的值。典型值在 0.01-0.2 範圍內。
  • repetition_penalty (float, 可選, 預設為 1.0) — 根據新標記是否出現在提示和已生成文字中來懲罰新標記的浮點數。值 > 1.0 鼓勵模型使用新標記,而值 < 1.0 鼓勵模型重複標記。
  • use_transformers_paged (bool, 可選, 預設為 False) — 是否使用 transformers 的分頁實現進行生成。如果設定為 True,將使用 transformers 的分頁實現進行生成,而不是預設的填充實現。此引數僅在 use_vllm 設定為 False 時有效。
  • cache_implementation (strNone, 可選, 預設為 None) — 當 use_vllm 設定為 False 時,用於更快生成快取方法的實現。
  • generation_kwargs (dict[str, Any]None, 可選, 預設為 None) — 取樣完成時傳遞給 GenerationConfig(如果使用 transformers)或 SamplingParams(如果使用 vLLM)的附加關鍵字引數。這可用於進一步自定義生成行為,例如設定 supress_tokensnum_beams 等。如果它包含與其它生成引數(如 min_ptop_p 等)衝突的鍵,它們將覆蓋這些引數。

vLLM 支援的生成加速控制引數

  • use_vllm (bool, 可選, 預設為 False) — 是否使用 vLLM 生成完成。如果設定為 True,訓練器將使用 vLLM 進行生成,而不是預設的 model.generate()。需要安裝 vllm
  • vllm_mode (str, 可選, 預設為 "server") — 當 use_vllm 設定為 True 時,用於 vLLM 整合的模式。必須是 "server""colocate" 之一。

    • "server": 訓練器將生成請求傳送到單獨的 vLLM 伺服器。請確保 TRL vLLM 伺服器正在執行(使用 trl vllm-serve 啟動)。
    • "colocate": vLLM 將在同一程序中執行並共享訓練 GPU。這避免了對單獨伺服器的需求,但可能會導致與訓練的資源爭用。
  • vllm_guided_decoding_regex (strNone, 可選, 預設為 None) — vLLM 引導式解碼的正則表示式。如果為 None(預設),則停用引導式解碼。

控制 vLLM 伺服器的引數(僅當 `vllm_mode` 為 `"server"` 時使用)

  • vllm_server_base_url (strNone, 可選, 預設為 None) — vLLM 伺服器的基本 URL(例如,"https://:8000")。如果提供此引數,則 vllm_server_hostvllm_server_port 將被忽略。
  • vllm_server_host (str, 可選, 預設為 "0.0.0.0") — 要連線的 vLLM 伺服器主機。如果提供了 vllm_server_base_url,則忽略此引數。
  • vllm_server_port (int, 可選, 預設為 8000) — 要連線的 vLLM 伺服器埠。如果提供了 vllm_server_base_url,則忽略此引數。
  • vllm_server_timeout (float, 可選, 預設為 240.0) — 等待 vLLM 伺服器啟動的總超時時長(秒)。如果在超時後伺服器仍未啟動,則會引發 ConnectionError

控制共置 vLLM 執行的引數(僅當 `vllm_mode` 為 `"colocate"` 時使用)

  • vllm_gpu_memory_utilization (float, 可選, 預設為 0.3) — 控制 vLLM 的 GPU 記憶體利用率。此設定僅在 vllm_mode 設定為 "colocate" 時適用。如果您使用 vllm_mode="server",則必須在透過 --vllm_gpu_memory_utilization 標誌啟動 vLLM 伺服器時單獨傳遞此引數。
  • vllm_tensor_parallel_size (int, 可選, 預設為 1) — 控制 vLLM 的張量並行大小。此設定僅在 vllm_mode 設定為 "colocate" 時適用。如果您使用 vllm_mode="server",則必須在透過 --vllm_tensor_parallel_size 標誌啟動 vLLM 伺服器時單獨傳遞此引數。
  • vllm_model_impl (str, 可選, 預設為 "vllm") — 用於 vLLM 的模型實現。必須是 "transformers""vllm" 之一。"transformers":使用 transformers 後端進行模型實現。"vllm":使用 vllm 庫進行模型實現。

控制訓練的引數

  • beta (float, 可選, 預設為 0.0) — KL 係數。如果為 0.0(預設),則不載入參考模型,從而減少記憶體使用並提高訓練速度。
  • num_iterations (int, 可選, 預設為 1) — 每批次的迭代次數(在演算法中表示為 μ)。
  • epsilon (float, 可選, 預設為 0.2) — 用於裁剪的 Epsilon 值。
  • delta — (floatNone, 可選, 預設為 None): 當設定為浮點數時,啟用兩邊 GRPO 損失中的上限裁剪。如果為 None(預設),則使用標準 GRPO 裁剪。建議在啟用時大於 1 + ε。此方法在INTELLECT-2 技術報告中引入。
  • epsilon_high (floatNone, 可選, 預設為 None) — 裁剪的上限 epsilon 值。如果未指定,則預設為與引數 epsilon 中指定的下限相同的值。DAPO 論文推薦使用 0.28
  • importance_sampling_level (str, 可選, 預設為 "token") — 控制重要性取樣比率是在 "token" 級別還是 "sequence" 級別計算。"token" 保留原始的每令牌對數機率比率(每個令牌一個權重)。"sequence" 對有效令牌的對數機率比率進行平均,為每個序列生成一個比率。GSPO 論文表明,序列級取樣通常會帶來更穩定的訓練和更好的與序列級獎勵對齊。
  • reward_weights (list[float]None, 可選, 預設為 None) — 每個獎勵函式的權重。必須與獎勵函式的數量匹配。如果為 None,所有獎勵都以 1.0 的權重平均加權。
  • scale_rewards (bool, 可選, 預設為 True) — 是否透過將獎勵除以其標準差來縮放獎勵。如果為 True(預設),獎勵將按標準差歸一化,確保它們具有單位方差。如果為 False,則不應用縮放。Dr. GRPO 論文建議不縮放獎勵,因為按標準差縮放會引入問題級難度偏差。
  • loss_type (str, 可選, 預設為 "bnpo") — 指定要使用的損失公式。支援的值有:

    • "grpo":透過對序列長度進行歸一化來聚合令牌級損失。不推薦,因為它會產生長度偏差——這種方法傾向於在具有正優勢時偏好較短的補全,而在具有負優勢時偏好較長的補全。
    • "bnpo":透過對本地批次中活躍令牌的數量進行歸一化來聚合令牌級損失。請注意,歸一化僅在本地批次上執行,因此結果可能會因本地批次大小而略有不同,儘管有效批次大小是恆定的。當使用 per_device_train_batch_size==1 時,損失等效於 GRPO 損失。
    • "dr_grpo":透過全域性常數進行歸一化來聚合令牌級損失。此方法在Dr. GRPO 論文中引入,以消除長度偏差。常數的值對應於 max_completion_length
  • mask_truncated_completions (bool, 可選, 預設為 False) — 啟用後,截斷的補全將從損失計算中排除,防止它們被錯誤地懲罰並在訓練期間引入噪聲。根據DAPO 論文,這是訓練穩定性的良好實踐。
  • sync_ref_model (bool, 可選, 預設為 False) — 是否每 ref_model_sync_steps 步使用 ref_model_mixup_alpha 引數將參考模型與活躍模型同步。此同步源自TR-DPO 論文
  • ref_model_mixup_alpha (float, 可選, 預設為 0.6) — 來自TR-DPO 論文的 α 引數,它控制當前策略和先前參考策略在更新期間的混合。參考策略根據以下公式更新:π_ref = α * π_θ + (1 - α) * π_ref_prev。要使用此引數,必須設定 sync_ref_model=True
  • ref_model_sync_steps (int, 可選, 預設為 512) — 來自TR-DPO 論文的 τ 引數,它決定了當前策略與參考策略同步的頻率。要使用此引數,必須設定 sync_ref_model=True
  • top_entropy_quantile (float, 可選, 預設為 1.0) — 來自Beyond the 80/20 Rule的 ρ 引數。只保留每個序列位置上機率分佈熵的最高 ρ 分位數令牌在策略損失項中,從而改善結果。範圍:[0.0-1.0]。值為 0.0 遮蔽除最高熵令牌外的所有令牌;1.0 保留所有令牌。論文推薦值為 0.2。如果與 mask_truncated_completions=True 一起使用,則只考慮非截斷補全中的令牌。
  • use_liger_loss (bool, 可選, 預設為 False) — 是否使用 Liger GRPO 損失。

控制日誌記錄的引數

  • log_completions (bool, 可選, 預設為 False) — 是否每 logging_steps 步記錄一組 (提示, 補全) 對。如果安裝了 rich,它會列印該樣本。如果啟用了 wandb 日誌記錄,它會將其記錄到 wandb
  • num_completions_to_print (intNone, 可選, 預設為 None) — 要使用 rich 列印的補全數量。如果為 None,則記錄所有補全。
  • wandb_log_unique_prompts (bool, 可選, 預設為 False) — 是否在 wandb 中記錄唯一提示。如果為 True,則只記錄唯一提示。如果為 False,則記錄所有提示。

GRPOTrainer 的配置類。

此類僅包含 GRPO 訓練特有的引數。有關訓練引數的完整列表,請參閱 TrainingArguments 文件。請注意,此類的預設值可能與 TrainingArguments 中的預設值不同。

使用 HfArgumentParser,我們可以將此類別轉換為可在命令列中指定的 argparse 引數。

< > 在 GitHub 上更新

© . This site is unofficial and not affiliated with Hugging Face, Inc.