TRL 文件
GRPO 訓練器
並獲得增強的文件體驗
開始使用
GRPO 訓練器
概述
TRL 支援 GRPO 訓練器用於訓練語言模型,如論文 DeepSeekMath: 推動開放語言模型數學推理的極限 所述,作者為 Zhihong Shao, Peiyi Wang, Qihao Zhu, Runxin Xu, Junxiao Song, Mingchuan Zhang, Y. K. Li, Y. Wu, Daya Guo。
論文摘要如下:
數學推理因其複雜和結構化的性質,對語言模型構成了重大挑戰。在本文中,我們介紹了 DeepSeekMath 7B,它在 DeepSeek-Coder-Base-v1.5 7B 的基礎上繼續預訓練,使用了來自 Common Crawl 的 1200 億個數學相關標記,以及自然語言和程式碼資料。DeepSeekMath 7B 在競賽級別的 MATH 基準測試中取得了令人印象深刻的 51.7% 的分數,而無需依賴外部工具包和投票技術,其效能水平接近 Gemini-Ultra 和 GPT-4。DeepSeekMath 7B 在 64 個樣本上進行自洽性測試,在 MATH 上達到了 60.9%。DeepSeekMath 的數學推理能力歸因於兩個關鍵因素:首先,我們透過精心設計的資料選擇管道,充分利用了公開網路資料的巨大潛力。其次,我們引入了組相對策略最佳化(GRPO),這是近端策略最佳化(PPO)的一個變體,它在增強數學推理能力的同時,優化了 PPO 的記憶體使用。
此後訓練方法由 Quentin Gallouédec 貢獻。
快速入門
此示例演示瞭如何使用 GRPO 方法訓練模型。我們使用 TLDR 資料集 中的提示(忽略完成列!)訓練了一個 Qwen 0.5B Instruct 模型。你可以在此處檢視資料集中的資料
以下是訓練模型的指令碼。
# train_grpo.py
from datasets import load_dataset
from trl import GRPOConfig, GRPOTrainer
dataset = load_dataset("trl-lib/tldr", split="train")
# Define the reward function, which rewards completions that are close to 20 characters
def reward_len(completions, **kwargs):
return [-abs(20 - len(completion)) for completion in completions]
training_args = GRPOConfig(output_dir="Qwen2-0.5B-GRPO")
trainer = GRPOTrainer(
model="Qwen/Qwen2-0.5B-Instruct",
reward_funcs=reward_len,
args=training_args,
train_dataset=dataset,
)
trainer.train()
使用以下命令執行指令碼
accelerate launch train_grpo.py
在 8 個 GPU 上分散式訓練大約需要 1 天。
深入瞭解 GRPO 方法
GRPO 是一種線上學習演算法,這意味著它透過使用訓練好的模型在訓練期間生成的資料進行迭代改進。GRPO 目標的直覺是最大化生成完成的優勢,同時確保模型保持接近參考策略。要理解 GRPO 的工作原理,可以將其分解為四個主要步驟:**生成完成**、**計算優勢**、**估計 KL 散度**和**計算損失**。
生成完成
在每個訓練步驟中,我們取樣一批提示並生成一組每個提示的完成(表示為).
計算優勢
對於每個序列,我們使用獎勵模型計算獎勵。為了與獎勵模型的比較性質保持一致——通常在相同問題的輸出之間進行比較的資料集上進行訓練——優勢的計算方式反映了這些相對比較。其標準化如下:
這種方法因此得名:**群組相對策略最佳化 (GRPO)**。
論文 理解 R1-Zero 類似訓練:批判性視角 表明,按縮放可能會導致問題級別難度偏差。你可以透過在 GRPOConfig 中設定 scale_rewards=False
來停用此縮放。
估計 KL 散度
KL 散度使用 Schulman et al. (2020) 引入的近似器進行估計。近似器定義如下:
計算損失
目標是最大化優勢,同時確保模型保持接近參考策略。因此,損失定義如下:
其中第一項代表縮放後的優勢,第二項透過 KL 散度懲罰與參考策略的偏差。
注意,與 DeepSeekMath: 推動開放語言模型數學推理的極限 中原始公式相比,我們沒有按縮放,因為論文 理解 R1-Zero 類似訓練:批判性視角 表明這會引入響應級別的長度偏差。更多詳情請參見損失型別。
注意,與 DeepSeekMath: 推動開放語言模型數學推理的極限 中的原始公式相比,我們預設使用,這意味著不使用 KL 散度項。此選擇受到幾項近期研究的啟發(例如,Open-Reasoner-Zero: 一種在基礎模型上擴充套件強化學習的開源方法),這些研究表明 KL 散度項對於 GRPO 訓練並非必不可少。因此,將其排除已成為常見做法(例如 理解 R1-Zero 類似訓練:批判性視角、DAPO: 一種大規模開源 LLM 強化學習系統)。如果你希望包含 KL 散度項,可以在 GRPOConfig 中將 beta
設定為非零值。
在原始論文中,此公式被泛化以考慮每次生成後的多次更新(表示為,可在 GRPOConfig 中使用 num_iterations
設定),透過利用裁剪替代目標
其中確保更新不會過度偏離參考策略,透過限制策略比率在以下範圍之間:和。當(TRL中的預設值)時,裁剪的替代目標簡化為原始目標。
損失型別
文獻中提出了幾種目標函式形式。最初,GRPO的目標函式定義如下:
其中
DAPO論文強調了GRPO演算法在長CoT場景中樣本級損失的侷限性,即較長的響應受到懲罰不足,導致輸出質量較差。提出的解決方案是token級歸一化,它透過為單個token分配更平衡的獎勵,更好地處理較長序列,而不管響應長度如何:
此外,在理解R1-Zero類訓練:一個批判性視角論文中,作者指出原始GRPO公式引入了響應長度偏差。他們表明,雖然DAPO公式減少了這種偏差,但並未完全消除。為了完全消除這種偏差,他們提出用一個常數而不是序列長度進行除法,從而得到以下公式:
這個常數建議設定為最大完成長度。要使用此公式,請在GRPOConfig中將loss_type
設定為"dr_grpo"
。
記錄指標
num_tokens
:迄今為止處理的令牌總數,包括提示和完成。completions/mean_length
:生成的完成的平均長度。completions/min_length
:生成的完成的最小長度。completions/max_length
:生成的完成的最大長度。completions/mean_terminated_length
:以EOS終止的生成的完成的平均長度。completions/min_terminated_length
:以EOS終止的生成的完成的最小長度。completions/max_terminated_length
:以EOS終止的生成的完成的最大長度。completions/clipped_ratio
:截斷(裁剪)完成的比例。reward/{reward_func_name}/mean
:特定獎勵函式的平均獎勵。reward/{reward_func_name}/std
:特定獎勵函式的獎勵標準差。reward
:應用獎勵權重後的總平均獎勵。reward_std
:應用獎勵權重後,每個批次內總獎勵的標準差。frac_reward_zero_std
:生成批次中獎勵標準差為零的樣本比例,這意味著該提示的多樣性很小(所有答案都正確或不正確)。entropy
:生成的完成中token預測的平均熵。(如果`mask_truncated_completions=True`,則排除被掩碼的序列token。)kl
:模型與參考模型之間的平均KL散度,在生成的完成上計算。僅當`beta`不為零時記錄。clip_ratio/region_mean
:GRPO目標被裁剪以保持在信任區域內的token(或序列,如果importance_sampling_level="sequence"
)機率的平均比率值越高意味著裁剪的token越多,這限制了策略$\pi_\theta$可以改變的幅度。clip_ratio/low_mean
:在信任區域下限被裁剪的token(或序列,如果importance_sampling_level="sequence"
)機率的平均比率clip_ratio/low_min
:在信任區域下限被裁剪的token(或序列,如果importance_sampling_level="sequence"
)機率的最小比率clip_ratio/high_mean
:在信任區域上限被裁剪的token(或序列,如果importance_sampling_level="sequence"
)機率的平均比率clip_ratio/high_max
:在信任區域上限被裁剪的token(或序列,如果importance_sampling_level="sequence"
)機率的最大比率.
定製化
透過vLLM加速訓練中的生成過程
在使用線上方法進行訓練時,生成通常是主要的瓶頸。為了加速生成,您可以使用vLLM,一個用於LLM的高吞吐量、低延遲推理引擎。要啟用它,首先透過以下方式安裝軟體包:
pip install trl[vllm]
我們支援兩種在訓練期間使用 vLLM 的方式:**伺服器模式**和**共置模式**。
🔌 選項 1:伺服器模式
在此模式下,vLLM 在單獨的程序中(並使用單獨的 GPU)執行,並透過 HTTP 與訓練器通訊。如果您有專用的 GPU 用於推理,此模式是理想選擇。
啟動 vLLM 伺服器:
trl vllm-serve --model <model_name>
在您的訓練指令碼中啟用伺服器模式:
from trl import GRPOConfig training_args = GRPOConfig( ..., use_vllm=True, vllm_mode="server", # default value, can be omitted )
請確保伺服器使用的 GPU 與訓練器不同,否則可能會遇到 NCCL 錯誤。您可以透過 `CUDA_VISIBLE_DEVICES` 環境變數指定要使用的 GPU。
🧩 選項 2:並置模式
在此模式下,vLLM 在訓練器程序內執行,並與訓練模型共享 GPU 記憶體。這避免了啟動單獨的伺服器,可以提高 GPU 利用率,但也可能導致訓練 GPU 上的記憶體爭用。
from trl import GRPOConfig
training_args = GRPOConfig(
...,
use_vllm=True,
vllm_mode="colocate",
)
根據模型大小和訓練的整體 GPU 記憶體要求,您可能需要調整 GRPOConfig 中的 vllm_gpu_memory_utilization
引數,以避免 GPU 利用率不足或記憶體不足錯誤。
我們提供了一個 HF Space 來幫助您根據模型配置和實驗設定估算推薦的 GPU 記憶體利用率。只需按如下方式使用即可獲得 vllm_gpu_memory_utilization
推薦
如果推薦值在您的環境中不起作用,我們建議在推薦值的基礎上增加一個小的緩衝區(例如,+0.05 或 +0.1)以確保穩定性。
預設情況下,GRPO 對 vLLM 使用 MASTER_ADDR=localhost
和 MASTER_PORT=12345
,但您可以透過相應地設定環境變數來覆蓋這些值。
有關更多資訊,請參閱 使用 vLLM 加速訓練。
大規模 GRPO:在多個節點上訓練 70B+ 模型
當訓練像 Qwen2.5-72B 這樣的大模型時,您需要一些關鍵的最佳化來使其在多個 GPU 和節點上高效且可擴充套件。這些最佳化包括
- DeepSpeed ZeRO Stage 3:ZeRO 利用資料並行來將模型狀態(權重、梯度、最佳化器狀態)分佈到多個 GPU 和 CPU 上,從而減少每個裝置的記憶體和計算要求。由於大模型無法在單個 GPU 上執行,因此訓練此類模型需要使用 ZeRO Stage 3。有關更多詳細資訊,請參閱 DeepSpeed 整合。
- Accelerate:Accelerate 是一個簡化跨多個 GPU 和節點分散式訓練的庫。它提供了一個簡單的 API 來啟動分散式訓練,並處理分散式訓練的複雜性,例如資料並行、梯度累積和分散式資料載入。有關更多詳細資訊,請參閱 分散式訓練。
- vLLM:請參閱上一節,瞭解如何使用 vLLM 加速生成。
以下是在多個節點上使用 GRPO 訓練 70B 模型的 SLURM 指令碼示例。此指令碼在 4 個節點上訓練模型,並使用第 5 個節點進行 vLLM 驅動的生成。
#!/bin/bash
#SBATCH --nodes=5
#SBATCH --gres=gpu:8
# Get the list of allocated nodes
NODELIST=($(scontrol show hostnames $SLURM_JOB_NODELIST))
# Assign the first 4 nodes for training and the 5th node for vLLM
TRAIN_NODES="${NODELIST[@]:0:4}" # Nodes 0, 1, 2, 3 for training
VLLM_NODE="${NODELIST[4]}" # Node 4 for vLLM
# Run training on the first 4 nodes (Group 1)
srun --nodes=4 --ntasks=4 --nodelist="${NODELIST[@]:0:4}" accelerate launch \
--config_file examples/accelerate_configs/deepspeed_zero3.yaml \
--num_processes 32 \
--num_machines 4 \
--main_process_ip ${NODELIST[0]} \
--machine_rank $SLURM_PROCID \
--rdzv_backend c10d \
train_grpo.py \
--server_ip $VLLM_NODE &
# Run vLLM server on the 5th node (Group 2)
srun --nodes=1 --ntasks=1 --nodelist="${NODELIST[4]}" trl vllm-serve --model Qwen/Qwen2.5-72B --tensor_parallel_size 8 &
wait
import argparse
from datasets import load_dataset
from trl import GRPOTrainer, GRPOConfig
def main():
parser = argparse.ArgumentParser()
parser.add_argument("--vllm_server_host", type=str, default="", help="The server IP")
args = parser.parse_args()
# Example dataset from TLDR
dataset = load_dataset("trl-lib/tldr", split="train")
# Dummy reward function: count the number of unique characters in the completions
def reward_num_unique_chars(completions, **kwargs):
return [len(set(c)) for c in completions]
training_args = GRPOConfig(
output_dir="Qwen2.5-72B-GRPO",
per_device_train_batch_size=4,
bf16=True,
gradient_checkpointing=True,
use_vllm=True,
vllm_server_host=args.vllm_server_host.replace("ip-", "").replace("-", "."), # from ip-X-X-X-X to X.X.X.X
)
trainer = GRPOTrainer(model="Qwen/Qwen2.5-72B", args=training_args, reward_funcs=reward_num_unique_chars, train_dataset=dataset)
trainer.train()
if __name__=="__main__":
main()
使用自定義獎勵函式
GRPOTrainer 支援使用自定義獎勵函式而不是密集獎勵模型。為確保相容性,您的獎勵函式必須滿足以下要求
輸入引數:
函式必須接受以下作為關鍵字引數
prompts
(包含提示),completions
(包含生成的補全),completions_ids
(包含標記化的補全),trainer_state
(TrainerState
):訓練器的當前狀態。這可用於實現動態獎勵函式,例如課程學習,其中獎勵根據訓練進度進行調整。- 資料集可能具有的所有列名(
prompt
除外)。例如,如果資料集包含名為ground_truth
的列,則函式將以ground_truth
作為關鍵字引數呼叫。
滿足此要求的最簡單方法是在函式簽名中使用
**kwargs
。根據資料集格式,輸入將有所不同
返回值:函式必須返回一個浮點數列表。每個浮點數代表與單個補全對應的獎勵。
示例 1:獎勵更長的補全
下面是一個標準格式的獎勵函式示例,它獎勵更長的補全
def reward_func(completions_ids, **kwargs):
"""Reward function that assigns higher scores to longer completions (in terms of token count)."""
return [float(len(ids)) for ids in completions_ids]
您可以按如下方式測試它
>>> prompts = ["The sky is", "The sun is"] # not used in the reward function, but the trainer will pass it
>>> completions = [" blue.", " in the sky."] # not used in the reward function, but the trainer will pass it
>>> completions_ids = [[6303, 13], [304, 279, 12884, 13]]
>>> reward_func(prompts=prompts, completions=completions, completions_ids=completions_ids)
[2.0, 4.0]
示例 1.1:獎勵更長的補全(基於字元數)
與上一個示例相同,但這次獎勵函式基於字元數而不是標記。
def reward_func(completions, **kwargs):
"""Reward function that assigns higher scores to longer completions (in terms of character count)."""
return [float(len(completion)) for completion in completions]
您可以按如下方式測試它
>>> prompts = ["The sky is", "The sun is"]
>>> completions = [" blue.", " in the sky."]
>>> completions_ids = [[6303, 13], [304, 279, 12884, 13]] # not used in the reward function, but the trainer will pass it
>>> reward_func(prompts=prompts, completions=completions, completions_ids=completions_ids)
[6.0, 12.0]
示例 2:獎勵具有特定格式的補全
下面是一個獎勵函式示例,它檢查補全是否具有特定格式。此示例的靈感來自論文 DeepSeek-R1:透過強化學習激勵 LLM 的推理能力 中使用的*格式獎勵*函式。它專為對話格式設計,其中提示和補全由結構化訊息組成。
import re
def format_reward_func(completions, **kwargs):
"""Reward function that checks if the completion has a specific format."""
pattern = r"^<think>.*?</think><answer>.*?</answer>$"
completion_contents = [completion[0]["content"] for completion in completions]
matches = [re.match(pattern, content) for content in completion_contents]
return [1.0 if match else 0.0 for match in matches]
您可以按如下方式測試此功能
>>> prompts = [
... [{"role": "assistant", "content": "What is the result of (1 + 2) * 4?"}],
... [{"role": "assistant", "content": "What is the result of (3 + 1) * 2?"}],
... ]
>>> completions = [
... [{"role": "assistant", "content": "<think>The sum of 1 and 2 is 3, which we multiply by 4 to get 12.</think><answer>(1 + 2) * 4 = 12</answer>"}],
... [{"role": "assistant", "content": "The sum of 3 and 1 is 4, which we multiply by 2 to get 8. So (3 + 1) * 2 = 8."}],
... ]
>>> format_reward_func(prompts=prompts, completions=completions)
[1.0, 0.0]
示例 3:根據參考獎勵補全
以下是一個獎勵函式示例,它檢查補全是否正確。此示例的靈感來自論文 DeepSeek-R1:透過強化學習激勵 LLM 的推理能力 中使用的*準確性獎勵*函式。此示例專為 標準格式 設計,其中資料集包含名為 ground_truth
的列。
import re
def reward_func(completions, ground_truth, **kwargs):
# Regular expression to capture content inside \boxed{}
matches = [re.search(r"\\boxed\{(.*?)\}", completion) for completion in completions]
contents = [match.group(1) if match else "" for match in matches]
# Reward 1 if the content is the same as the ground truth, 0 otherwise
return [1.0 if c == gt else 0.0 for c, gt in zip(contents, ground_truth)]
您可以按如下方式測試此功能
>>> prompts = ["Problem: Solve the equation $2x + 3 = 7$. Solution:", "Problem: Solve the equation $3x - 5 = 10$."]
>>> completions = [r" The solution is \boxed{2}.", r" The solution is \boxed{6}."]
>>> ground_truth = ["2", "5"]
>>> reward_func(prompts=prompts, completions=completions, ground_truth=ground_truth)
[1.0, 0.0]
示例 4:多工獎勵函式
以下是在 GRPOTrainer 中使用多個獎勵函式的示例。在此示例中,我們定義了兩個特定於任務的獎勵函式:math_reward_func
和 coding_reward_func
。math_reward_func
根據正確性獎勵數學問題,而 coding_reward_func
根據解決方案是否有效獎勵編碼問題。
from datasets import Dataset
from trl import GRPOTrainer
# Define a dataset that contains both math and coding problems
dataset = Dataset.from_list(
[
{"prompt": "What is 2+2?", "task": "math"},
{"prompt": "Write a function that returns the sum of two numbers.", "task": "code"},
{"prompt": "What is 3*4?", "task": "math"},
{"prompt": "Write a function that returns the product of two numbers.", "task": "code"},
]
)
# Math-specific reward function
def math_reward_func(prompts, completions, task, **kwargs):
rewards = []
for prompt, completion, t in zip(prompts, completions, task):
if t == "math":
# Calculate math-specific reward
correct = check_math_solution(prompt, completion)
reward = 1.0 if correct else -1.0
rewards.append(reward)
else:
# Return None for non-math tasks
rewards.append(None)
return rewards
# Coding-specific reward function
def coding_reward_func(prompts, completions, task, **kwargs):
rewards = []
for prompt, completion, t in zip(prompts, completions, task):
if t == "coding":
# Calculate coding-specific reward
works = test_code_solution(prompt, completion)
reward = 1.0 if works else -1.0
rewards.append(reward)
else:
# Return None for non-coding tasks
rewards.append(None)
return rewards
# Use both task-specific reward functions
trainer = GRPOTrainer(
model="Qwen/Qwen2-0.5B-Instruct",
reward_funcs=[math_reward_func, coding_reward_func],
train_dataset=dataset,
)
trainer.train()
在此示例中,math_reward_func
和 coding_reward_func
旨在與包含數學和編碼問題的混合資料集一起使用。資料集中的 task
列用於確定將哪個獎勵函式應用於每個問題。如果資料集中沒有與樣本相關的獎勵函式,則獎勵函式將返回 None
,GRPOTrainer 將繼續使用有效的函式和任務。這允許 GRPOTrainer 處理具有不同適用性的多個獎勵函式。
請注意,GRPOTrainer 將忽略獎勵函式返回的 None
獎勵,只考慮相關函式返回的獎勵。這確保模型在相關任務上進行訓練,並忽略沒有相關獎勵函式的任務。
將獎勵函式傳遞給訓練器
要使用您的自定義獎勵函式,請按如下方式將其傳遞給 GRPOTrainer
from trl import GRPOTrainer
trainer = GRPOTrainer(
reward_funcs=reward_func,
...,
)
如果您有多個獎勵函式,可以將其作為列表傳遞
from trl import GRPOTrainer
trainer = GRPOTrainer(
reward_funcs=[reward_func1, reward_func2],
...,
)
獎勵將計算為每個函式的獎勵之和,如果配置中提供了 reward_weights
,則為加權和。
請注意,GRPOTrainer 支援不同型別的多個獎勵函式。有關更多詳細資訊,請參閱引數文件。
視覺語言模型 (VLM) 訓練
GRPO 支援在包含文字和影像的多模態資料集上訓練視覺語言模型 (VLM)。
支援的模型
已測試的型號:
- Gemma3 — 例如,
google/gemma-3-4b-it
- LLaVA-NeXT — 例如,
llava-hf/llava-v1.6-mistral-7b-hf
- Qwen2-VL — 例如,
Qwen/Qwen2-VL-2B-Instruct
- Qwen2.5-VL — 例如,
Qwen/Qwen2.5-VL-3B-Instruct
- SmolVLM2 — 例如,
HuggingFaceTB/SmolVLM2-2.2B-Instruct
快速入門
使用 grpo_vlm.py 對 VLM 進行微調。在 lmms-lab/multimodal-open-r1-8k-verified
上訓練的示例命令
accelerate launch \
--config_file=examples/accelerate_configs/deepspeed_zero3.yaml \
examples/scripts/grpo_vlm.py \
--model_name_or_path Qwen/Qwen2.5-VL-3B-Instruct \
--output_dir grpo-Qwen2.5-VL-3B-Instruct \
--learning_rate 1e-5 \
--gradient_checkpointing \
--torch_dtype bfloat16 \
--max_prompt_length 2048 \
--max_completion_length 1024 \
--use_vllm \
--vllm_mode colocate \
--use_peft \
--lora_target_modules "q_proj", "v_proj" \
--log_completions
配置提示
- 在視覺-語言投影層上使用 LoRA
- 啟用 4 位量化以減少記憶體使用
- VLM 是記憶體密集型的——從較小的批次大小開始
- 大多數模型與 vLLM 相容(
server
和colocate
模式)
資料集格式
每個訓練樣本應包括
prompt
:透過處理器聊天模板格式化的文字image
:單個影像(PIL 或 NumPy 陣列)
訓練器透過模型的影像處理器自動處理影像到張量的轉換。
GRPOTrainer
class trl.GRPOTrainer
< source >( model: typing.Union[str, transformers.modeling_utils.PreTrainedModel] reward_funcs: typing.Union[str, transformers.modeling_utils.PreTrainedModel, typing.Callable[[list, list], list[float]], list[typing.Union[str, transformers.modeling_utils.PreTrainedModel, typing.Callable[[list, list], list[float]]]]] args: typing.Optional[trl.trainer.grpo_config.GRPOConfig] = None train_dataset: typing.Union[datasets.arrow_dataset.Dataset, datasets.iterable_dataset.IterableDataset, NoneType] = None eval_dataset: typing.Union[datasets.arrow_dataset.Dataset, datasets.iterable_dataset.IterableDataset, dict[str, typing.Union[datasets.arrow_dataset.Dataset, datasets.iterable_dataset.IterableDataset]], NoneType] = None processing_class: typing.Union[transformers.tokenization_utils_base.PreTrainedTokenizerBase, transformers.processing_utils.ProcessorMixin, NoneType] = None reward_processing_classes: typing.Union[transformers.tokenization_utils_base.PreTrainedTokenizerBase, list[transformers.tokenization_utils_base.PreTrainedTokenizerBase], NoneType] = None callbacks: typing.Optional[list[transformers.trainer_callback.TrainerCallback]] = None optimizers: tuple = (None, None) peft_config: typing.Optional[ForwardRef('PeftConfig')] = None )
引數
- model (
Union[str, PreTrainedModel]
) — 要訓練的模型。可以是以下任一型別:- 字串:Hugging Face 模型庫中預訓練模型的*模型 ID*,或包含使用
save_pretrained
儲存的模型權重的*目錄*路徑,例如'./my_model_directory/'
。模型使用from_pretrained
和args.model_init_kwargs
中的關鍵字引數載入。 PreTrainedModel
物件。僅支援因果語言模型。
- 字串:Hugging Face 模型庫中預訓練模型的*模型 ID*,或包含使用
- reward_funcs (
Union[RewardFunc, list[RewardFunc]]
) — 用於計算獎勵的獎勵函式。為了計算獎勵,我們將所有獎勵函式與提示和補全一起呼叫並求和。可以是以下任一型別:-
單個獎勵函式,例如:
-
字串:Hugging Face 模型庫中預訓練模型的*模型 ID*,或包含使用
save_pretrained
儲存的模型權重的*目錄*路徑,例如'./my_model_directory/'
。模型使用from_pretrained
和num_labels=1
以及args.model_init_kwargs
中的關鍵字引數載入。 -
PreTrainedModel
物件:僅支援序列分類模型。 -
自定義獎勵函式:該函式提供提示和生成的補全,以及資料集中的任何附加列。它應該返回一個獎勵列表。當獎勵不適用於這些樣本時,自定義獎勵函式也可以返回
None
。這對於多工訓練非常有用,其中不同的獎勵函式適用於不同型別的樣本。當獎勵函式為樣本返回None
時,該獎勵函式將從該樣本的獎勵計算中排除。有關更多詳細資訊,請參閱 使用自定義獎勵函式。訓練器的狀態也傳遞給獎勵函式。訓練器的狀態是
TrainerState
的例項,可以透過訪問獎勵函式簽名的trainer_state
引數來訪問。
-
-
獎勵函式列表,其中每個項都可以獨立地是上述任何型別。允許列表中混合不同型別(例如,字串模型 ID 和自定義獎勵函式)。
-
- args (GRPOConfig,可選,預設為
None
) — 此訓練器的配置。如果為None
,則使用預設配置。 - train_dataset (Dataset 或 IterableDataset) — 用於訓練的資料集。它必須包含一個
"prompt"
列。資料集中任何附加列都將被忽略。樣本的格式可以是: - eval_dataset (Dataset, IterableDataset 或
dict[str, Union[Dataset, IterableDataset]]
) — 用於評估的資料集。它必須滿足與train_dataset
相同的要求。 - processing_class (
PreTrainedTokenizerBase
或ProcessorMixin
,可選,預設為None
) — 用於處理資料的處理類。填充側必須設定為“左”。如果為None
,則從模型的名稱中使用from_pretrained
載入處理類。必須設定填充標記tokenizer.pad_token
。如果處理類未設定填充標記,則tokenizer.eos_token
將用作預設值。 - reward_processing_classes (
Union[PreTrainedTokenizerBase, list[PreTrainedTokenizerBase]]
,可選,預設為None
) — 與reward_funcs
中指定的獎勵函式對應的處理類。可以是以下任一型別:- 單個處理類:當
reward_funcs
只包含一個獎勵函式時使用。 - 處理類列表:必須與
reward_funcs
中獎勵函式的順序和長度匹配。如果設定為None
,或者列表中與PreTrainedModel
對應的元素為None
,則模型的 tokenizer 會自動使用from_pretrained
載入。對於reward_funcs
中是自定義獎勵函式(而不是PreTrainedModel
)的元素,reward_processing_classes
中對應的條目將被忽略。
- 單個處理類:當
- callbacks (
TrainerCallback
列表,可選,預設為None
) — 自定義訓練迴圈的回撥列表。這將新增到 此處 詳述的預設回撥列表中。如果要刪除使用的預設回撥之一,請使用
remove_callback
方法。 - optimizers (
tuple[torch.optim.Optimizer, torch.optim.lr_scheduler.LambdaLR]
,可選,預設為(None, None)
) — 包含要使用的最佳化器和排程器的元組。預設為模型上的AdamW
例項和由args
控制的get_linear_schedule_with_warmup
提供的排程器。 - peft_config (
~peft.PeftConfig
,可選,預設為None
) — 用於包裝模型的 PEFT 配置。如果為None
,則不包裝模型。
用於分組相對策略最佳化(GRPO)方法的訓練器。該演算法最初是在論文 DeepSeekMath:在開放語言模型中推動數學推理能力的極限 中提出的。
示例
from datasets import load_dataset
from trl import GRPOTrainer
dataset = load_dataset("trl-lib/tldr", split="train")
def reward_func(completions, **kwargs):
# Dummy reward function that rewards completions with more unique letters.
return [float(len(set(completion))) for completion in completions]
trainer = GRPOTrainer(
model="Qwen/Qwen2-0.5B-Instruct",
reward_funcs=reward_func,
train_dataset=dataset,
)
trainer.train()
train
< source >( resume_from_checkpoint: typing.Union[str, bool, NoneType] = None trial: typing.Union[ForwardRef('optuna.Trial'), dict[str, typing.Any], NoneType] = None ignore_keys_for_eval: typing.Optional[list[str]] = None **kwargs )
引數
- resume_from_checkpoint (
str
或bool
,可選) — 如果是str
,則為先前Trainer
例項儲存的檢查點的本地路徑。如果為bool
且等於True
,則載入先前Trainer
例項在 args.output_dir 中儲存的最新檢查點。如果存在,訓練將從此處載入的模型/最佳化器/排程器狀態恢復。 - trial (
optuna.Trial
或dict[str, Any]
,可選) — 用於超引數搜尋的試驗執行或超引數字典。 - ignore_keys_for_eval (
list[str]
,可選) — 模型輸出(如果是字典)中應在訓練期間收集預測以進行評估時忽略的鍵列表。 - kwargs (
dict[str, Any]
,可選) — 用於隱藏已棄用引數的附加關鍵字引數
主訓練入口點。
將儲存模型,以便您可以使用 `from_pretrained()` 重新載入它。
僅從主程序儲存。
push_to_hub
< source >( commit_message: typing.Optional[str] = 'End of training' blocking: bool = True token: typing.Optional[str] = None revision: typing.Optional[str] = None **kwargs )
將 `self.model` 和 `self.processing_class` 上傳到 🤗 模型中心的 `self.args.hub_model_id` 儲存庫。
GRPOConfig
class trl.GRPOConfig
< source >( output_dir: typing.Optional[str] = None overwrite_output_dir: bool = False do_train: bool = False do_eval: bool = False do_predict: bool = False eval_strategy: typing.Union[transformers.trainer_utils.IntervalStrategy, str] = 'no' prediction_loss_only: bool = False per_device_train_batch_size: int = 8 per_device_eval_batch_size: int = 8 per_gpu_train_batch_size: typing.Optional[int] = None per_gpu_eval_batch_size: typing.Optional[int] = None gradient_accumulation_steps: int = 1 eval_accumulation_steps: typing.Optional[int] = None eval_delay: typing.Optional[float] = 0 torch_empty_cache_steps: typing.Optional[int] = None learning_rate: float = 1e-06 weight_decay: float = 0.0 adam_beta1: float = 0.9 adam_beta2: float = 0.999 adam_epsilon: float = 1e-08 max_grad_norm: float = 1.0 num_train_epochs: float = 3.0 max_steps: int = -1 lr_scheduler_type: typing.Union[transformers.trainer_utils.SchedulerType, str] = 'linear' lr_scheduler_kwargs: typing.Union[dict[str, typing.Any], str, NoneType] = <factory> warmup_ratio: float = 0.0 warmup_steps: int = 0 log_level: str = 'passive' log_level_replica: str = 'warning' log_on_each_node: bool = True logging_dir: typing.Optional[str] = None logging_strategy: typing.Union[transformers.trainer_utils.IntervalStrategy, str] = 'steps' logging_first_step: bool = False logging_steps: float = 10 logging_nan_inf_filter: bool = True save_strategy: typing.Union[transformers.trainer_utils.SaveStrategy, str] = 'steps' save_steps: float = 500 save_total_limit: typing.Optional[int] = None save_safetensors: typing.Optional[bool] = True save_on_each_node: bool = False save_only_model: bool = False restore_callback_states_from_checkpoint: bool = False no_cuda: bool = False use_cpu: bool = False use_mps_device: bool = False seed: int = 42 data_seed: typing.Optional[int] = None jit_mode_eval: bool = False use_ipex: bool = False bf16: typing.Optional[bool] = None fp16: bool = False fp16_opt_level: str = 'O1' half_precision_backend: str = 'auto' bf16_full_eval: bool = False fp16_full_eval: bool = False tf32: typing.Optional[bool] = None local_rank: int = -1 ddp_backend: typing.Optional[str] = None tpu_num_cores: typing.Optional[int] = None tpu_metrics_debug: bool = False debug: typing.Union[str, list[transformers.debug_utils.DebugOption]] = '' dataloader_drop_last: bool = False eval_steps: typing.Optional[float] = None dataloader_num_workers: int = 0 dataloader_prefetch_factor: typing.Optional[int] = None past_index: int = -1 run_name: typing.Optional[str] = None disable_tqdm: typing.Optional[bool] = None remove_unused_columns: typing.Optional[bool] = False label_names: typing.Optional[list[str]] = None load_best_model_at_end: typing.Optional[bool] = False metric_for_best_model: typing.Optional[str] = None greater_is_better: typing.Optional[bool] = None ignore_data_skip: bool = False fsdp: typing.Union[list[transformers.trainer_utils.FSDPOption], str, NoneType] = '' fsdp_min_num_params: int = 0 fsdp_config: typing.Union[dict[str, typing.Any], str, NoneType] = None fsdp_transformer_layer_cls_to_wrap: typing.Optional[str] = None accelerator_config: typing.Union[dict, str, NoneType] = None deepspeed: typing.Union[dict, str, NoneType] = None label_smoothing_factor: float = 0.0 optim: typing.Union[transformers.training_args.OptimizerNames, str] = 'adamw_torch' optim_args: typing.Optional[str] = None adafactor: bool = False group_by_length: bool = False length_column_name: typing.Optional[str] = 'length' report_to: typing.Union[NoneType, str, list[str]] = None ddp_find_unused_parameters: typing.Optional[bool] = None ddp_bucket_cap_mb: typing.Optional[int] = None ddp_broadcast_buffers: typing.Optional[bool] = None dataloader_pin_memory: bool = True dataloader_persistent_workers: bool = False skip_memory_metrics: bool = True use_legacy_prediction_loop: bool = False push_to_hub: bool = False resume_from_checkpoint: typing.Optional[str] = None hub_model_id: typing.Optional[str] = None hub_strategy: typing.Union[transformers.trainer_utils.HubStrategy, str] = 'every_save' hub_token: typing.Optional[str] = None hub_private_repo: typing.Optional[bool] = None hub_always_push: bool = False hub_revision: typing.Optional[str] = None gradient_checkpointing: bool = False gradient_checkpointing_kwargs: typing.Union[dict[str, typing.Any], str, NoneType] = None include_inputs_for_metrics: bool = False include_for_metrics: list = <factory> eval_do_concat_batches: bool = True fp16_backend: str = 'auto' push_to_hub_model_id: typing.Optional[str] = None push_to_hub_organization: typing.Optional[str] = None push_to_hub_token: typing.Optional[str] = None mp_parameters: str = '' auto_find_batch_size: bool = False full_determinism: bool = False torchdynamo: typing.Optional[str] = None ray_scope: typing.Optional[str] = 'last' ddp_timeout: int = 1800 torch_compile: bool = False torch_compile_backend: typing.Optional[str] = None torch_compile_mode: typing.Optional[str] = None include_tokens_per_second: typing.Optional[bool] = False include_num_input_tokens_seen: typing.Optional[bool] = False neftune_noise_alpha: typing.Optional[float] = None optim_target_modules: typing.Union[NoneType, str, list[str]] = None batch_eval_metrics: bool = False eval_on_start: bool = False use_liger_kernel: typing.Optional[bool] = False liger_kernel_config: typing.Optional[dict[str, bool]] = None eval_use_gather_object: typing.Optional[bool] = False average_tokens_across_devices: typing.Optional[bool] = True model_init_kwargs: typing.Union[dict, str, NoneType] = None disable_dropout: bool = False max_prompt_length: typing.Optional[int] = 512 num_generations: typing.Optional[int] = 8 max_completion_length: typing.Optional[int] = 256 ds3_gather_for_generation: bool = True shuffle_dataset: typing.Optional[bool] = True generation_batch_size: typing.Optional[int] = None steps_per_generation: typing.Optional[int] = None temperature: float = 1.0 top_p: float = 1.0 top_k: typing.Optional[int] = None min_p: typing.Optional[float] = None generation_kwargs: typing.Optional[dict] = None repetition_penalty: float = 1.0 use_transformers_paged: bool = False cache_implementation: typing.Optional[str] = None use_vllm: bool = False vllm_server_base_url: typing.Optional[str] = None vllm_mode: str = 'server' vllm_model_impl: str = 'vllm' vllm_guided_decoding_regex: typing.Optional[str] = None vllm_server_host: str = '0.0.0.0' vllm_server_port: int = 8000 vllm_server_timeout: float = 240.0 vllm_gpu_memory_utilization: float = 0.3 vllm_tensor_parallel_size: int = 1 beta: float = 0.0 num_iterations: int = 1 epsilon: float = 0.2 delta: typing.Optional[float] = None epsilon_high: typing.Optional[float] = None importance_sampling_level: str = 'token' reward_weights: typing.Optional[list[float]] = None scale_rewards: bool = True loss_type: str = 'bnpo' mask_truncated_completions: bool = False sync_ref_model: bool = False ref_model_mixup_alpha: float = 0.6 ref_model_sync_steps: int = 512 top_entropy_quantile: float = 1.0 use_liger_loss: bool = False log_completions: bool = False num_completions_to_print: typing.Optional[int] = None wandb_log_unique_prompts: typing.Optional[bool] = False )
控制模型和參考模型的引數
- model_init_kwargs (
str
,dict[str, Any]
或None
, 可選, 預設為None
) —from_pretrained
的關鍵字引數,當 GRPOTrainer 的model
引數以字串形式提供時使用。 - disable_dropout (
bool
, 可選, 預設為False
) — 是否在模型中停用 dropout。這對於使用參考模型進行訓練很有用,因為它可以防止模型為相同輸入生成不同的 logprobs。
控制資料預處理的引數
- remove_unused_columns (
bool
, 可選, 預設為False
) — 是否僅保留資料集中"prompt"
列。如果您使用自定義獎勵函式,並且該函式除了"prompts"
和"completions"
之外還需要其他列,則應將其設定為False
。 - max_prompt_length (
int
或None
, 可選, 預設為512
) — 提示的最大長度。如果提示長度超過此值,將從左側截斷。 - num_generations (
int
或None
, 可選, 預設為8
) — 每個提示的生成樣本數。有效批處理大小(num_processes * per_device_batch_size * gradient_accumulation_steps)必須能被此值整除。 - max_completion_length (
int
或None
, 可選, 預設為256
) — 生成完成的最大長度。 - ds3_gather_for_generation (
bool
, 可選, 預設為True
) — 此設定適用於 DeepSpeed ZeRO-3。如果啟用,將收集策略模型權重以進行生成,從而提高生成速度。但是,停用此選項可以訓練超出單個 GPU 視訊記憶體容量的模型,儘管代價是生成速度較慢。停用此選項與 vLLM 生成不相容。 - shuffle_dataset (
bool
, 可選, 預設為True
) — 是否打亂訓練資料集。
控制生成的引數
- generation_batch_size — (
int
或None
, 可選, 預設為None
): 用於生成的批處理大小。如果為None
,則預設為有效訓練批處理大小:per_device_train_batch_size * num_processes * steps_per_generation
。換句話說,每個最佳化步驟處理一個生成批次。與steps_per_generation
互斥。 - steps_per_generation — (
int
或None
, 可選, 預設為None
): 每次生成步數。如果為None
,則預設為gradient_accumulation_steps
。與generation_batch_size
互斥。 - temperature (
float
, 預設為1.0
) — 取樣的溫度。溫度越高,完成度越隨機。 - top_p (
float
, 可選, 預設為1.0
) — 控制要考慮的最高機率標記的累積機率的浮點數。必須在 (0, 1] 範圍內。設定為1.0
以考慮所有標記。 - top_k (
int
或None
, 可選, 預設為None
) — 保留用於 top-k 過濾的最高機率詞彙標記數量。如果為None
,則停用 top-k 過濾,並考慮所有標記。 - min_p (
float
或None
, 可選, 預設為None
) — 最小標記機率,將按最可能標記的機率進行縮放。它必須是0.0
到1.0
之間的值。典型值在0.01-0.2
範圍內。 - repetition_penalty (
float
, 可選, 預設為1.0
) — 根據新標記是否出現在提示和已生成文字中來懲罰新標記的浮點數。值 >1.0
鼓勵模型使用新標記,而值 <1.0
鼓勵模型重複標記。 - use_transformers_paged (
bool
, 可選, 預設為False
) — 是否使用transformers
的分頁實現進行生成。如果設定為True
,將使用transformers
的分頁實現進行生成,而不是預設的填充實現。此引數僅在use_vllm
設定為False
時有效。 - cache_implementation (
str
或None
, 可選, 預設為None
) — 當use_vllm
設定為False
時,用於更快生成快取方法的實現。 - generation_kwargs (
dict[str, Any]
或None
, 可選, 預設為None
) — 取樣完成時傳遞給GenerationConfig
(如果使用 transformers)或SamplingParams
(如果使用 vLLM)的附加關鍵字引數。這可用於進一步自定義生成行為,例如設定supress_tokens
、num_beams
等。如果它包含與其它生成引數(如min_p
、top_p
等)衝突的鍵,它們將覆蓋這些引數。
vLLM 支援的生成加速控制引數
- use_vllm (
bool
, 可選, 預設為False
) — 是否使用 vLLM 生成完成。如果設定為True
,訓練器將使用 vLLM 進行生成,而不是預設的 model.generate()。需要安裝vllm
。 - vllm_mode (
str
, 可選, 預設為"server"
) — 當use_vllm
設定為True
時,用於 vLLM 整合的模式。必須是"server"
或"colocate"
之一。"server"
: 訓練器將生成請求傳送到單獨的 vLLM 伺服器。請確保 TRL vLLM 伺服器正在執行(使用trl vllm-serve
啟動)。"colocate"
: vLLM 將在同一程序中執行並共享訓練 GPU。這避免了對單獨伺服器的需求,但可能會導致與訓練的資源爭用。
- vllm_guided_decoding_regex (
str
或None
, 可選, 預設為None
) — vLLM 引導式解碼的正則表示式。如果為None
(預設),則停用引導式解碼。
控制 vLLM 伺服器的引數(僅當 `vllm_mode` 為 `"server"` 時使用)
- vllm_server_base_url (
str
或None
, 可選, 預設為None
) — vLLM 伺服器的基本 URL(例如,"https://:8000"
)。如果提供此引數,則vllm_server_host
和vllm_server_port
將被忽略。 - vllm_server_host (
str
, 可選, 預設為"0.0.0.0"
) — 要連線的 vLLM 伺服器主機。如果提供了vllm_server_base_url
,則忽略此引數。 - vllm_server_port (
int
, 可選, 預設為8000
) — 要連線的 vLLM 伺服器埠。如果提供了vllm_server_base_url
,則忽略此引數。 - vllm_server_timeout (
float
, 可選, 預設為240.0
) — 等待 vLLM 伺服器啟動的總超時時長(秒)。如果在超時後伺服器仍未啟動,則會引發ConnectionError
。
控制共置 vLLM 執行的引數(僅當 `vllm_mode` 為 `"colocate"` 時使用)
- vllm_gpu_memory_utilization (
float
, 可選, 預設為0.3
) — 控制 vLLM 的 GPU 記憶體利用率。此設定僅在vllm_mode
設定為"colocate"
時適用。如果您使用vllm_mode="server"
,則必須在透過--vllm_gpu_memory_utilization
標誌啟動 vLLM 伺服器時單獨傳遞此引數。 - vllm_tensor_parallel_size (
int
, 可選, 預設為1
) — 控制 vLLM 的張量並行大小。此設定僅在vllm_mode
設定為"colocate"
時適用。如果您使用vllm_mode="server"
,則必須在透過--vllm_tensor_parallel_size
標誌啟動 vLLM 伺服器時單獨傳遞此引數。 - vllm_model_impl (
str
, 可選, 預設為"vllm"
) — 用於 vLLM 的模型實現。必須是"transformers"
或"vllm"
之一。"transformers"
:使用transformers
後端進行模型實現。"vllm"
:使用vllm
庫進行模型實現。
控制訓練的引數
- beta (
float
, 可選, 預設為0.0
) — KL 係數。如果為0.0
(預設),則不載入參考模型,從而減少記憶體使用並提高訓練速度。 - num_iterations (
int
, 可選, 預設為1
) — 每批次的迭代次數(在演算法中表示為 μ)。 - epsilon (
float
, 可選, 預設為0.2
) — 用於裁剪的 Epsilon 值。 - delta — (
float
或None
, 可選, 預設為None
): 當設定為浮點數時,啟用兩邊 GRPO 損失中的上限裁剪。如果為None
(預設),則使用標準 GRPO 裁剪。建議在啟用時大於1 + ε
。此方法在INTELLECT-2 技術報告中引入。 - epsilon_high (
float
或None
, 可選, 預設為None
) — 裁剪的上限 epsilon 值。如果未指定,則預設為與引數epsilon
中指定的下限相同的值。DAPO 論文推薦使用0.28
。 - importance_sampling_level (
str
, 可選, 預設為"token"
) — 控制重要性取樣比率是在"token"
級別還是"sequence"
級別計算。"token"
保留原始的每令牌對數機率比率(每個令牌一個權重)。"sequence"
對有效令牌的對數機率比率進行平均,為每個序列生成一個比率。GSPO 論文表明,序列級取樣通常會帶來更穩定的訓練和更好的與序列級獎勵對齊。 - reward_weights (
list[float]
或None
, 可選, 預設為None
) — 每個獎勵函式的權重。必須與獎勵函式的數量匹配。如果為None
,所有獎勵都以1.0
的權重平均加權。 - scale_rewards (
bool
, 可選, 預設為True
) — 是否透過將獎勵除以其標準差來縮放獎勵。如果為True
(預設),獎勵將按標準差歸一化,確保它們具有單位方差。如果為False
,則不應用縮放。Dr. GRPO 論文建議不縮放獎勵,因為按標準差縮放會引入問題級難度偏差。 - loss_type (
str
, 可選, 預設為"bnpo"
) — 指定要使用的損失公式。支援的值有:"grpo"
:透過對序列長度進行歸一化來聚合令牌級損失。不推薦,因為它會產生長度偏差——這種方法傾向於在具有正優勢時偏好較短的補全,而在具有負優勢時偏好較長的補全。"bnpo"
:透過對本地批次中活躍令牌的數量進行歸一化來聚合令牌級損失。請注意,歸一化僅在本地批次上執行,因此結果可能會因本地批次大小而略有不同,儘管有效批次大小是恆定的。當使用per_device_train_batch_size==1
時,損失等效於 GRPO 損失。"dr_grpo"
:透過全域性常數進行歸一化來聚合令牌級損失。此方法在Dr. GRPO 論文中引入,以消除長度偏差。常數的值對應於max_completion_length
。
- mask_truncated_completions (
bool
, 可選, 預設為False
) — 啟用後,截斷的補全將從損失計算中排除,防止它們被錯誤地懲罰並在訓練期間引入噪聲。根據DAPO 論文,這是訓練穩定性的良好實踐。 - sync_ref_model (
bool
, 可選, 預設為False
) — 是否每ref_model_sync_steps
步使用ref_model_mixup_alpha
引數將參考模型與活躍模型同步。此同步源自TR-DPO 論文。 - ref_model_mixup_alpha (
float
, 可選, 預設為0.6
) — 來自TR-DPO 論文的 α 引數,它控制當前策略和先前參考策略在更新期間的混合。參考策略根據以下公式更新:π_ref = α * π_θ + (1 - α) * π_ref_prev
。要使用此引數,必須設定sync_ref_model=True
。 - ref_model_sync_steps (
int
, 可選, 預設為512
) — 來自TR-DPO 論文的 τ 引數,它決定了當前策略與參考策略同步的頻率。要使用此引數,必須設定sync_ref_model=True
。 - top_entropy_quantile (
float
, 可選, 預設為1.0
) — 來自Beyond the 80/20 Rule的 ρ 引數。只保留每個序列位置上機率分佈熵的最高 ρ 分位數令牌在策略損失項中,從而改善結果。範圍:[0.0-1.0]
。值為0.0
遮蔽除最高熵令牌外的所有令牌;1.0
保留所有令牌。論文推薦值為0.2
。如果與mask_truncated_completions=True
一起使用,則只考慮非截斷補全中的令牌。 - use_liger_loss (
bool
, 可選, 預設為False
) — 是否使用 Liger GRPO 損失。
控制日誌記錄的引數
- log_completions (
bool
, 可選, 預設為False
) — 是否每logging_steps
步記錄一組 (提示, 補全) 對。如果安裝了rich
,它會列印該樣本。如果啟用了wandb
日誌記錄,它會將其記錄到wandb
。 - num_completions_to_print (
int
或None
, 可選, 預設為None
) — 要使用rich
列印的補全數量。如果為None
,則記錄所有補全。 - wandb_log_unique_prompts (
bool
, 可選, 預設為False
) — 是否在 wandb 中記錄唯一提示。如果為True
,則只記錄唯一提示。如果為False
,則記錄所有提示。
GRPOTrainer 的配置類。
此類僅包含 GRPO 訓練特有的引數。有關訓練引數的完整列表,請參閱 TrainingArguments
文件。請注意,此類的預設值可能與 TrainingArguments
中的預設值不同。
使用 HfArgumentParser
,我們可以將此類別轉換為可在命令列中指定的 argparse 引數。