TRL 文件

回撥

Hugging Face's logo
加入 Hugging Face 社群

並獲得增強的文件體驗

開始使用

回撥函式

SyncRefModelCallback

class trl.SyncRefModelCallback

< >

( ref_model: typing.Union[transformers.modeling_utils.PreTrainedModel, torch.nn.modules.module.Module] accelerator: typing.Optional[accelerate.accelerator.Accelerator] )

用於將模型與參考模型同步的回撥函式。

RichProgressCallback

class trl.RichProgressCallback

< >

( )

一個使用 Rich 顯示訓練或評估進度的 `TrainerCallback`。

WinRateCallback

class trl.WinRateCallback

< >

( judge: BasePairwiseJudge trainer: Trainer generation_config: typing.Optional[transformers.generation.configuration_utils.GenerationConfig] = None num_prompts: typing.Optional[int] = None shuffle_order: bool = True use_soft_judge: bool = False )

引數

  • judge (BasePairwiseJudge) — 用於比較生成內容的評判器。
  • trainer (Trainer) — 回撥函式將附加到的訓練器。訓練器的評估資料集必須包含一個名為 `"prompt"` 的列,其中包含用於生成內容的提示。如果 `Trainer` 有一個參考模型(透過 `ref_model` 屬性),它將使用此參考模型生成參考內容;否則,它將預設使用初始模型。
  • generation_config (GenerationConfig, 可選) — 用於生成內容的生成配置。
  • num_prompts (intNone, 可選, 預設為 None) — 要為其生成內容的提示數量。如果未提供,則預設為評估資料集中的示例數量。
  • shuffle_order (bool, 可選, 預設為 True) — 是否在評判前打亂生成內容的順序。
  • use_soft_judge (bool, 可選, 預設為 False) — 是否使用一個軟評判器,它為第一個生成內容與第二個生成內容的對比返回一個 0 到 1 之間的獲勝機率。

一個 `TrainerCallback`,用於根據參考模型計算模型的勝率。

它使用來自評估資料集的提示生成內容,並將訓練好的模型的輸出與參考模型進行比較。參考模型要麼是模型的初始版本(訓練前),要麼是訓練器中可用的參考模型。在每個評估步驟中,評判器會確定訓練好的模型生成的內容相比參考模型獲勝的頻率。然後,勝率會記錄在訓練器的日誌中,鍵為 `"eval_win_rate"`。

用法

trainer = DPOTrainer(...)
judge = PairRMJudge()
win_rate_callback = WinRateCallback(judge=judge, trainer=trainer)
trainer.add_callback(win_rate_callback)

LogCompletionsCallback

class trl.LogCompletionsCallback

< >

( trainer: Trainer generation_config: typing.Optional[transformers.generation.configuration_utils.GenerationConfig] = None num_prompts: typing.Optional[int] = None freq: typing.Optional[int] = None )

引數

  • trainer (Trainer) — 回撥函式將附加到的訓練器。訓練器的評估資料集必須包含一個名為 `"prompt"` 的列,其中包含用於生成內容的提示。
  • generation_config (GenerationConfig, 可選) — 用於生成內容的生成配置。
  • num_prompts (intNone, 可選) — 要為其生成內容的提示數量。如果未提供,則預設為評估資料集中的示例數量。
  • freq (intNone, 可選) — 記錄生成內容的頻率。如果未提供,則預設為訓練器的 `eval_steps`。

一個 `TrainerCallback`,用於將生成的內容記錄到 Weights & Biases 和/或 Comet 中。

用法

trainer = DPOTrainer(...)
completions_callback = LogCompletionsCallback(trainer=trainer)
trainer.add_callback(completions_callback)

MergeModelCallback

class trl.MergeModelCallback

< >

( merge_config: typing.Optional[ForwardRef('MergeConfig')] = None merge_at_every_checkpoint: bool = False push_to_hub: bool = False )

引數

  • merge_config (MergeConfig, 可選, 預設為 None) — 用於合併過程的配置。如果未提供,則使用預設的 `MergeConfig`。
  • merge_at_every_checkpoint (bool, 可選, 預設為 False) — 是否在每個檢查點合併模型。
  • push_to_hub (bool, 可選, 預設為 False) — 合併後是否將合併後的模型推送到 Hub。

一個 `TrainerCallback`,用於根據合併配置將策略模型(正在訓練的模型)與另一個模型進行合併。

示例

# pip install mergekit

from trl.mergekit_utils import MergeConfig
from trl import MergeModelCallback

config = MergeConfig()
merge_callback = MergeModelCallback(config)
trainer = DPOTrainer(..., callbacks=[merge_callback])
< > 在 GitHub 上更新

© . This site is unofficial and not affiliated with Hugging Face, Inc.