TRL 文件

KTO 訓練器

Hugging Face's logo
加入 Hugging Face 社群

並獲得增強的文件體驗

開始使用

KTO 訓練器

概覽

卡尼曼-特沃斯基最佳化(Kahneman-Tversky Optimization, KTO)由 Kawin EthayarajhWinnie XuNiklas Muennighoff、Dan Jurafsky 和 Douwe Kiela 在論文 KTO: Model Alignment as Prospect Theoretic Optimization 中提出。

論文摘要如下:

卡尼曼和特沃斯基的前景理論告訴我們,人類以一種有偏見但明確的方式感知隨機變數;例如,人類是出了名的損失厭惡。我們發現,用於使大型語言模型(LLM)與人類反饋對齊的目標函式,隱式地包含了許多這些偏見——這些目標函式(例如 DPO)相對於交叉熵最小化的成功,部分可以歸因於它們是“人類感知損失函式”(human-aware loss functions, HALOs)。然而,這些方法歸因於人類的效用函式仍然與前景理論文獻中的有所不同。我們使用一個卡尼曼-特沃斯基的人類效用模型,提出了一個直接最大化生成內容效用的 HALO,而不是像現有方法那樣最大化偏好資料的對數似然。我們將這種方法稱為卡尼曼-特沃斯基最佳化(KTO),在從 10 億到 300 億引數規模的模型上,它的效能與基於偏好的方法相當或更優。關鍵的是,KTO 不需要偏好資料——只需要一個二元訊號,即對於給定的輸入,輸出是理想的還是不理想的。這使得它在現實世界中更容易使用,因為偏好資料稀少且昂貴。

官方程式碼可以在 ContextualAI/HALOs 中找到。

此後訓練方法由 Kashif RasulYounes BelkadaLewis Tunstall 和 Pablo Vicente 貢獻。

快速入門

此示例演示瞭如何使用 KTO 方法訓練模型。我們使用 Qwen 0.5B 模型作為基礎模型。我們使用來自 KTO Mix 14k 的偏好資料。你可以在資料集中檢視資料。

以下是訓練模型的指令碼

# train_kto.py
from datasets import load_dataset
from trl import KTOConfig, KTOTrainer
from transformers import AutoModelForCausalLM, AutoTokenizer

model = AutoModelForCausalLM.from_pretrained("Qwen/Qwen2-0.5B-Instruct")
tokenizer = AutoTokenizer.from_pretrained("Qwen/Qwen2-0.5B-Instruct")
train_dataset = load_dataset("trl-lib/kto-mix-14k", split="train")

training_args = KTOConfig(output_dir="Qwen2-0.5B-KTO")
trainer = KTOTrainer(model=model, args=training_args, processing_class=tokenizer, train_dataset=train_dataset)
trainer.train()

使用以下命令執行指令碼

accelerate launch train_kto.py

在 8 個 H100 GPU 上進行分散式訓練,大約需要 30 分鐘。你可以透過檢查獎勵圖來驗證訓練進度。獎勵邊際的上升趨勢表明模型正在改進並隨著時間的推移生成更好的響應。

要檢視訓練後的模型表現如何,可以使用 Transformers Chat CLI

$ transformers chat trl-lib/Qwen2-0.5B-KTO
<quentin_gallouedec>:
What is the best programming language?

<trl-lib/Qwen2-0.5B-KTO>:
The best programming language can vary depending on individual preferences, industry-specific requirements, technical skills, and familiarity with the specific use case or task. Here are some widely-used programming languages that have been noted as popular and widely used:                                                                                  

Here are some other factors to consider when choosing a programming language for a project:

 1 JavaScript: JavaScript is at the heart of the web and can be used for building web applications, APIs, and interactive front-end applications like frameworks like React and Angular. It's similar to C, C++, and F# in syntax structure and is accessible and easy to learn, making it a popular choice for beginners and professionals alike.                                                                   
 2 Java: Known for its object-oriented programming (OOP) and support for Java 8 and .NET, Java is used for developing enterprise-level software applications, high-performance games, as well as mobile apps, game development, and desktop applications.                                                                                                                                                            
 3 C++: Known for its flexibility and scalability, C++ offers comprehensive object-oriented programming and is a popular choice for high-performance computing and other technical fields. It's a powerful platform for building real-world applications and games at scale.                                                                                                                                         
 4 Python: Developed by Guido van Rossum in 1991, Python is a high-level, interpreted, and dynamically typed language known for its simplicity, readability, and versatility.   

預期資料集格式

KTO 需要一個非成對偏好資料集。或者,您也可以提供一個*成對*偏好資料集(也簡稱為*偏好資料集*)。在這種情況下,訓練器會自動將其轉換為非成對格式,方法是分拆“選擇的”和“拒絕的”響應,為“選擇的”完成分配 `label = True`,為“拒絕的”完成分配 `label = False`。

KTOTrainer 支援對話式標準兩種資料集格式。當提供對話式資料集時,訓練器會自動將聊天模板應用於該資料集。

理論上,資料集應至少包含一個“選擇的”和一個“拒絕的”完成。然而,一些使用者僅使用“選擇的”或僅使用“拒絕的”資料也成功運行了 KTO。如果僅使用“拒絕的”資料,建議採用保守的學習率。

示例指令碼

我們提供了一個使用 KTO 方法訓練模型的示例指令碼。該指令碼位於 trl/scripts/kto.py

要使用Qwen2 0.5B 模型UltraFeedback 資料集上測試 KTO 指令碼,請執行以下命令:

accelerate launch trl/scripts/kto.py \
    --model_name_or_path Qwen/Qwen2-0.5B-Instruct \
    --dataset_name trl-lib/kto-mix-14k \
    --num_train_epochs 1 \
    --output_dir Qwen2-0.5B-KTO

使用技巧

對於混合專家模型:啟用輔助損失

如果負載在專家之間大致均勻分佈,MOE(專家混合模型)效率最高。
為了確保在偏好調整期間類似地訓練 MOE,將負載均衡器的輔助損失新增到最終損失中是有益的。

透過在模型配置(例如 MixtralConfig)中設定 output_router_logits=True 來啟用此選項。
要調整輔助損失對總損失的貢獻程度,請在模型配置中使用超引數 router_aux_loss_coef=...(預設值:0.001)。

批次大小建議

使用每步批次大小至少為 4,有效批次大小在 16 到 128 之間。即使你的有效批次大小很大,如果每步批次大小不佳,KTO 中的 KL 估計也會很差。

學習率建議

每種 `beta` 的選擇都有一個在學習效能下降前可以容忍的最大學習率。對於預設設定 `beta = 0.1`,大多數模型的學習率通常不應超過 `1e-6`。隨著 `beta` 的減小,學習率也應相應降低。總的來說,我們強烈建議將學習率保持在 `5e-7` 到 `5e-6` 之間。即使對於小資料集,我們也建議不要使用超出此範圍的學習率。相反,應選擇更多的訓練輪次(epochs)以獲得更好的結果。

不平衡資料

KTOConfig 中的 `desirable_weight` 和 `undesirable_weight` 分別指對理想/正例和不理想/負例損失施加的權重。預設情況下,它們都是 1。然而,如果你的一種型別的樣本比另一種多,那麼你應該增加較少型別樣本的權重,使得(`desirable_weight`×\times正例數量)與(`undesirable_weight`×\times負例數量)的比率在 1:1 到 4:3 的範圍內。

記錄的指標

在訓練和評估期間,我們記錄以下獎勵指標

  • `rewards/chosen_sum`: 策略模型對“選擇的”響應的對數機率總和,按 beta 縮放
  • `rewards/rejected_sum`: 策略模型對“拒絕的”響應的對數機率總和,按 beta 縮放
  • `logps/chosen_sum`: “選擇的”完成的對數機率總和
  • `logps/rejected_sum`: “拒絕的”完成的對數機率總和
  • `logits/chosen_sum`: “選擇的”完成的 logits 總和
  • `logits/rejected_sum`: “拒絕的”完成的 logits 總和
  • `count/chosen`: 一個批次中“選擇的”樣本數量
  • `count/rejected`: 一個批次中“拒絕的”樣本數量

KTOTrainer

class trl.KTOTrainer

< >

( model: typing.Union[transformers.modeling_utils.PreTrainedModel, torch.nn.modules.module.Module, str] = None ref_model: typing.Union[transformers.modeling_utils.PreTrainedModel, torch.nn.modules.module.Module, str, NoneType] = None args: KTOConfig = None train_dataset: typing.Optional[datasets.arrow_dataset.Dataset] = None eval_dataset: typing.Union[datasets.arrow_dataset.Dataset, dict[str, datasets.arrow_dataset.Dataset], NoneType] = None processing_class: typing.Union[transformers.tokenization_utils_base.PreTrainedTokenizerBase, transformers.image_processing_utils.BaseImageProcessor, transformers.feature_extraction_utils.FeatureExtractionMixin, transformers.processing_utils.ProcessorMixin, NoneType] = None data_collator: typing.Optional[transformers.data.data_collator.DataCollator] = None model_init: typing.Optional[typing.Callable[[], transformers.modeling_utils.PreTrainedModel]] = None callbacks: typing.Optional[list[transformers.trainer_callback.TrainerCallback]] = None optimizers: tuple = (None, None) preprocess_logits_for_metrics: typing.Optional[typing.Callable[[torch.Tensor, torch.Tensor], torch.Tensor]] = None peft_config: typing.Optional[dict] = None compute_metrics: typing.Optional[typing.Callable[[transformers.trainer_utils.EvalLoopOutput], dict]] = None model_adapter_name: typing.Optional[str] = None ref_adapter_name: typing.Optional[str] = None )

引數

  • model (transformers.PreTrainedModel) — 用於訓練的模型,最好是 `AutoModelForSequenceClassification`。
  • ref_model (PreTrainedModelWrapper) — 帶有因果語言建模頭的 Hugging Face Transformer 模型。用於隱式獎勵計算和損失。如果沒有提供參考模型,訓練器將建立一個與待最佳化模型相同架構的參考模型。
  • args (KTOConfig) — 用於訓練的引數。
  • train_dataset (datasets.Dataset) — 用於訓練的資料集。
  • eval_dataset (datasets.Dataset) — 用於評估的資料集。
  • processing_class (PreTrainedTokenizerBase, BaseImageProcessor, FeatureExtractionMixinProcessorMixin, *可選*, 預設為 None) — 用於處理資料的處理類。如果提供,將用於自動處理模型的輸入,並將與模型一起儲存,以便更容易地重新執行中斷的訓練或重用微調後的模型。
  • data_collator (transformers.DataCollator, *可選*, 預設為 None) — 用於訓練的資料整理器。如果指定為 None,將使用預設的資料整理器 (`DPODataCollatorWithPadding`),它將把序列填充到批次中序列的最大長度,適用於成對序列的資料集。
  • model_init (Callable[[], transformers.PreTrainedModel]) — 用於訓練的模型初始化器。如果指定為 None,將使用預設的模型初始化器。
  • callbacks (list[transformers.TrainerCallback]) — 用於訓練的回撥。
  • optimizers (tuple[torch.optim.Optimizer, torch.optim.lr_scheduler.LambdaLR]) — 用於訓練的最佳化器和排程器。
  • preprocess_logits_for_metrics (Callable[[torch.Tensor, torch.Tensor], torch.Tensor]) — 在計算指標前用於預處理 logits 的函式。
  • peft_config (dict, 預設為 None) — 用於訓練的 PEFT 配置。如果傳遞 PEFT 配置,模型將被包裝在 PEFT 模型中。
  • compute_metrics (Callable[[EvalPrediction], dict], *可選*) — 用於計算指標的函式。必須接受一個 `EvalPrediction` 並返回一個從字串到指標值的字典。
  • model_adapter_name (str, 預設為 None) — 當使用帶有多個介面卡的 LoRA 時,訓練目標 PEFT 介面卡的名稱。
  • ref_adapter_name (str, 預設為 None) — 當使用帶有多個介面卡的 LoRA 時,參考 PEFT 介面卡的名稱。

初始化 KTOTrainer。

train

< >

( resume_from_checkpoint: typing.Union[str, bool, NoneType] = None trial: typing.Union[ForwardRef('optuna.Trial'), dict[str, typing.Any], NoneType] = None ignore_keys_for_eval: typing.Optional[list[str]] = None **kwargs )

引數

  • resume_from_checkpoint (strbool, *可選*) — 如果是 `str`,則為 `Trainer` 的先前例項儲存的檢查點的本地路徑。如果是 `bool` 且等於 `True`,則載入 *args.output_dir* 中由 `Trainer` 的先前例項儲存的最新檢查點。如果存在,訓練將從此處載入的模型/最佳化器/排程器狀態恢復。
  • trial (optuna.Trialdict[str, Any], *可選*) — 超引數搜尋的試驗執行或超引數字典。
  • ignore_keys_for_eval (list[str], *可選*) — 在訓練期間收集評估預測時,模型輸出中(如果輸出是字典)應忽略的鍵的列表。
  • kwargs (dict[str, Any], *可選*) — 用於隱藏已棄用引數的附加關鍵字引數。

主訓練入口點。

save_model

< >

( output_dir: typing.Optional[str] = None _internal_call: bool = False )

將儲存模型,以便您可以使用 `from_pretrained()` 重新載入它。

僅從主程序儲存。

push_to_hub

< >

( commit_message: typing.Optional[str] = 'End of training' blocking: bool = True token: typing.Optional[str] = None revision: typing.Optional[str] = None **kwargs )

引數

  • commit_message (str, *可選*, 預設為 "End of training") — 推送時的提交資訊。
  • blocking (bool, *可選*, 預設為 True) — 函式是否應在 `git push` 完成後才返回。
  • token (str, *可選*, 預設為 None) — 具有寫入許可權的令牌,用於覆蓋 Trainer 的原始引數。
  • revision (str, *可選*) — 要提交的 git 修訂版本。預設為“main”分支的頭部。
  • kwargs (dict[str, Any], *可選*) — 傳遞給 `~Trainer.create_model_card` 的附加關鍵字引數。

將 `self.model` 和 `self.processing_class` 上傳到 🤗 模型中心的 `self.args.hub_model_id` 儲存庫。

KTOConfig

class trl.KTOConfig

< >

( output_dir: typing.Optional[str] = None overwrite_output_dir: bool = False do_train: bool = False do_eval: bool = False do_predict: bool = False eval_strategy: typing.Union[transformers.trainer_utils.IntervalStrategy, str] = 'no' prediction_loss_only: bool = False per_device_train_batch_size: int = 8 per_device_eval_batch_size: int = 8 per_gpu_train_batch_size: typing.Optional[int] = None per_gpu_eval_batch_size: typing.Optional[int] = None gradient_accumulation_steps: int = 1 eval_accumulation_steps: typing.Optional[int] = None eval_delay: typing.Optional[float] = 0 torch_empty_cache_steps: typing.Optional[int] = None learning_rate: float = 1e-06 weight_decay: float = 0.0 adam_beta1: float = 0.9 adam_beta2: float = 0.999 adam_epsilon: float = 1e-08 max_grad_norm: float = 1.0 num_train_epochs: float = 3.0 max_steps: int = -1 lr_scheduler_type: typing.Union[transformers.trainer_utils.SchedulerType, str] = 'linear' lr_scheduler_kwargs: typing.Union[dict[str, typing.Any], str, NoneType] = <factory> warmup_ratio: float = 0.0 warmup_steps: int = 0 log_level: str = 'passive' log_level_replica: str = 'warning' log_on_each_node: bool = True logging_dir: typing.Optional[str] = None logging_strategy: typing.Union[transformers.trainer_utils.IntervalStrategy, str] = 'steps' logging_first_step: bool = False logging_steps: float = 10 logging_nan_inf_filter: bool = True save_strategy: typing.Union[transformers.trainer_utils.SaveStrategy, str] = 'steps' save_steps: float = 500 save_total_limit: typing.Optional[int] = None save_safetensors: typing.Optional[bool] = True save_on_each_node: bool = False save_only_model: bool = False restore_callback_states_from_checkpoint: bool = False no_cuda: bool = False use_cpu: bool = False use_mps_device: bool = False seed: int = 42 data_seed: typing.Optional[int] = None jit_mode_eval: bool = False use_ipex: bool = False bf16: typing.Optional[bool] = None fp16: bool = False fp16_opt_level: str = 'O1' half_precision_backend: str = 'auto' bf16_full_eval: bool = False fp16_full_eval: bool = False tf32: typing.Optional[bool] = None local_rank: int = -1 ddp_backend: typing.Optional[str] = None tpu_num_cores: typing.Optional[int] = None tpu_metrics_debug: bool = False debug: typing.Union[str, list[transformers.debug_utils.DebugOption]] = '' dataloader_drop_last: bool = False eval_steps: typing.Optional[float] = None dataloader_num_workers: int = 0 dataloader_prefetch_factor: typing.Optional[int] = None past_index: int = -1 run_name: typing.Optional[str] = None disable_tqdm: typing.Optional[bool] = None remove_unused_columns: typing.Optional[bool] = True label_names: typing.Optional[list[str]] = None load_best_model_at_end: typing.Optional[bool] = False metric_for_best_model: typing.Optional[str] = None greater_is_better: typing.Optional[bool] = None ignore_data_skip: bool = False fsdp: typing.Union[list[transformers.trainer_utils.FSDPOption], str, NoneType] = '' fsdp_min_num_params: int = 0 fsdp_config: typing.Union[dict[str, typing.Any], str, NoneType] = None fsdp_transformer_layer_cls_to_wrap: typing.Optional[str] = None accelerator_config: typing.Union[dict, str, NoneType] = None deepspeed: typing.Union[dict, str, NoneType] = None label_smoothing_factor: float = 0.0 optim: typing.Union[transformers.training_args.OptimizerNames, str] = 'adamw_torch' optim_args: typing.Optional[str] = None adafactor: bool = False group_by_length: bool = False length_column_name: typing.Optional[str] = 'length' report_to: typing.Union[NoneType, str, list[str]] = None ddp_find_unused_parameters: typing.Optional[bool] = None ddp_bucket_cap_mb: typing.Optional[int] = None ddp_broadcast_buffers: typing.Optional[bool] = None dataloader_pin_memory: bool = True dataloader_persistent_workers: bool = False skip_memory_metrics: bool = True use_legacy_prediction_loop: bool = False push_to_hub: bool = False resume_from_checkpoint: typing.Optional[str] = None hub_model_id: typing.Optional[str] = None hub_strategy: typing.Union[transformers.trainer_utils.HubStrategy, str] = 'every_save' hub_token: typing.Optional[str] = None hub_private_repo: typing.Optional[bool] = None hub_always_push: bool = False hub_revision: typing.Optional[str] = None gradient_checkpointing: bool = False gradient_checkpointing_kwargs: typing.Union[dict[str, typing.Any], str, NoneType] = None include_inputs_for_metrics: bool = False include_for_metrics: list = <factory> eval_do_concat_batches: bool = True fp16_backend: str = 'auto' push_to_hub_model_id: typing.Optional[str] = None push_to_hub_organization: typing.Optional[str] = None push_to_hub_token: typing.Optional[str] = None mp_parameters: str = '' auto_find_batch_size: bool = False full_determinism: bool = False torchdynamo: typing.Optional[str] = None ray_scope: typing.Optional[str] = 'last' ddp_timeout: int = 1800 torch_compile: bool = False torch_compile_backend: typing.Optional[str] = None torch_compile_mode: typing.Optional[str] = None include_tokens_per_second: typing.Optional[bool] = False include_num_input_tokens_seen: typing.Optional[bool] = False neftune_noise_alpha: typing.Optional[float] = None optim_target_modules: typing.Union[NoneType, str, list[str]] = None batch_eval_metrics: bool = False eval_on_start: bool = False use_liger_kernel: typing.Optional[bool] = False liger_kernel_config: typing.Optional[dict[str, bool]] = None eval_use_gather_object: typing.Optional[bool] = False average_tokens_across_devices: typing.Optional[bool] = True max_length: typing.Optional[int] = 1024 max_prompt_length: typing.Optional[int] = 512 max_completion_length: typing.Optional[int] = None beta: float = 0.1 loss_type: str = 'kto' desirable_weight: float = 1.0 undesirable_weight: float = 1.0 label_pad_token_id: int = -100 padding_value: typing.Optional[int] = None truncation_mode: str = 'keep_end' generate_during_eval: bool = False is_encoder_decoder: typing.Optional[bool] = None disable_dropout: bool = True precompute_ref_log_probs: bool = False model_init_kwargs: typing.Optional[dict[str, typing.Any]] = None ref_model_init_kwargs: typing.Optional[dict[str, typing.Any]] = None dataset_num_proc: typing.Optional[int] = None use_liger_loss: bool = False base_model_attribute_name: str = 'model' )

引數

  • max_length (intNone, 可選, 預設為 1024) — 批次中序列(提示 + 補全)的最大長度。如果要使用預設的資料整理器,則此引數為必需。
  • max_prompt_length (intNone, 可選, 預設為 512) — 提示的最大長度。如果要使用預設的資料整理器,則此引數為必需。
  • max_completion_length (intNone, 可選, 預設為 None) — 補全的最大長度。如果要使用預設的資料整理器且模型是編碼器-解碼器模型,則此引數為必需。
  • beta (float, 可選, 預設為 0.1) — 控制與參考模型偏差的引數。更高的 β 意味著與參考模型的偏差更小。
  • loss_type (str, 可選, 預設為 "kto") — 使用的損失型別。可能的值有:

    • "kto":來自 KTO 論文的 KTO 損失。
    • "apo_zero_unpaired":來自 APO 論文的 APO-zero 損失的非配對變體。
  • desirable_weight (float, 可選, 預設為 1.0) — 合意損失會乘以這個因子,以抵消合意和非合意對數量不均等的問題。
  • undesirable_weight (float, 可選, 預設為 1.0) — 不合意損失會乘以這個因子,以抵消合意和非合意對數量不均等的問題。
  • label_pad_token_id (int, 可選, 預設為 -100) — 標籤填充標記的 ID。如果要使用預設的資料整理器,則此引數為必需。
  • padding_value (intNone, 可選, 預設為 None) — 要使用的填充值。如果為 None,則使用分詞器的填充值。
  • truncation_mode (str, 可選, 預設為 "keep_end") — 當提示過長時使用的截斷模式。可能的值為 "keep_end""keep_start"。如果要使用預設的資料整理器,則此引數為必需。
  • generate_during_eval (bool, 可選, 預設為 False) — 如果為 True,則在評估期間從模型和參考模型生成並記錄補全到 W&B 或 Comet。
  • is_encoder_decoder (boolNone, 可選, 預設為 None) — 當使用 model_init 引數(可呼叫物件)來例項化模型而不是 model 引數時,您需要指定可呼叫物件返回的模型是否為編碼器-解碼器模型。
  • precompute_ref_log_probs (bool, 可選, 預設為 False) — 是否為訓練和評估資料集預先計算參考模型的對數機率。這在不使用參考模型進行訓練時非常有用,可以減少所需的總 GPU 記憶體。
  • model_init_kwargs (dict[str, Any]None, 可選, 預設為 None) — 當從字串例項化模型時,傳遞給 AutoModelForCausalLM.from_pretrained 的關鍵字引數。
  • ref_model_init_kwargs (dict[str, Any]None, 可選, 預設為 None) — 當從字串例項化參考模型時,傳遞給 AutoModelForCausalLM.from_pretrained 的關鍵字引數。
  • dataset_num_proc — (intNone, 可選, 預設為 None): 用於處理資料集的程序數。
  • disable_dropout (bool, 可選, 預設為 True) — 是否在模型和參考模型中停用 dropout。
  • use_liger_loss (bool, 可選, 預設為 False) — 是否使用 Liger 損失。這需要安裝 liger-kernel。
  • base_model_attribute_name (str, 可選, 預設為 "model") — 模型中包含基礎模型的屬性名稱。當 use_liger_lossTrue 且模型沒有 get_decoder 方法時,此引數用於從模型中獲取基礎模型。

KTOTrainer 的配置類。

這個類僅包含特定於 KTO 訓練的引數。有關訓練引數的完整列表,請參閱 TrainingArguments 文件。請注意,此類中的預設值可能與 TrainingArguments 中的不同。

使用 HfArgumentParser,我們可以將此類別轉換為可在命令列上指定的 argparse 引數。

< > 在 GitHub 上更新

© . This site is unofficial and not affiliated with Hugging Face, Inc.