TRL 文件

Nash-MD 訓練器

Hugging Face's logo
加入 Hugging Face 社群

並獲得增強的文件體驗

開始使用

Nash-MD 訓練器

概述

Nash-MD 由 Rémi Munos、Michal Valko、Daniele Calandriello、Mohammad Gheshlaghi Azar、Mark Rowland、Daniel Guo、Yunhao Tang、Matthieu Geist、Thomas Mésnard 和 Andrea Michi 在論文 《Nash Learning from Human Feedback》 中提出。

論文摘要如下:

從人類反饋中進行強化學習 (RLHF) 已成為使大型語言模型 (LLM) 與人類偏好對齊的主要正規化。通常,RLHF 的第一步是從人類反饋中學習一個獎勵模型,這些反饋通常表示為對預訓練 LLM 生成的一對文字之間的偏好。隨後,透過強化學習演算法最佳化 LLM 的策略,使其最大化獎勵模型。然而,當前獎勵模型的一個固有侷限性是它們無法完全表示人類偏好的豐富性,並且依賴於取樣分佈。在本研究中,我們介紹了一種使用成對人類反饋對 LLM 進行微調的替代流程。我們的方法包括首先學習一個基於給定提示的兩個輸入的偏好模型,然後尋求一種策略,該策略能持續生成優於任何競爭策略的響應,從而定義該偏好模型的納什均衡。我們將此方法稱為從人類反饋中進行納什學習 (NLHF)。在表格化策略表示的背景下,我們提出了一種新穎的演算法解決方案 Nash-MD,它基於映象下降的原理。該演算法產生一系列策略,最終迭代收斂到正則化的納什均衡。此外,我們探索了策略的引數化表示,併為深度學習架構引入了梯度下降演算法。為了展示我們方法的有效性,我們展示了在文字摘要任務中微調 LLM 的實驗結果。我們相信 NLHF 為偏好學習和策略最佳化提供了一條有吸引力的途徑,並有潛力推動使 LLM 與人類偏好對齊領域的發展。

此後訓練方法由 Kashif RasulDaniil TiapkinPierre Ménard、Daniele Calandriello 和 Quentin Gallouédec 貢獻。

快速開始

這個例子演示瞭如何使用 Nash-MD 方法訓練一個模型。我們使用 Qwen 0.5B 模型 作為基礎模型,並使用 PairRMJudge 作為評判器。我們使用來自 UltraFeedback 資料集 的提示。你可以在此處檢視資料集中的提示。

以下是訓練模型的指令碼

# train_nash_md.py
from datasets import load_dataset
from trl import NashMDConfig, NashMDTrainer, PairRMJudge
from transformers import AutoModelForCausalLM, AutoTokenizer

model = AutoModelForCausalLM.from_pretrained("Qwen/Qwen2-0.5B-Instruct")
tokenizer = AutoTokenizer.from_pretrained("Qwen/Qwen2-0.5B-Instruct")
judge = PairRMJudge()
train_dataset = load_dataset("trl-lib/ultrafeedback-prompt", split="train")

training_args = NashMDConfig(output_dir="Qwen2-0.5B-NashMD")
trainer = NashMDTrainer(
    model=model, judge=judge, args=training_args, processing_class=tokenizer, train_dataset=train_dataset
)
trainer.train()

使用以下命令執行指令碼

accelerate launch train_nash_md.py

在 8 個 GPU 上進行分散式訓練,大約需要 3 小時。

要檢視 訓練後模型 的效能,您可以使用 Transformers Chat CLI

$ transformers chat trl-lib/Qwen2-0.5B-NashMD
<quentin_gallouedec>:
What is the best programming language?

<trl-lib/Qwen2-0.5B-NashMD>:
The best programming language depends on personal preference, the complexity of the project, and the specific requirements of the task. Some programming languages that are often recommended include Python, Java, and JavaScript, and there are many other languages to choose from depending on individual needs.

預期資料集型別

Nash-MD 需要一個僅含提示的資料集NashMDTrainer 支援對話式標準兩種資料集格式。當提供對話式資料集時,訓練器會自動將聊天模板應用於資料集。

使用技巧

使用獎勵模型

除了評判器,您也可以選擇使用獎勵模型——請參閱 Reward Bench 獲取可用的公開模型排行榜。以下程式碼示例展示瞭如何用 trl-lib/Qwen2-0.5B-Reward 模型替換評判器。

- from trl import PairRMJudge
+ from transformers import AutoModelForSequenceClassification

- judge = PairRMJudge()
+ reward_model = AutoModelForSequenceClassification.from_pretrained("trl-lib/Qwen2-0.5B-Reward", num_labels=1)

  trainer = NashMDTrainer(
      ...
-     judge=judge,
+     reward_model=reward_model,
  )

請確保 SFT 模型和獎勵模型使用相同的聊天模板和分詞器。否則,您可能會發現在訓練過程中模型補全的評分不正確。

鼓勵生成 EOS 令牌

我們可能希望模型在給定的長度內生成補全。在訓練期間,模型將生成補全,其長度最多為 NashMDConfigmax_new_tokens 引數指定的最大長度。如果您想懲罰模型在達到最大長度之前未生成 EOS 令牌,可以使用 NashMDConfigmissing_eos_penalty 引數。

training_args = NashMDConfig(..., max_new_tokens=128, missing_eos_penalty=1.0)

記錄補全

為了更好地理解模型在訓練過程中的行為,您可以使用 LogCompletionsCallback 定期記錄樣本補全。

trainer = NashMDTrainer(..., eval_dataset=eval_dataset)
completions_callback = LogCompletionsCallback(trainer, num_prompts=8)
trainer.add_callback(completions_callback)

此回撥函式直接將模型生成的補全記錄到 Weights & Biases。

Logged Completions

示例指令碼

我們提供了一個使用 Nash-MD 方法訓練模型的示例指令碼。該指令碼位於 examples/scripts/nash_md.py

要使用 Qwen2.5 0.5B 模型UltraFeedback 資料集 上測試線上 DPO 指令碼,請執行以下命令:

python examples/scripts/nash_md.py \
    --model_name_or_path Qwen/Qwen2.5-0.5B-Instruct \
    --judge pair_rm \
    --dataset_name trl-lib/ultrafeedback-prompt \
    --learning_rate 5.0e-7 \
    --output_dir Qwen2.5-0.5B-NashMD-PairRM \
    --warmup_ratio 0.1 \
    --push_to_hub

記錄的指標

記錄的指標如下

  • loss/kl:模型與參考資料之間的平均 KL 散度。
  • objective/entropy:模型與參考資料的平均熵。
  • loss/score:平均強化分數損失。
  • rewards/chosen:模型補全的平均分數(根據獎勵模型)。
  • rewards/rejected:混合補全的平均分數(根據獎勵模型)。
  • rewards/probabilities:模型補全被選中與混合補全的平均機率(根據獎勵模型或評判器)。
  • rewards/accuracies:Nash-MD 隱式獎勵模型的準確率。
  • rewards/margins:被選中和混合補全之間的平均獎勵邊際(根據獎勵模型)。
  • logps/chosen:被選中補全的平均對數機率。
  • logps/rejected:參考補全的平均對數機率。
  • val/model_contain_eos_token:模型輸出包含 eos 令牌的次數。
  • val/ref_contain_eos_token:混合輸出包含 eos 令牌的次數。
  • beta:控制與參考模型偏差的損失項權重的引數。通常是固定的,但可以透過向 NashMDConfig 傳遞一個列表來使其動態化。
  • mixture_coef:模型和參考模型的 Logit 混合係數。通常是固定的,但可以透過向 NashMDConfig 傳遞一個列表來使其動態化。

NashMDTrainer

class trl.NashMDTrainer

< >

( model: typing.Union[transformers.modeling_utils.PreTrainedModel, torch.nn.modules.module.Module] = None ref_model: typing.Union[transformers.modeling_utils.PreTrainedModel, torch.nn.modules.module.Module] = None reward_model: typing.Union[transformers.modeling_utils.PreTrainedModel, torch.nn.modules.module.Module, NoneType] = None judge: typing.Optional[trl.trainer.judges.BasePairwiseJudge] = None args: typing.Optional[trl.trainer.nash_md_config.NashMDConfig] = None data_collator: typing.Optional[typing.Callable] = None train_dataset: typing.Union[datasets.arrow_dataset.Dataset, datasets.iterable_dataset.IterableDataset, NoneType] = None eval_dataset: typing.Union[datasets.arrow_dataset.Dataset, dict[str, datasets.arrow_dataset.Dataset], NoneType] = None processing_class: typing.Union[transformers.tokenization_utils_base.PreTrainedTokenizerBase, transformers.image_processing_utils.BaseImageProcessor, transformers.feature_extraction_utils.FeatureExtractionMixin, transformers.processing_utils.ProcessorMixin, NoneType] = None peft_config: typing.Optional[dict] = None compute_metrics: typing.Optional[typing.Callable[[transformers.trainer_utils.EvalPrediction], dict]] = None callbacks: typing.Optional[list[transformers.trainer_callback.TrainerCallback]] = None optimizers: tuple = (None, None) preprocess_logits_for_metrics: typing.Optional[typing.Callable[[torch.Tensor, torch.Tensor], torch.Tensor]] = None )

引數

  • model (transformers.PreTrainedModel) — 用於訓練的模型,最好是 AutoModelForCausalLM
  • ref_model (PreTrainedModelWrapper) — 帶有因果語言模型頭的 Hugging Face transformer 模型。用於隱式獎勵計算和損失。如果未提供參考模型,訓練器將建立一個與待最佳化模型具有相同架構的參考模型。
  • reward_model (transformers.PreTrainedModel) — 用於對補全進行評分的獎勵模型,最好是 AutoModelForSequenceClassification
  • judge (BasePairwiseJudge) — 用於對模型補全進行成對比較的評判器。
  • args (NashMDConfig) — 用於訓練的 NashMD 配置引數。
  • data_collator (transformers.DataCollator) — 用於訓練的資料整理器。如果未指定,將使用預設的資料整理器 (DPODataCollatorWithPadding),該整理器會根據批次中序列的最大長度,對成對序列的資料集進行填充。
  • train_dataset (datasets.Dataset) — 用於訓練的資料集。
  • eval_dataset (datasets.Dataset) — 用於評估的資料集。
  • processing_class (PreTrainedTokenizerBase, BaseImageProcessor, FeatureExtractionMixinProcessorMixin, 可選, 預設為 None) — 用於處理資料的處理類。如果提供,將用於自動處理模型的輸入,並與模型一起儲存,以便更容易地重新執行中斷的訓練或重用微調後的模型。
  • peft_config (dict) — 用於訓練的 peft 配置。
  • compute_metrics (Callable[[EvalPrediction], dict], 可選) — 用於計算指標的函式。必須接受一個 EvalPrediction 並返回一個從字串到指標值的字典。
  • callbacks (list[transformers.TrainerCallback]) — 用於訓練的回撥函式。
  • optimizers (tuple[torch.optim.Optimizer, torch.optim.lr_scheduler.LambdaLR]) — 用於訓練的最佳化器和排程器。
  • preprocess_logits_for_metrics (Callable[[torch.Tensor, torch.Tensor], torch.Tensor]) — 在計算指標前用於預處理 logits 的函式。

初始化 NashMDTrainer 作為 OnlineDPOConfig 的子類。

train

< >

( resume_from_checkpoint: typing.Union[str, bool, NoneType] = None trial: typing.Union[ForwardRef('optuna.Trial'), dict[str, typing.Any], NoneType] = None ignore_keys_for_eval: typing.Optional[list[str]] = None **kwargs )

引數

  • resume_from_checkpoint (strbool, 可選) — 如果是 str,則為之前 Trainer 例項儲存的檢查點的本地路徑。如果是 bool 且為 True,則載入之前 Trainer 例項在 *args.output_dir* 中儲存的最後一個檢查點。如果存在,訓練將從此處載入的模型/最佳化器/排程器狀態繼續。
  • trial (optuna.Trialdict[str, Any], 可選) — 試驗執行或用於超引數搜尋的超引數字典。
  • ignore_keys_for_eval (list[str], 可選) — 在訓練期間收集評估預測時,應忽略的模型輸出中的鍵列表(如果輸出是字典)。
  • kwargs (dict[str, Any], 可選) — 用於隱藏已棄用引數的附加關鍵字引數

主訓練入口點。

save_model

< >

( output_dir: typing.Optional[str] = None _internal_call: bool = False )

將儲存模型,以便您可以使用 `from_pretrained()` 重新載入它。

僅從主程序儲存。

push_to_hub

< >

( commit_message: typing.Optional[str] = 'End of training' blocking: bool = True token: typing.Optional[str] = None revision: typing.Optional[str] = None **kwargs )

引數

  • commit_message (str, 可選, 預設為 "End of training") — 推送時提交的訊息。
  • blocking (bool, 可選, 預設為 True) — 該函式是否應該僅在 git push 完成後返回。
  • token (str, 可選, 預設為 None) — 具有寫入許可權的令牌,用於覆蓋 Trainer 的原始引數。
  • revision (str, 可選) — 要提交的 git 修訂版本。預設為 "main" 分支的頭部。
  • kwargs (dict[str, Any], 可選) — 傳遞給 ~Trainer.create_model_card 的附加關鍵字引數。

將 `self.model` 和 `self.processing_class` 上傳到 🤗 模型中心的 `self.args.hub_model_id` 儲存庫。

NashMDConfig

class trl.NashMDConfig

< >

( output_dir: typing.Optional[str] = None overwrite_output_dir: bool = False do_train: bool = False do_eval: bool = False do_predict: bool = False eval_strategy: typing.Union[transformers.trainer_utils.IntervalStrategy, str] = 'no' prediction_loss_only: bool = False per_device_train_batch_size: int = 8 per_device_eval_batch_size: int = 8 per_gpu_train_batch_size: typing.Optional[int] = None per_gpu_eval_batch_size: typing.Optional[int] = None gradient_accumulation_steps: int = 1 eval_accumulation_steps: typing.Optional[int] = None eval_delay: typing.Optional[float] = 0 torch_empty_cache_steps: typing.Optional[int] = None learning_rate: float = 5e-07 weight_decay: float = 0.0 adam_beta1: float = 0.9 adam_beta2: float = 0.999 adam_epsilon: float = 1e-08 max_grad_norm: float = 1.0 num_train_epochs: float = 3.0 max_steps: int = -1 lr_scheduler_type: typing.Union[transformers.trainer_utils.SchedulerType, str] = 'linear' lr_scheduler_kwargs: typing.Union[dict[str, typing.Any], str, NoneType] = <factory> warmup_ratio: float = 0.0 warmup_steps: int = 0 log_level: str = 'passive' log_level_replica: str = 'warning' log_on_each_node: bool = True logging_dir: typing.Optional[str] = None logging_strategy: typing.Union[transformers.trainer_utils.IntervalStrategy, str] = 'steps' logging_first_step: bool = False logging_steps: float = 10 logging_nan_inf_filter: bool = True save_strategy: typing.Union[transformers.trainer_utils.SaveStrategy, str] = 'steps' save_steps: float = 500 save_total_limit: typing.Optional[int] = None save_safetensors: typing.Optional[bool] = True save_on_each_node: bool = False save_only_model: bool = False restore_callback_states_from_checkpoint: bool = False no_cuda: bool = False use_cpu: bool = False use_mps_device: bool = False seed: int = 42 data_seed: typing.Optional[int] = None jit_mode_eval: bool = False use_ipex: bool = False bf16: typing.Optional[bool] = None fp16: bool = False fp16_opt_level: str = 'O1' half_precision_backend: str = 'auto' bf16_full_eval: bool = False fp16_full_eval: bool = False tf32: typing.Optional[bool] = None local_rank: int = -1 ddp_backend: typing.Optional[str] = None tpu_num_cores: typing.Optional[int] = None tpu_metrics_debug: bool = False debug: typing.Union[str, list[transformers.debug_utils.DebugOption]] = '' dataloader_drop_last: bool = False eval_steps: typing.Optional[float] = None dataloader_num_workers: int = 0 dataloader_prefetch_factor: typing.Optional[int] = None past_index: int = -1 run_name: typing.Optional[str] = None disable_tqdm: typing.Optional[bool] = None remove_unused_columns: typing.Optional[bool] = True label_names: typing.Optional[list[str]] = None load_best_model_at_end: typing.Optional[bool] = False metric_for_best_model: typing.Optional[str] = None greater_is_better: typing.Optional[bool] = None ignore_data_skip: bool = False fsdp: typing.Union[list[transformers.trainer_utils.FSDPOption], str, NoneType] = '' fsdp_min_num_params: int = 0 fsdp_config: typing.Union[dict[str, typing.Any], str, NoneType] = None fsdp_transformer_layer_cls_to_wrap: typing.Optional[str] = None accelerator_config: typing.Union[dict, str, NoneType] = None deepspeed: typing.Union[dict, str, NoneType] = None label_smoothing_factor: float = 0.0 optim: typing.Union[transformers.training_args.OptimizerNames, str] = 'adamw_torch' optim_args: typing.Optional[str] = None adafactor: bool = False group_by_length: bool = False length_column_name: typing.Optional[str] = 'length' report_to: typing.Union[NoneType, str, list[str]] = None ddp_find_unused_parameters: typing.Optional[bool] = None ddp_bucket_cap_mb: typing.Optional[int] = None ddp_broadcast_buffers: typing.Optional[bool] = None dataloader_pin_memory: bool = True dataloader_persistent_workers: bool = False skip_memory_metrics: bool = True use_legacy_prediction_loop: bool = False push_to_hub: bool = False resume_from_checkpoint: typing.Optional[str] = None hub_model_id: typing.Optional[str] = None hub_strategy: typing.Union[transformers.trainer_utils.HubStrategy, str] = 'every_save' hub_token: typing.Optional[str] = None hub_private_repo: typing.Optional[bool] = None hub_always_push: bool = False hub_revision: typing.Optional[str] = None gradient_checkpointing: bool = False gradient_checkpointing_kwargs: typing.Union[dict[str, typing.Any], str, NoneType] = None include_inputs_for_metrics: bool = False include_for_metrics: list = <factory> eval_do_concat_batches: bool = True fp16_backend: str = 'auto' push_to_hub_model_id: typing.Optional[str] = None push_to_hub_organization: typing.Optional[str] = None push_to_hub_token: typing.Optional[str] = None mp_parameters: str = '' auto_find_batch_size: bool = False full_determinism: bool = False torchdynamo: typing.Optional[str] = None ray_scope: typing.Optional[str] = 'last' ddp_timeout: int = 1800 torch_compile: bool = False torch_compile_backend: typing.Optional[str] = None torch_compile_mode: typing.Optional[str] = None include_tokens_per_second: typing.Optional[bool] = False include_num_input_tokens_seen: typing.Optional[bool] = False neftune_noise_alpha: typing.Optional[float] = None optim_target_modules: typing.Union[NoneType, str, list[str]] = None batch_eval_metrics: bool = False eval_on_start: bool = False use_liger_kernel: typing.Optional[bool] = False liger_kernel_config: typing.Optional[dict[str, bool]] = None eval_use_gather_object: typing.Optional[bool] = False average_tokens_across_devices: typing.Optional[bool] = True reward_model_path: typing.Optional[str] = None judge: typing.Optional[str] = None max_new_tokens: int = 64 max_length: int = 512 temperature: float = 0.9 missing_eos_penalty: typing.Optional[float] = None beta: list = <factory> loss_type: str = 'sigmoid' dataset_num_proc: typing.Optional[int] = None disable_dropout: bool = True use_vllm: bool = False vllm_model_impl: str = 'vllm' gpu_memory_utilization: typing.Optional[float] = 0.55 ds3_gather_for_generation: bool = True model_init_kwargs: typing.Optional[dict[str, typing.Any]] = None mixture_coef: list = <factory> )

引數

  • mixture_coef (float or list[float], optional, defaults to 0.5) — 用於模型和參考模型的 Logit 混合係數。如果提供一個浮點數列表,則會為每個新週期選擇混合係數,最後一個係數將用於剩餘的週期。

NashMDTrainer 的配置類。

OnlineDPOConfig 的子類,我們可以使用其所有引數,並新增以下內容

< > 在 GitHub 上更新

© . This site is unofficial and not affiliated with Hugging Face, Inc.