TRL 文件
回撥
並獲得增強的文件體驗
開始使用
回撥函式
SyncRefModelCallback
class trl.SyncRefModelCallback
< 原始碼 >( ref_model: typing.Union[transformers.modeling_utils.PreTrainedModel, torch.nn.modules.module.Module] accelerator: typing.Optional[accelerate.accelerator.Accelerator] )
用於將模型與參考模型同步的回撥函式。
RichProgressCallback
一個使用 Rich 顯示訓練或評估進度的 `TrainerCallback`。
WinRateCallback
class trl.WinRateCallback
< 原始碼 >( judge: BasePairwiseJudge trainer: Trainer generation_config: typing.Optional[transformers.generation.configuration_utils.GenerationConfig] = None num_prompts: typing.Optional[int] = None shuffle_order: bool = True use_soft_judge: bool = False )
引數
- judge (
BasePairwiseJudge
) — 用於比較生成內容的評判器。 - trainer (
Trainer
) — 回撥函式將附加到的訓練器。訓練器的評估資料集必須包含一個名為 `"prompt"` 的列,其中包含用於生成內容的提示。如果 `Trainer` 有一個參考模型(透過 `ref_model` 屬性),它將使用此參考模型生成參考內容;否則,它將預設使用初始模型。 - generation_config (
GenerationConfig
, 可選) — 用於生成內容的生成配置。 - num_prompts (
int
或None
, 可選, 預設為None
) — 要為其生成內容的提示數量。如果未提供,則預設為評估資料集中的示例數量。 - shuffle_order (
bool
, 可選, 預設為True
) — 是否在評判前打亂生成內容的順序。 - use_soft_judge (
bool
, 可選, 預設為False
) — 是否使用一個軟評判器,它為第一個生成內容與第二個生成內容的對比返回一個 0 到 1 之間的獲勝機率。
一個 `TrainerCallback`,用於根據參考模型計算模型的勝率。
它使用來自評估資料集的提示生成內容,並將訓練好的模型的輸出與參考模型進行比較。參考模型要麼是模型的初始版本(訓練前),要麼是訓練器中可用的參考模型。在每個評估步驟中,評判器會確定訓練好的模型生成的內容相比參考模型獲勝的頻率。然後,勝率會記錄在訓練器的日誌中,鍵為 `"eval_win_rate"`。
LogCompletionsCallback
class trl.LogCompletionsCallback
< 原始碼 >( trainer: Trainer generation_config: typing.Optional[transformers.generation.configuration_utils.GenerationConfig] = None num_prompts: typing.Optional[int] = None freq: typing.Optional[int] = None )
一個 `TrainerCallback`,用於將生成的內容記錄到 Weights & Biases 和/或 Comet 中。
MergeModelCallback
class trl.MergeModelCallback
< 原始碼 >( merge_config: typing.Optional[ForwardRef('MergeConfig')] = None merge_at_every_checkpoint: bool = False push_to_hub: bool = False )
一個 `TrainerCallback`,用於根據合併配置將策略模型(正在訓練的模型)與另一個模型進行合併。