TRL 文件
獎勵模型
並獲得增強的文件體驗
開始使用
獎勵模型
TRL 支援自定義獎勵模型,任何人都可以根據自己的資料集和模型進行獎勵建模。
請在 examples/scripts/reward_modeling.py
檢視一個完整且靈活的示例。
期望的資料集型別
RewardTrainer 需要一個隱式提示偏好資料集。這意味著資料集應該只包含 "chosen"
和 "rejected"
列(而不包含 "prompt"
)。RewardTrainer 支援對話式和標準兩種資料集格式。當提供對話式資料集時,訓練器會自動將聊天模板應用於資料集。
您也可以使用預分詞的資料集,在這種情況下,資料集應包含以下列:input_ids_chosen
, attention_mask_chosen
, input_ids_rejected
和 attention_mask_rejected
。
使用 RewardTrainer
準備好資料集後,您可以像使用 🤗 Transformers 的 Trainer
類一樣使用 RewardTrainer。您應該將一個 AutoModelForSequenceClassification
模型傳遞給 RewardTrainer,同時傳遞一個 RewardConfig 來配置訓練的超引數。
利用 🤗 PEFT 訓練獎勵模型
只需在 RewardTrainer 的關鍵字引數中傳遞一個 peft_config
,訓練器就會自動將模型轉換為 PEFT 模型!
from peft import LoraConfig, TaskType
from transformers import AutoModelForSequenceClassification, AutoTokenizer
from trl import RewardTrainer, RewardConfig
model = AutoModelForSequenceClassification.from_pretrained("gpt2")
peft_config = LoraConfig(
task_type=TaskType.SEQ_CLS,
inference_mode=False,
r=8,
lora_alpha=32,
lora_dropout=0.1,
)
...
trainer = RewardTrainer(
model=model,
args=training_args,
processing_class=tokenizer,
train_dataset=dataset,
peft_config=peft_config,
)
trainer.train()
為損失函式新增邊際
就像在 Llama 2 論文 中一樣,您可以透過向資料集中新增 margin
列來為損失函式新增邊際。獎勵整理器會自動傳遞它,並相應地計算損失。
def add_margin(row):
# Assume you have a score_chosen and score_rejected columns that you want to use to compute the margin
return {'margin': row['score_chosen'] - row['score_rejected']}
dataset = dataset.map(add_margin)
中心化獎勵
在許多場景中,確保獎勵模型的輸出均值為零是更可取的。這通常透過首先計算模型的平均得分,然後減去它來實現。
[Eisenstein et al., 2023] 提出了一種輔助損失函式,旨在直接學習一箇中心化的獎勵模型。該輔助損失函式最小化獎勵的平方和,從而鼓勵模型自然地產生均值為零的輸出。
這個輔助損失函式與主損失函式結合,權重由 [RewardConfig]
中的引數 center_rewards_coefficient
控制。預設情況下,此功能是停用的(center_rewards_coefficient = None
)。
training_args = RewardConfig(
center_rewards_coefficient=0.01,
...
)
有關參考結果,請參閱 PR #1932。
RewardTrainer
class trl.RewardTrainer
< 原始碼 >( model: typing.Union[transformers.modeling_utils.PreTrainedModel, torch.nn.modules.module.Module, NoneType] = None args: typing.Optional[trl.trainer.reward_config.RewardConfig] = None data_collator: typing.Optional[transformers.data.data_collator.DataCollator] = None train_dataset: typing.Optional[datasets.arrow_dataset.Dataset] = None eval_dataset: typing.Union[datasets.arrow_dataset.Dataset, dict[str, datasets.arrow_dataset.Dataset], NoneType] = None processing_class: typing.Union[transformers.tokenization_utils_base.PreTrainedTokenizerBase, transformers.image_processing_utils.BaseImageProcessor, transformers.feature_extraction_utils.FeatureExtractionMixin, transformers.processing_utils.ProcessorMixin, NoneType] = None model_init: typing.Optional[typing.Callable[[], transformers.modeling_utils.PreTrainedModel]] = None compute_metrics: typing.Optional[typing.Callable[[transformers.trainer_utils.EvalPrediction], dict]] = None callbacks: typing.Optional[list[transformers.trainer_callback.TrainerCallback]] = None optimizers: tuple = (None, None) preprocess_logits_for_metrics: typing.Optional[typing.Callable[[torch.Tensor, torch.Tensor], torch.Tensor]] = None peft_config: typing.Optional[dict] = None )
train
< 原始碼 >( resume_from_checkpoint: typing.Union[str, bool, NoneType] = None trial: typing.Union[ForwardRef('optuna.Trial'), dict[str, typing.Any], NoneType] = None ignore_keys_for_eval: typing.Optional[list[str]] = None **kwargs )
引數
- resume_from_checkpoint (
str
或bool
, 可選) — 如果是str
,則為之前Trainer
例項儲存的檢查點的本地路徑。如果是bool
且等於True
,則載入之前Trainer
例項在 *args.output_dir* 中儲存的最後一個檢查點。如果提供,訓練將從此處載入的模型/最佳化器/排程器狀態恢復。 - trial (
optuna.Trial
或dict[str, Any]
, 可選) — 用於超引數搜尋的試驗執行或超引數字典。 - ignore_keys_for_eval (
list[str]
, 可選) — 模型輸出(如果為字典)中在訓練期間收集評估預測時應忽略的鍵列表。 - kwargs (
dict[str, Any]
, 可選) — 用於隱藏已棄用引數的附加關鍵字引數。
主訓練入口點。
將儲存模型,以便您可以使用 `from_pretrained()` 重新載入它。
僅從主程序儲存。
push_to_hub
< 原始碼 >( commit_message: typing.Optional[str] = 'End of training' blocking: bool = True token: typing.Optional[str] = None revision: typing.Optional[str] = None **kwargs )
將 `self.model` 和 `self.processing_class` 上傳到 🤗 模型中心的 `self.args.hub_model_id` 儲存庫。
RewardConfig
class trl.RewardConfig
< 原始碼 >( output_dir: typing.Optional[str] = None overwrite_output_dir: bool = False do_train: bool = False do_eval: bool = False do_predict: bool = False eval_strategy: typing.Union[transformers.trainer_utils.IntervalStrategy, str] = 'no' prediction_loss_only: bool = False per_device_train_batch_size: int = 8 per_device_eval_batch_size: int = 8 per_gpu_train_batch_size: typing.Optional[int] = None per_gpu_eval_batch_size: typing.Optional[int] = None gradient_accumulation_steps: int = 1 eval_accumulation_steps: typing.Optional[int] = None eval_delay: typing.Optional[float] = 0 torch_empty_cache_steps: typing.Optional[int] = None learning_rate: float = 5e-05 weight_decay: float = 0.0 adam_beta1: float = 0.9 adam_beta2: float = 0.999 adam_epsilon: float = 1e-08 max_grad_norm: float = 1.0 num_train_epochs: float = 3.0 max_steps: int = -1 lr_scheduler_type: typing.Union[transformers.trainer_utils.SchedulerType, str] = 'linear' lr_scheduler_kwargs: typing.Union[dict[str, typing.Any], str, NoneType] = <factory> warmup_ratio: float = 0.0 warmup_steps: int = 0 log_level: str = 'passive' log_level_replica: str = 'warning' log_on_each_node: bool = True logging_dir: typing.Optional[str] = None logging_strategy: typing.Union[transformers.trainer_utils.IntervalStrategy, str] = 'steps' logging_first_step: bool = False logging_steps: float = 10 logging_nan_inf_filter: bool = True save_strategy: typing.Union[transformers.trainer_utils.SaveStrategy, str] = 'steps' save_steps: float = 500 save_total_limit: typing.Optional[int] = None save_safetensors: typing.Optional[bool] = True save_on_each_node: bool = False save_only_model: bool = False restore_callback_states_from_checkpoint: bool = False no_cuda: bool = False use_cpu: bool = False use_mps_device: bool = False seed: int = 42 data_seed: typing.Optional[int] = None jit_mode_eval: bool = False use_ipex: bool = False bf16: typing.Optional[bool] = None fp16: bool = False fp16_opt_level: str = 'O1' half_precision_backend: str = 'auto' bf16_full_eval: bool = False fp16_full_eval: bool = False tf32: typing.Optional[bool] = None local_rank: int = -1 ddp_backend: typing.Optional[str] = None tpu_num_cores: typing.Optional[int] = None tpu_metrics_debug: bool = False debug: typing.Union[str, list[transformers.debug_utils.DebugOption]] = '' dataloader_drop_last: bool = False eval_steps: typing.Optional[float] = None dataloader_num_workers: int = 0 dataloader_prefetch_factor: typing.Optional[int] = None past_index: int = -1 run_name: typing.Optional[str] = None disable_tqdm: typing.Optional[bool] = None remove_unused_columns: bool = False label_names: typing.Optional[list[str]] = None load_best_model_at_end: typing.Optional[bool] = False metric_for_best_model: typing.Optional[str] = None greater_is_better: typing.Optional[bool] = None ignore_data_skip: bool = False fsdp: typing.Union[list[transformers.trainer_utils.FSDPOption], str, NoneType] = '' fsdp_min_num_params: int = 0 fsdp_config: typing.Union[dict[str, typing.Any], str, NoneType] = None fsdp_transformer_layer_cls_to_wrap: typing.Optional[str] = None accelerator_config: typing.Union[dict, str, NoneType] = None deepspeed: typing.Union[dict, str, NoneType] = None label_smoothing_factor: float = 0.0 optim: typing.Union[transformers.training_args.OptimizerNames, str] = 'adamw_torch' optim_args: typing.Optional[str] = None adafactor: bool = False group_by_length: bool = False length_column_name: typing.Optional[str] = 'length' report_to: typing.Union[NoneType, str, list[str]] = None ddp_find_unused_parameters: typing.Optional[bool] = None ddp_bucket_cap_mb: typing.Optional[int] = None ddp_broadcast_buffers: typing.Optional[bool] = None dataloader_pin_memory: bool = True dataloader_persistent_workers: bool = False skip_memory_metrics: bool = True use_legacy_prediction_loop: bool = False push_to_hub: bool = False resume_from_checkpoint: typing.Optional[str] = None hub_model_id: typing.Optional[str] = None hub_strategy: typing.Union[transformers.trainer_utils.HubStrategy, str] = 'every_save' hub_token: typing.Optional[str] = None hub_private_repo: typing.Optional[bool] = None hub_always_push: bool = False hub_revision: typing.Optional[str] = None gradient_checkpointing: bool = False gradient_checkpointing_kwargs: typing.Union[dict[str, typing.Any], str, NoneType] = None include_inputs_for_metrics: bool = False include_for_metrics: list = <factory> eval_do_concat_batches: bool = True fp16_backend: str = 'auto' push_to_hub_model_id: typing.Optional[str] = None push_to_hub_organization: typing.Optional[str] = None push_to_hub_token: typing.Optional[str] = None mp_parameters: str = '' auto_find_batch_size: bool = False full_determinism: bool = False torchdynamo: typing.Optional[str] = None ray_scope: typing.Optional[str] = 'last' ddp_timeout: int = 1800 torch_compile: bool = False torch_compile_backend: typing.Optional[str] = None torch_compile_mode: typing.Optional[str] = None include_tokens_per_second: typing.Optional[bool] = False include_num_input_tokens_seen: typing.Optional[bool] = False neftune_noise_alpha: typing.Optional[float] = None optim_target_modules: typing.Union[NoneType, str, list[str]] = None batch_eval_metrics: bool = False eval_on_start: bool = False use_liger_kernel: typing.Optional[bool] = False liger_kernel_config: typing.Optional[dict[str, bool]] = None eval_use_gather_object: typing.Optional[bool] = False average_tokens_across_devices: bool = True max_length: typing.Optional[int] = 1024 disable_dropout: bool = True dataset_num_proc: typing.Optional[int] = None center_rewards_coefficient: typing.Optional[float] = None )
引數
- max_length (
int
或None
,可選,預設為1024
) — 批次中序列(提示+補全)的最大長度,會過濾掉超過此限制的條目。如果要使用預設的資料整理器,則此引數是必需的。 - disable_dropout (
bool
, 可選,預設為True
) — 是否在模型中停用 dropout。 - dataset_num_proc (
int
, 可選,預設為None
) — 用於處理資料集的程序數。 - center_rewards_coefficient (
float
, 可選,預設為None
) — 激勵獎勵模型輸出均值為零的獎勵的係數(由 https://huggingface.co/papers/2312.09244 提出,公式 2)。推薦值:0.01
。 - remove_unused_columns (
bool
, 可選,預設為False
) — 是否移除模型前向傳播中未使用的列。僅當資料集已預先分詞時,此項才能為True
。
用於 RewardTrainer 的配置類。
此類僅包含特定於獎勵(Reward)訓練的引數。有關訓練引數的完整列表,請參閱 TrainingArguments
文件。請注意,此類中的預設值可能與 TrainingArguments
中的不同。
使用 HfArgumentParser
,我們可以將此類別轉換為可在命令列上指定的 argparse 引數。