TRL 文件
PRM 訓練器
並獲得增強的文件體驗
開始使用
PRM 訓練器
PRM 訓練器是一個實驗性 API,可能隨時會發生變化。
概述
過程監督獎勵模型(Process-supervised Reward Models,PRM)由 Jonathan Uesato, Nate Kushman, Ramana Kumar, Francis Song, Noah Siegel, Lisa Wang, Antonia Creswell, Geoffrey Irving 和 Irina Higgins 在論文 《利用過程和結果反饋解決數學應用題》 中提出。
論文摘要如下:
最近的研究表明,要求語言模型生成推理步驟可以提高其在許多推理任務上的表現。當超越提示工程時,這就引出了一個問題:我們應該如何監督這些模型?是採用監督最終結果的基於結果的方法,還是採用監督推理過程本身的基於過程的方法?這兩種方法之間的差異不僅體現在最終答案的錯誤上,還可能體現在推理錯誤上,後者難以檢測,並且在教育等許多現實世界領域中存在問題。我們對基於過程和基於結果的方法在自然語言任務 GSM8K 上進行了首次全面比較。我們發現,純粹基於結果的監督在較少的標籤監督下,產生了相似的最終答案錯誤率。然而,對於正確的推理步驟,我們發現必須使用基於過程的監督或來自模仿過程反饋的習得獎勵模型的監督。總的來說,我們改進了之前的最佳結果,將最終答案錯誤率從 16.8% 降至 12.7%,並將最終答案正確的解決方案中的推理錯誤率從 14.0% 降至 3.4%。
此後訓練方法由 Gaetan Lopez、Lewis Tunstall、Quentin Gallouédec 和 Agustín Piqueres 貢獻。
快速入門
這個例子演示瞭如何使用 PRM 方法訓練一個模型。我們使用 Qwen 0.5B 模型 作為基礎模型。我們使用來自 Math Shepherd 資料集 的逐步監督資料。你可以在這裡檢視資料集中的資料
以下是訓練模型的指令碼
# train_prm.py
from datasets import load_dataset
from trl import PRMConfig, PRMTrainer
from transformers import AutoModelForTokenClassification, AutoTokenizer
model = AutoModelForTokenClassification.from_pretrained("Qwen/Qwen2-0.5B", num_labels=2)
tokenizer = AutoTokenizer.from_pretrained("Qwen/Qwen2-0.5B")
train_dataset = load_dataset("trl-lib/math_shepherd", split="train[:10%]")
training_args = PRMConfig(output_dir="Qwen2-0.5B-Reward-Math-Sheperd")
trainer = PRMTrainer(model=model, args=training_args, processing_class=tokenizer, train_dataset=train_dataset)
trainer.train()
使用以下命令執行指令碼
accelerate launch train_prm.py
在 8 個 GPU 上分散式訓練大約需要 1 小時。
要檢視 訓練好的模型 的表現,你可以使用以下指令碼。
from datasets import load_dataset
from transformers import pipeline
pipe = pipeline("token-classification", model="trl-lib/Qwen2-0.5B-Reward-Math-Sheperd")
dataset = load_dataset("trl-lib/math_shepherd")
example = {
"prompt": "Musa is the class teacher of a class of 45 students. He wants to split them into three groups by age. If a third of the class is under 11 years, and two-fifths are above 11 but under 13, how many students will be in the third group (13 years and above)?",
"completions": [
"Step 1: A third of the class is under 11 years because 11 - 1/3 = <<11-1/3=7>>7.",
"Step 2: Two-fifths of the class are above 11 but under 13 because 2/5 * 11 = <<2/5*11=8>>8.",
"Step 3: There are 45 students, so the third group will have 45 - 7 - 8 = <<45-7-8=20>>20 students. The answer is: 20",
],
"labels": [True, False, False],
}
separator = "\n" # It's important to use the same separator as the one used during training
for idx in range(1, len(example["completions"]) + 1):
steps = example["completions"][0:idx]
text = separator.join((example["prompt"], *steps)) + separator # Add a separator between the prompt and each steps
pred_entity = pipe(text)[-1]["entity"]
pred = {"LABEL_0": False, "LABEL_1": True}[pred_entity]
label = example["labels"][idx - 1]
print(f"Step {idx}\tPredicted: {pred} \tLabel: {label}")
Step 1 Predicted: True Label: True Step 2 Predicted: False Label: False Step 3 Predicted: False Label: False
成功了!
預期的資料集型別
PRM 需要逐步監督。資料集應包含以下列:prompt
、completions
和 labels
,其中 completions
包含一個推理步驟列表,labels
包含一個布林值或浮點數列表,表示每個步驟的正確性。
PRMTrainer 僅支援標準資料集格式。
示例指令碼
我們提供一個示例指令碼來使用 PRM 方法訓練模型。該指令碼可在 examples/scripts/prm.py
中找到。
要在 Math Shepherd 資料集 上使用 PRM 指令碼訓練 Qwen2 0.5B 模型,請執行以下命令
accelerate launch examples/scripts/prm.py \ --model_name_or_path Qwen/Qwen2-0.5B \ --dataset_name trl-lib/math_shepherd \ --num_train_epochs 1 \ --output_dir Qwen2-0.5B-Reward-Math-Sheperd
PRMTrainer
class trl.PRMTrainer
< 原始碼 >( model: typing.Union[transformers.modeling_utils.PreTrainedModel, torch.nn.modules.module.Module, NoneType] = None args: typing.Optional[trl.trainer.prm_config.PRMConfig] = None data_collator: typing.Optional[transformers.data.data_collator.DataCollator] = None train_dataset: typing.Optional[datasets.arrow_dataset.Dataset] = None eval_dataset: typing.Union[datasets.arrow_dataset.Dataset, dict[str, datasets.arrow_dataset.Dataset], NoneType] = None processing_class: typing.Union[transformers.tokenization_utils_base.PreTrainedTokenizerBase, transformers.image_processing_utils.BaseImageProcessor, transformers.feature_extraction_utils.FeatureExtractionMixin, transformers.processing_utils.ProcessorMixin, NoneType] = None model_init: typing.Optional[typing.Callable[[], transformers.modeling_utils.PreTrainedModel]] = None compute_metrics: typing.Optional[typing.Callable[[transformers.trainer_utils.EvalPrediction], dict]] = None callbacks: typing.Optional[list[transformers.trainer_callback.TrainerCallback]] = None optimizers: tuple = (None, None) preprocess_logits_for_metrics: typing.Optional[typing.Callable[[torch.Tensor, torch.Tensor], torch.Tensor]] = None peft_config: typing.Optional[dict] = None )
引數
- model (
transformers.PreTrainedModel
) — 待訓練的模型,最好是AutoModelForTokenClassification
。 - args (
PRMConfig
) — 訓練時使用的引數。 - data_collator (
transformers.DataCollator
) — 訓練時使用的資料整理器。如果未指定,將使用預設的資料整理器 (DataCollatorForTokenClassification
),它會根據批次中序列的最大長度對序列進行填充,適用於成對序列的資料集。 - train_dataset (
datasets.Dataset
) — 用於訓練的資料集。 - eval_dataset (
datasets.Dataset
) — 用於評估的資料集。 - processing_class (
PreTrainedTokenizerBase
、BaseImageProcessor
、FeatureExtractionMixin
或ProcessorMixin
,*可選*,預設為 `None`) — 用於處理資料的處理類。如果提供,將用於自動處理模型的輸入,並會與模型一起儲存,以便更容易地重新執行中斷的訓練或重用微調後的模型。 - model_init (
Callable[[], transformers.PreTrainedModel]
) — 用於訓練的模型初始化器。如果未指定,將使用預設的模型初始化器。 - compute_metrics (
Callable[[transformers.EvalPrediction], dict]
, *可選*,預設為 `compute_accuracy`) — 用於評估的指標。如果未指定指標,將使用預設指標 (`compute_accuracy`)。 - callbacks (
list[transformers.TrainerCallback]
) — 用於訓練的回撥函式。 - optimizers (
tuple[torch.optim.Optimizer, torch.optim.lr_scheduler.LambdaLR]
) — 用於訓練的最佳化器和排程器。 - preprocess_logits_for_metrics (
Callable[[torch.Tensor, torch.Tensor], torch.Tensor]
) — 用於在計算指標前預處理 logits 的函式。 - peft_config (
dict
, 預設為 `None`) — 用於訓練的 PEFT 配置。如果傳遞 PEFT 配置,模型將被包裝在 PEFT 模型中。
初始化 PRMTrainer。
訓練
< 原始碼 >( resume_from_checkpoint: typing.Union[str, bool, NoneType] = None trial: typing.Union[ForwardRef('optuna.Trial'), dict[str, typing.Any], NoneType] = None ignore_keys_for_eval: typing.Optional[list[str]] = None **kwargs )
引數
- resume_from_checkpoint (
str
或bool
,*可選*) — 如果是str
,則為上一個 `Trainer` 例項儲存的檢查點的本地路徑。如果是bool
且等於 `True`,則載入上一個 `Trainer` 例項在 *args.output_dir* 中儲存的最後一個檢查點。如果存在,訓練將從此處載入的模型/最佳化器/排程器狀態恢復。 - trial (
optuna.Trial
或dict[str, Any]
,*可選*) — 用於超引數搜尋的試驗執行或超引數字典。 - ignore_keys_for_eval (
list[str]
, *可選*) — 在訓練期間收集評估預測時,模型輸出中(如果為字典)應忽略的鍵的列表。 - kwargs (
dict[str, Any]
, *可選*) — 用於隱藏已棄用引數的附加關鍵字引數
主訓練入口點。
將儲存模型,以便您可以使用 `from_pretrained()` 重新載入它。
僅從主程序儲存。
push_to_hub
< 原始碼 >( commit_message: typing.Optional[str] = 'End of training' blocking: bool = True token: typing.Optional[str] = None revision: typing.Optional[str] = None **kwargs )
引數
- commit_message (
str
, *可選*,預設為 `"End of training"`) — 推送時要提交的訊息。 - blocking (
bool
, *可選*,預設為 `True`) — 該函式是否僅在 `git push` 完成後才返回。 - token (
str
, *可選*,預設為 `None`) — 具有寫入許可權的令牌,用於覆蓋 Trainer 的原始引數。 - revision (
str
, *可選*) — 要提交的 git 修訂版本。預設為“main”分支的頭部。 - kwargs (
dict[str, Any]
, *可選*) — 傳遞給 `~Trainer.create_model_card` 的附加關鍵字引數。
將 `self.model` 和 `self.processing_class` 上傳到 🤗 模型中心的 `self.args.hub_model_id` 儲存庫。
PRMConfig
class trl.PRMConfig
< 原始碼 >( output_dir: typing.Optional[str] = None overwrite_output_dir: bool = False do_train: bool = False do_eval: bool = False do_predict: bool = False eval_strategy: typing.Union[transformers.trainer_utils.IntervalStrategy, str] = 'no' prediction_loss_only: bool = False per_device_train_batch_size: int = 8 per_device_eval_batch_size: int = 8 per_gpu_train_batch_size: typing.Optional[int] = None per_gpu_eval_batch_size: typing.Optional[int] = None gradient_accumulation_steps: int = 1 eval_accumulation_steps: typing.Optional[int] = None eval_delay: typing.Optional[float] = 0 torch_empty_cache_steps: typing.Optional[int] = None learning_rate: float = 1e-05 weight_decay: float = 0.0 adam_beta1: float = 0.9 adam_beta2: float = 0.999 adam_epsilon: float = 1e-08 max_grad_norm: float = 1.0 num_train_epochs: float = 3.0 max_steps: int = -1 lr_scheduler_type: typing.Union[transformers.trainer_utils.SchedulerType, str] = 'linear' lr_scheduler_kwargs: typing.Union[dict[str, typing.Any], str, NoneType] = <factory> warmup_ratio: float = 0.0 warmup_steps: int = 0 log_level: str = 'passive' log_level_replica: str = 'warning' log_on_each_node: bool = True logging_dir: typing.Optional[str] = None logging_strategy: typing.Union[transformers.trainer_utils.IntervalStrategy, str] = 'steps' logging_first_step: bool = False logging_steps: float = 10 logging_nan_inf_filter: bool = True save_strategy: typing.Union[transformers.trainer_utils.SaveStrategy, str] = 'steps' save_steps: float = 500 save_total_limit: typing.Optional[int] = None save_safetensors: typing.Optional[bool] = True save_on_each_node: bool = False save_only_model: bool = False restore_callback_states_from_checkpoint: bool = False no_cuda: bool = False use_cpu: bool = False use_mps_device: bool = False seed: int = 42 data_seed: typing.Optional[int] = None jit_mode_eval: bool = False use_ipex: bool = False bf16: typing.Optional[bool] = None fp16: bool = False fp16_opt_level: str = 'O1' half_precision_backend: str = 'auto' bf16_full_eval: bool = False fp16_full_eval: bool = False tf32: typing.Optional[bool] = None local_rank: int = -1 ddp_backend: typing.Optional[str] = None tpu_num_cores: typing.Optional[int] = None tpu_metrics_debug: bool = False debug: typing.Union[str, list[transformers.debug_utils.DebugOption]] = '' dataloader_drop_last: bool = False eval_steps: typing.Optional[float] = None dataloader_num_workers: int = 0 dataloader_prefetch_factor: typing.Optional[int] = None past_index: int = -1 run_name: typing.Optional[str] = None disable_tqdm: typing.Optional[bool] = None remove_unused_columns: typing.Optional[bool] = True label_names: typing.Optional[list[str]] = None load_best_model_at_end: typing.Optional[bool] = False metric_for_best_model: typing.Optional[str] = None greater_is_better: typing.Optional[bool] = None ignore_data_skip: bool = False fsdp: typing.Union[list[transformers.trainer_utils.FSDPOption], str, NoneType] = '' fsdp_min_num_params: int = 0 fsdp_config: typing.Union[dict[str, typing.Any], str, NoneType] = None fsdp_transformer_layer_cls_to_wrap: typing.Optional[str] = None accelerator_config: typing.Union[dict, str, NoneType] = None deepspeed: typing.Union[dict, str, NoneType] = None label_smoothing_factor: float = 0.0 optim: typing.Union[transformers.training_args.OptimizerNames, str] = 'adamw_torch' optim_args: typing.Optional[str] = None adafactor: bool = False group_by_length: bool = False length_column_name: typing.Optional[str] = 'length' report_to: typing.Union[NoneType, str, list[str]] = None ddp_find_unused_parameters: typing.Optional[bool] = None ddp_bucket_cap_mb: typing.Optional[int] = None ddp_broadcast_buffers: typing.Optional[bool] = None dataloader_pin_memory: bool = True dataloader_persistent_workers: bool = False skip_memory_metrics: bool = True use_legacy_prediction_loop: bool = False push_to_hub: bool = False resume_from_checkpoint: typing.Optional[str] = None hub_model_id: typing.Optional[str] = None hub_strategy: typing.Union[transformers.trainer_utils.HubStrategy, str] = 'every_save' hub_token: typing.Optional[str] = None hub_private_repo: typing.Optional[bool] = None hub_always_push: bool = False hub_revision: typing.Optional[str] = None gradient_checkpointing: bool = False gradient_checkpointing_kwargs: typing.Union[dict[str, typing.Any], str, NoneType] = None include_inputs_for_metrics: bool = False include_for_metrics: list = <factory> eval_do_concat_batches: bool = True fp16_backend: str = 'auto' push_to_hub_model_id: typing.Optional[str] = None push_to_hub_organization: typing.Optional[str] = None push_to_hub_token: typing.Optional[str] = None mp_parameters: str = '' auto_find_batch_size: bool = False full_determinism: bool = False torchdynamo: typing.Optional[str] = None ray_scope: typing.Optional[str] = 'last' ddp_timeout: int = 1800 torch_compile: bool = False torch_compile_backend: typing.Optional[str] = None torch_compile_mode: typing.Optional[str] = None include_tokens_per_second: typing.Optional[bool] = False include_num_input_tokens_seen: typing.Optional[bool] = False neftune_noise_alpha: typing.Optional[float] = None optim_target_modules: typing.Union[NoneType, str, list[str]] = None batch_eval_metrics: bool = False eval_on_start: bool = False use_liger_kernel: typing.Optional[bool] = False liger_kernel_config: typing.Optional[dict[str, bool]] = None eval_use_gather_object: typing.Optional[bool] = False average_tokens_across_devices: bool = True max_length: typing.Optional[int] = 1024 max_prompt_length: typing.Optional[int] = 512 max_completion_length: typing.Optional[int] = None disable_dropout: bool = True step_separator: str = '\n' train_on_last_step_only: bool = False dataset_num_proc: typing.Optional[int] = None )
引數
- max_length (
int
或None
, 可選, 預設為1024
) — 用於截斷的序列(提示+補全)的最大長度。 - max_prompt_length (
int
或None
, 可選, 預設為512
) — 用於截斷的提示的最大長度。 - max_completion_length (
int
或None
, 可選, 預設為None
) — 用於截斷的補全的最大長度。補全是所有步驟的串聯。 - disable_dropout (
bool
, 可選, 預設為True
) — 是否在模型中停用 dropout。 - step_separator (
str
, 可選, 預設為"\n"
) — 用於分隔推理過程中每個步驟的分隔符。 - train_on_last_step_only (
bool
, 可選, 預設為False
) — 是否僅在最後一步進行訓練。 - dataset_num_proc (
int
, 可選, 預設為None
) — 用於處理資料集的程序數。
PRMTrainer 的配置類。
此類僅包含 PRM 訓練專用的引數。有關訓練引數的完整列表,請參閱 TrainingArguments
文件。請注意,此類中的預設值可能與 TrainingArguments
中的預設值不同。
使用 HfArgumentParser
,我們可以將此類別轉換為可在命令列上指定的 argparse 引數。