TRL 文件
迭代訓練器 (Iterative Trainer)
並獲得增強的文件體驗
開始使用
迭代訓練器 (Iterative Trainer)
迭代式微調是一種訓練方法,它允許在最佳化步驟之間執行自定義操作(例如生成和過濾)。在 TRL 中,我們提供了一個易於使用的 API,只需幾行程式碼即可迭代地微調您的模型。
快速入門
要快速開始,您可以將模型識別符號或預例項化的模型傳遞給訓練器
from trl import IterativeSFTConfig, IterativeSFTTrainer
# Using a model identifier
trainer = IterativeSFTTrainer(
"facebook/opt-350m",
args=IterativeSFTConfig(
max_length=512,
output_dir="./output",
),
)
# Or using a pre-instantiated model
from transformers import AutoModelForCausalLM, AutoTokenizer
model = AutoModelForCausalLM.from_pretrained("facebook/opt-350m")
tokenizer = AutoTokenizer.from_pretrained("facebook/opt-350m")
trainer = IterativeSFTTrainer(
model,
args=IterativeSFTConfig(
max_length=512,
output_dir="./output",
),
processing_class=tokenizer,
)
用法
IterativeSFTTrainer 支援兩種向 step
函式提供輸入資料的方式
使用張量列表作為輸入:
inputs = {
"input_ids": input_ids,
"attention_mask": attention_mask,
}
trainer.step(**inputs)
使用字串列表作為輸入:
inputs = {
"texts": texts,
"texts_labels": texts_labels, # Optional, defaults to texts
}
trainer.step(**inputs)
對於因果語言模型,標籤將自動從 input_ids
或 texts
建立。使用序列到序列模型時,您必須提供自己的標籤或 text_labels
。
配置
IterativeSFTConfig 類提供了幾個引數來定製訓練
from trl import IterativeSFTConfig
config = IterativeSFTConfig(
# Model initialization parameters
model_init_kwargs={"torch_dtype": "bfloat16"},
# Data preprocessing parameters
max_length=512,
truncation_mode="keep_end",
# Training parameters
output_dir="./output",
learning_rate=2e-5,
per_device_train_batch_size=4,
gradient_accumulation_steps=4,
max_steps=1000,
save_steps=100,
optim="adamw_torch",
report_to="wandb",
)
模型初始化
您可以透過向 model_init_kwargs
傳遞關鍵字引數來控制模型的初始化方式
config = IterativeSFTConfig(
model_init_kwargs={
"torch_dtype": "bfloat16",
"device_map": "auto",
"trust_remote_code": True,
}
)
資料預處理
訓練器支援兩種截斷模式
keep_end
:從序列的開頭截斷keep_start
:從序列的末尾截斷
config = IterativeSFTConfig(
max_length=512,
truncation_mode="keep_end", # or "keep_start"
)
訓練最佳化
您可以最佳化 CUDA 快取使用,以實現更節省記憶體的訓練
config = IterativeSFTConfig(
optimize_device_cache=True,
)
IterativeSFTTrainer
class trl.IterativeSFTTrainer
< 原始碼 >( model: typing.Union[str, transformers.modeling_utils.PreTrainedModel] args: typing.Union[trl.trainer.iterative_sft_config.IterativeSFTConfig, transformers.training_args.TrainingArguments, NoneType] = None data_collator: typing.Optional[transformers.data.data_collator.DataCollator] = None eval_dataset: typing.Union[datasets.arrow_dataset.Dataset, dict[str, datasets.arrow_dataset.Dataset], NoneType] = None processing_class: typing.Union[transformers.tokenization_utils_base.PreTrainedTokenizerBase, transformers.image_processing_utils.BaseImageProcessor, transformers.feature_extraction_utils.FeatureExtractionMixin, transformers.processing_utils.ProcessorMixin, NoneType] = None optimizers: tuple = (None, None) preprocess_logits_for_metrics: typing.Optional[typing.Callable[[torch.Tensor, torch.Tensor], torch.Tensor]] = None compute_metrics: typing.Optional[typing.Callable[[transformers.trainer_utils.EvalLoopOutput], dict]] = None )
引數
- model (
Union[str, PreTrainedModel]
) — 待訓練的模型。可以是:- 一個字串,表示 huggingface.co 上模型倉庫中預訓練模型的 model id,或包含使用
save_pretrained
儲存的模型權重的 目錄 路徑,例如'./my_model_directory/'
。模型使用from_pretrained
和args.model_init_kwargs
中的關鍵字引數載入。 - 一個
PreTrainedModel
物件。僅支援因果語言模型。
- 一個字串,表示 huggingface.co 上模型倉庫中預訓練模型的 model id,或包含使用
- args (IterativeSFTConfig, 可選, 預設為
None
) — 此訓練器的配置。如果為None
,則使用預設配置。 - data_collator (
DataCollator
, 可選) — 用於從處理過的train_dataset
或eval_dataset
的元素列表中形成批次的函式。如果未提供processing_class
,將預設為default_data_collator
;如果 processing_class 是特徵提取器或分詞器,則預設為DataCollatorWithPadding
的例項。 - eval_dataset (
datasets.Dataset
) — 用於評估的資料集。 - processing_class (
PreTrainedTokenizerBase
,BaseImageProcessor
,FeatureExtractionMixin
或ProcessorMixin
, 可選, 預設為None
) — 用於處理資料的處理類。如果為None
,則處理類將從模型的名稱使用from_pretrained
載入。 - optimizers (
tuple[torch.optim.Optimizer, torch.optim.lr_scheduler.LambdaLR]
) — 用於訓練的最佳化器和排程器。 - preprocess_logits_for_metrics (
Callable[[torch.Tensor, torch.Tensor], torch.Tensor]
) — 在計算指標之前用於預處理 logits 的函式。 - compute_metrics (
Callable[[EvalPrediction], dict]
, 可選) — 用於計算指標的函式。必須接受一個EvalPrediction
並返回一個從字串到指標值的字典。
IterativeSFTTrainer 可用於透過需要在最佳化之間執行某些步驟的方法來微調模型。
訓練 (train)
< 原始碼 >( resume_from_checkpoint: typing.Union[str, bool, NoneType] = None trial: typing.Union[ForwardRef('optuna.Trial'), dict[str, typing.Any], NoneType] = None ignore_keys_for_eval: typing.Optional[list[str]] = None **kwargs )
引數
- resume_from_checkpoint (
str
orbool
, 可選) — 如果是str
,則為先前Trainer
例項儲存的檢查點的本地路徑。如果是bool
且等於True
,則載入先前Trainer
例項在 *args.output_dir* 中儲存的最後一個檢查點。如果存在,訓練將從此載入的模型/最佳化器/排程器狀態恢復。 - trial (
optuna.Trial
或dict[str, Any]
, 可選) — 用於超引數搜尋的試驗執行或超引數字典。 - ignore_keys_for_eval (
list[str]
, 可選) — 你的模型輸出中(如果它是一個字典)的鍵列表,在訓練期間收集評估預測時應忽略這些鍵。 - kwargs (
dict[str, Any]
, 可選) — 用於隱藏已棄用引數的額外關鍵字引數
主訓練入口點。
將儲存模型,以便您可以使用 `from_pretrained()` 重新載入它。
僅從主程序儲存。
推送到 Hub (push_to_hub)
< 原始碼 >( commit_message: typing.Optional[str] = 'End of training' blocking: bool = True token: typing.Optional[str] = None revision: typing.Optional[str] = None **kwargs )
將 `self.model` 和 `self.processing_class` 上傳到 🤗 模型中心的 `self.args.hub_model_id` 儲存庫。
IterativeSFTConfig
class trl.IterativeSFTConfig
< 原始碼 >( output_dir: typing.Optional[str] = None overwrite_output_dir: bool = False do_train: bool = False do_eval: bool = False do_predict: bool = False eval_strategy: typing.Union[transformers.trainer_utils.IntervalStrategy, str] = 'no' prediction_loss_only: bool = False per_device_train_batch_size: int = 8 per_device_eval_batch_size: int = 8 per_gpu_train_batch_size: typing.Optional[int] = None per_gpu_eval_batch_size: typing.Optional[int] = None gradient_accumulation_steps: int = 1 eval_accumulation_steps: typing.Optional[int] = None eval_delay: typing.Optional[float] = 0 torch_empty_cache_steps: typing.Optional[int] = None learning_rate: float = 5e-05 weight_decay: float = 0.0 adam_beta1: float = 0.9 adam_beta2: float = 0.999 adam_epsilon: float = 1e-08 max_grad_norm: float = 1.0 num_train_epochs: float = 3.0 max_steps: int = -1 lr_scheduler_type: typing.Union[transformers.trainer_utils.SchedulerType, str] = 'linear' lr_scheduler_kwargs: typing.Union[dict[str, typing.Any], str, NoneType] = <factory> warmup_ratio: float = 0.0 warmup_steps: int = 0 log_level: str = 'passive' log_level_replica: str = 'warning' log_on_each_node: bool = True logging_dir: typing.Optional[str] = None logging_strategy: typing.Union[transformers.trainer_utils.IntervalStrategy, str] = 'steps' logging_first_step: bool = False logging_steps: float = 10 logging_nan_inf_filter: bool = True save_strategy: typing.Union[transformers.trainer_utils.SaveStrategy, str] = 'steps' save_steps: float = 500 save_total_limit: typing.Optional[int] = None save_safetensors: typing.Optional[bool] = True save_on_each_node: bool = False save_only_model: bool = False restore_callback_states_from_checkpoint: bool = False no_cuda: bool = False use_cpu: bool = False use_mps_device: bool = False seed: int = 42 data_seed: typing.Optional[int] = None jit_mode_eval: bool = False use_ipex: bool = False bf16: typing.Optional[bool] = None fp16: bool = False fp16_opt_level: str = 'O1' half_precision_backend: str = 'auto' bf16_full_eval: bool = False fp16_full_eval: bool = False tf32: typing.Optional[bool] = None local_rank: int = -1 ddp_backend: typing.Optional[str] = None tpu_num_cores: typing.Optional[int] = None tpu_metrics_debug: bool = False debug: typing.Union[str, list[transformers.debug_utils.DebugOption]] = '' dataloader_drop_last: bool = False eval_steps: typing.Optional[float] = None dataloader_num_workers: int = 0 dataloader_prefetch_factor: typing.Optional[int] = None past_index: int = -1 run_name: typing.Optional[str] = None disable_tqdm: typing.Optional[bool] = None remove_unused_columns: typing.Optional[bool] = True label_names: typing.Optional[list[str]] = None load_best_model_at_end: typing.Optional[bool] = False metric_for_best_model: typing.Optional[str] = None greater_is_better: typing.Optional[bool] = None ignore_data_skip: bool = False fsdp: typing.Union[list[transformers.trainer_utils.FSDPOption], str, NoneType] = '' fsdp_min_num_params: int = 0 fsdp_config: typing.Union[dict[str, typing.Any], str, NoneType] = None fsdp_transformer_layer_cls_to_wrap: typing.Optional[str] = None accelerator_config: typing.Union[dict, str, NoneType] = None deepspeed: typing.Union[dict, str, NoneType] = None label_smoothing_factor: float = 0.0 optim: typing.Union[transformers.training_args.OptimizerNames, str] = 'adamw_torch' optim_args: typing.Optional[str] = None adafactor: bool = False group_by_length: bool = False length_column_name: typing.Optional[str] = 'length' report_to: typing.Union[NoneType, str, list[str]] = None ddp_find_unused_parameters: typing.Optional[bool] = None ddp_bucket_cap_mb: typing.Optional[int] = None ddp_broadcast_buffers: typing.Optional[bool] = None dataloader_pin_memory: bool = True dataloader_persistent_workers: bool = False skip_memory_metrics: bool = True use_legacy_prediction_loop: bool = False push_to_hub: bool = False resume_from_checkpoint: typing.Optional[str] = None hub_model_id: typing.Optional[str] = None hub_strategy: typing.Union[transformers.trainer_utils.HubStrategy, str] = 'every_save' hub_token: typing.Optional[str] = None hub_private_repo: typing.Optional[bool] = None hub_always_push: bool = False hub_revision: typing.Optional[str] = None gradient_checkpointing: bool = False gradient_checkpointing_kwargs: typing.Union[dict[str, typing.Any], str, NoneType] = None include_inputs_for_metrics: bool = False include_for_metrics: list = <factory> eval_do_concat_batches: bool = True fp16_backend: str = 'auto' push_to_hub_model_id: typing.Optional[str] = None push_to_hub_organization: typing.Optional[str] = None push_to_hub_token: typing.Optional[str] = None mp_parameters: str = '' auto_find_batch_size: bool = False full_determinism: bool = False torchdynamo: typing.Optional[str] = None ray_scope: typing.Optional[str] = 'last' ddp_timeout: int = 1800 torch_compile: bool = False torch_compile_backend: typing.Optional[str] = None torch_compile_mode: typing.Optional[str] = None include_tokens_per_second: typing.Optional[bool] = False include_num_input_tokens_seen: typing.Optional[bool] = False neftune_noise_alpha: typing.Optional[float] = None optim_target_modules: typing.Union[NoneType, str, list[str]] = None batch_eval_metrics: bool = False eval_on_start: bool = False use_liger_kernel: typing.Optional[bool] = False liger_kernel_config: typing.Optional[dict[str, bool]] = None eval_use_gather_object: typing.Optional[bool] = False average_tokens_across_devices: typing.Optional[bool] = True model_init_kwargs: typing.Optional[dict[str, typing.Any]] = None max_length: typing.Optional[int] = None truncation_mode: str = 'keep_end' optimize_device_cache: bool = False )
控制模型的引數
- model_init_kwargs (
dict[str, Any]
或None
, 可選, 預設為None
) — 用於from_pretrained
的關鍵字引數,當 IterativeSFTTrainer 的model
引數以字串形式提供時使用。
控制資料預處理的引數
IterativeSFTTrainer 的配置類。
此類僅包含特定於迭代式 SFT 訓練的引數。有關訓練引數的完整列表,請參閱 TrainingArguments
文件。請注意,此類中的預設值可能與 TrainingArguments
中的預設值不同。
使用 HfArgumentParser
,我們可以將此類別轉換為可在命令列上指定的 argparse 引數。