Optimum 文件
訓練器
並獲得增強的文件體驗
開始使用
訓練器
ORTTrainer
class optimum.onnxruntime.ORTTrainer
< 來源 >( 模型: typing.Union[transformers.modeling_utils.PreTrainedModel, torch.nn.modules.module.Module] = None 引數: ORTTrainingArguments = None 資料收集器: typing.Optional[transformers.data.data_collator.DataCollator] = None 訓練資料集: typing.Optional[torch.utils.data.dataset.Dataset] = None 評估資料集: typing.Union[torch.utils.data.dataset.Dataset, typing.Dict[str, torch.utils.data.dataset.Dataset], NoneType] = None 分詞器: typing.Optional[transformers.tokenization_utils_base.PreTrainedTokenizerBase] = None 模型初始化: typing.Optional[typing.Callable[[], transformers.modeling_utils.PreTrainedModel]] = None 計算指標: typing.Optional[typing.Callable[[transformers.trainer_utils.EvalPrediction], typing.Dict]] = None 回撥: typing.Optional[typing.List[transformers.trainer_callback.TrainerCallback]] = None 最佳化器: typing.Tuple[torch.optim.optimizer.Optimizer, torch.optim.lr_scheduler.LambdaLR] = (None, None) 用於指標的預處理logits: typing.Optional[typing.Callable[[torch.Tensor, torch.Tensor], torch.Tensor]] = None )
引數
- 模型(PreTrainedModel 或
torch.nn.Module
,可選)— 用於訓練、評估或預測的模型。如果未提供,則必須傳入一個model_init
。ORTTrainer
經過最佳化,可與 transformers 庫提供的 PreTrainedModel 配合使用。只要您自己的模型(定義為torch.nn.Module
)與 🤗 Transformers 模型的工作方式相同,您仍然可以使用它們進行 ONNX Runtime 後端訓練和 PyTorch 後端推理。 - 引數(
ORTTrainingArguments
,可選)— 用於訓練的調整引數。如果未提供,將預設為ORTTrainingArguments
的一個基本例項,其中output_dir
設定為當前目錄中名為 tmp_trainer 的目錄。 - 資料收集器(
DataCollator
,可選)— 用於從train_dataset
或eval_dataset
元素列表中形成批次的功能。如果未提供tokenizer
,將預設為 default_data_collator,否則將預設為 DataCollatorWithPadding 的例項。 - 訓練資料集(
torch.utils.data.Dataset
或torch.utils.data.IterableDataset
,可選)— 用於訓練的資料集。如果是 Dataset,則會自動刪除model.forward()
方法不接受的列。請注意,如果它是一個帶有隨機化的torch.utils.data.IterableDataset
,並且您以分散式方式進行訓練,那麼您的可迭代資料集應該要麼使用一個內部屬性generator
(它是一個torch.Generator
),用於所有程序上必須相同的隨機化,要麼有一個set_epoch()
方法,該方法在內部設定所使用的 RNG 的種子。 - 評估資料集(Union[
torch.utils.data.Dataset
, Dict[str,torch.utils.data.Dataset
]],可選)— 用於評估的資料集。如果是 Dataset,則會自動刪除model.forward()
方法不接受的列。如果是一個字典,則會在每個資料集上進行評估,並將字典鍵作為度量名稱的字首。 - 分詞器(PreTrainedTokenizerBase,可選)— 用於預處理資料用的分詞器。如果提供,它將在批處理輸入時自動將輸入填充到最大長度,並與模型一起儲存,以便更容易地重新執行中斷的訓練或重用微調後的模型。
- 模型初始化(
Callable[[], PreTrainedModel]
,可選)— 例項化要使用的模型的功能。如果提供,每次呼叫ORTTrainer.train
都將從該功能給出的新模型例項開始。該功能可以不帶引數,或帶有一個引數(包含 optuna/Ray Tune/SigOpt 試驗物件),以便能夠根據超引數(如層數、內部層大小、dropout 機率等)選擇不同的架構。 - 計算指標(
Callable[[EvalPrediction], Dict]
,可選)— 將用於在評估時計算指標的功能。必須接受一個EvalPrediction
並返回一個從字串到指標值的字典。 - 回撥(List of
TrainerCallback
,可選)— 用於自定義訓練迴圈的回撥列表。這些回撥將被新增到此處詳述的預設回撥列表中。如果您想刪除其中一個預設回撥,請使用ORTTrainer.remove_callback
方法。 - 最佳化器(
Tuple[torch.optim.Optimizer, torch.optim.lr_scheduler.LambdaLR]
,可選)— 包含要使用的最佳化器和排程器的元組。將預設為您的模型上的AdamW
例項和由args
控制的get_linear_schedule_with_warmup
給出的排程器。 - 用於指標的預處理logits(
Callable[[torch.Tensor, torch.Tensor], torch.Tensor]
,可選)— 一個在每個評估步驟快取 logits 之前立即預處理 logits 的函式。必須接受兩個張量,即 logits 和標籤,並返回經過處理的 logits。此函式所做的修改將反映在compute_metrics
收到的預測中。請注意,如果資料集中沒有標籤,則標籤(第二個引數)將為None
。
ORTTrainer 是一個簡單但功能完備的 ONNX Runtime 訓練和評估迴圈,針對 🤗 Transformers 進行了最佳化。
重要屬性
- 模型 — 始終指向核心模型。如果使用 transformers 模型,它將是 PreTrainedModel 的子類。
- 模型包裝器 — 如果一個或多個其他模組包裝了原始模型,則始終指向最外部的模型。這是應該用於前向傳播的模型。例如,在
DeepSpeed
下,內部模型首先被ORTModule
包裝,然後被DeepSpeed
包裝,然後再次被torch.nn.DistributedDataParallel
包裝。如果內部模型尚未包裝,則self.model_wrapped
與self.model
相同。 - is_model_parallel — 模型是否已切換到模型並行模式(與資料並行不同,這意味著一些模型層分佈在不同的 GPU 上)。
- place_model_on_device — 是否自動將模型放置在裝置上 - 如果使用模型並行或 DeepSpeed,或者如果預設的
ORTTrainingArguments.place_model_on_device
被覆蓋為返回False
,則此項將設定為False
。 - is_in_train — 模型當前是否正在執行
train
(例如,在train
中呼叫evaluate
時)
設定最佳化器。
我們提供了一個效果良好的合理預設值。如果您想使用其他最佳化器,可以透過 optimizers
在 ORTTrainer 的初始化中傳遞一個元組,或者在子類中重寫此方法。
獲取 ort_optimizer_cls_and_kwargs
< 來源 >( 引數: ORTTrainingArguments )
根據 ORTTrainingArguments
返回在 ONNX Runtime 中實現的最佳化器類和最佳化器引數。
訓練
< 來源 >( 從檢查點恢復: typing.Union[bool, str, NoneType] = None 試用: typing.Union[ForwardRef('optuna.Trial'), typing.Dict[str, typing.Any]] = None 忽略評估鍵: typing.Optional[typing.List[str]] = None **kwargs )
引數
- 從檢查點恢復(
str
或bool
,可選)— 如果是str
,則為上一個ORTTrainer
例項儲存的檢查點的本地路徑。如果是bool
且等於True
,則載入上一個ORTTrainer
例項儲存在 args.output_dir 中的最後一個檢查點。如果存在,訓練將從此處載入的模型/最佳化器/排程器狀態恢復。 - 試用(
optuna.Trial
或Dict[str, Any]
,可選)— 用於超引數搜尋的試用執行或超引數字典。 - 忽略評估鍵(
List[str]
,可選)— 模型輸出(如果是字典)中應在訓練期間收集評估預測時忽略的鍵列表。 - kwargs(
Dict[str, Any]
,可選)— 用於隱藏已棄用引數的附加關鍵字引數。
使用 ONNX Runtime 加速器進行訓練的主要入口點。
ORTSeq2SeqTrainer
class optimum.onnxruntime.ORTSeq2SeqTrainer
< 來源 >( 模型: typing.Union[transformers.modeling_utils.PreTrainedModel, torch.nn.modules.module.Module] = None 引數: ORTTrainingArguments = None 資料收集器: typing.Optional[transformers.data.data_collator.DataCollator] = None 訓練資料集: typing.Optional[torch.utils.data.dataset.Dataset] = None 評估資料集: typing.Union[torch.utils.data.dataset.Dataset, typing.Dict[str, torch.utils.data.dataset.Dataset], NoneType] = None 分詞器: typing.Optional[transformers.tokenization_utils_base.PreTrainedTokenizerBase] = None 模型初始化: typing.Optional[typing.Callable[[], transformers.modeling_utils.PreTrainedModel]] = None 計算指標: typing.Optional[typing.Callable[[transformers.trainer_utils.EvalPrediction], typing.Dict]] = None 回撥: typing.Optional[typing.List[transformers.trainer_callback.TrainerCallback]] = None 最佳化器: typing.Tuple[torch.optim.optimizer.Optimizer, torch.optim.lr_scheduler.LambdaLR] = (None, None) 用於指標的預處理logits: typing.Optional[typing.Callable[[torch.Tensor, torch.Tensor], torch.Tensor]] = None )
評估
< 來源 >( 評估資料集: typing.Optional[torch.utils.data.dataset.Dataset] = None 忽略鍵: typing.Optional[typing.List[str]] = None 指標鍵字首: str = 'eval' **gen_kwargs )
引數
- 評估資料集(
Dataset
,可選)— 如果您希望覆蓋self.eval_dataset
,請傳遞一個數據集。如果是 Dataset,則會自動刪除model.forward()
方法不接受的列。它必須實現__len__
方法。 - 忽略鍵(
List[str]
,可選)— 模型輸出(如果是字典)中應在收集預測時忽略的鍵列表。 - 指標鍵字首(
str
,可選,預設為"eval"
)— 用於作為指標鍵字首的可選字首。例如,如果字首為"eval"
(預設),則指標“bleu”將命名為“eval_bleu”。 - 最大長度(
int
,可選)— 使用生成方法預測時要使用的最大目標長度。 - 光束數量(
int
,可選)— 使用生成方法預測時將使用的光束搜尋的光束數量。1 表示不進行光束搜尋。 - gen_kwargs — 其他特定於
generate
的關鍵字引數。
執行評估並返回指標。
呼叫指令碼將負責提供一個計算指標的方法,因為它們是依賴於任務的(將其傳遞給初始化 compute_metrics
引數)。
您還可以透過子類化並覆蓋此方法來注入自定義行為。
預測
< 來源 >( 測試資料集: Dataset 忽略鍵: typing.Optional[typing.List[str]] = None 指標鍵字首: str = 'test' **gen_kwargs )
引數
- 測試資料集(
Dataset
)— 用於執行預測的資料集。如果是 Dataset,則會自動刪除model.forward()
方法不接受的列。必須實現__len__
方法。 - 忽略鍵(
List[str]
,可選)— 模型輸出(如果是字典)中應在收集預測時忽略的鍵列表。 - 指標鍵字首(
str
,可選,預設為"eval"
)— 用於作為指標鍵字首的可選字首。例如,如果字首為"eval"
(預設),則指標“bleu”將命名為“eval_bleu”。 - 最大長度(
int
,可選)— 使用生成方法預測時要使用的最大目標長度。 - 光束數量(
int
,可選)— 使用生成方法預測時將使用的光束搜尋的光束數量。1 表示不進行光束搜尋。 - gen_kwargs — 其他特定於
generate
的關鍵字引數。
執行預測並返回預測和潛在指標。
根據資料集和您的用例,您的測試資料集可能包含標籤。在這種情況下,此方法也將返回指標,就像在 evaluate()
中一樣。
如果您的預測或標籤具有不同的序列長度(例如,因為您在令牌分類任務中進行動態填充),則預測將被填充(右側)以允許連線成一個數組。填充索引為 -100。
返回:NamedTuple 具有以下鍵的命名元組
- 預測(
np.ndarray
):在test_dataset
上的預測。 - 標籤ID(
np.ndarray
,可選):標籤(如果資料集中包含)。 - 指標(
Dict[str, float]
,可選):潛在的指標字典(如果資料集中包含標籤)。
ORTTrainingArguments
class optimum.onnxruntime.ORTTrainingArguments
< 來源 >( output_dir: typing.Optional[str] = None overwrite_output_dir: bool = False do_train: bool = False do_eval: bool = False do_predict: bool = False eval_strategy: typing.Union[transformers.trainer_utils.IntervalStrategy, str] = 'no' prediction_loss_only: bool = False per_device_train_batch_size: int = 8 per_device_eval_batch_size: int = 8 per_gpu_train_batch_size: typing.Optional[int] = None per_gpu_eval_batch_size: typing.Optional[int] = None gradient_accumulation_steps: int = 1 eval_accumulation_steps: typing.Optional[int] = None eval_delay: typing.Optional[float] = 0 torch_empty_cache_steps: typing.Optional[int] = None learning_rate: float = 5e-05 weight_decay: float = 0.0 adam_beta1: float = 0.9 adam_beta2: float = 0.999 adam_epsilon: float = 1e-08 max_grad_norm: float = 1.0 num_train_epochs: float = 3.0 max_steps: int = -1 lr_scheduler_type: typing.Union[transformers.trainer_utils.SchedulerType, str] = 'linear' lr_scheduler_kwargs: typing.Union[dict[str, typing.Any], str, NoneType] = <factory> warmup_ratio: float = 0.0 warmup_steps: int = 0 log_level: str = 'passive' log_level_replica: str = 'warning' log_on_each_node: bool = True logging_dir: typing.Optional[str] = None logging_strategy: typing.Union[transformers.trainer_utils.IntervalStrategy, str] = 'steps' logging_first_step: bool = False logging_steps: float = 500 logging_nan_inf_filter: bool = True save_strategy: typing.Union[transformers.trainer_utils.SaveStrategy, str] = 'steps' save_steps: float = 500 save_total_limit: typing.Optional[int] = None save_safetensors: typing.Optional[bool] = True save_on_each_node: bool = False save_only_model: bool = False restore_callback_states_from_checkpoint: bool = False no_cuda: bool = False use_cpu: bool = False use_mps_device: bool = False seed: int = 42 data_seed: typing.Optional[int] = None jit_mode_eval: bool = False use_ipex: bool = False bf16: bool = False fp16: bool = False fp16_opt_level: str = 'O1' half_precision_backend: str = 'auto' bf16_full_eval: bool = False fp16_full_eval: bool = False tf32: typing.Optional[bool] = None local_rank: int = -1 ddp_backend: typing.Optional[str] = None tpu_num_cores: typing.Optional[int] = None tpu_metrics_debug: bool = False debug: typing.Union[str, list[transformers.debug_utils.DebugOption]] = '' dataloader_drop_last: bool = False eval_steps: typing.Optional[float] = None dataloader_num_workers: int = 0 dataloader_prefetch_factor: typing.Optional[int] = None past_index: int = -1 run_name: typing.Optional[str] = None disable_tqdm: typing.Optional[bool] = None remove_unused_columns: typing.Optional[bool] = True label_names: typing.Optional[list[str]] = None load_best_model_at_end: typing.Optional[bool] = False metric_for_best_model: typing.Optional[str] = None greater_is_better: typing.Optional[bool] = None ignore_data_skip: bool = False fsdp: typing.Union[list[transformers.trainer_utils.FSDPOption], str, NoneType] = '' fsdp_min_num_params: int = 0 fsdp_config: typing.Union[dict[str, typing.Any], str, NoneType] = None fsdp_transformer_layer_cls_to_wrap: typing.Optional[str] = None accelerator_config: typing.Union[dict, str, NoneType] = None deepspeed: typing.Union[dict, str, NoneType] = None label_smoothing_factor: float = 0.0 optim: typing.Optional[str] = 'adamw_hf' optim_args: typing.Optional[str] = None adafactor: bool = False group_by_length: bool = False length_column_name: typing.Optional[str] = 'length' report_to: typing.Union[NoneType, str, list[str]] = None ddp_find_unused_parameters: typing.Optional[bool] = None ddp_bucket_cap_mb: typing.Optional[int] = None ddp_broadcast_buffers: typing.Optional[bool] = None dataloader_pin_memory: bool = True dataloader_persistent_workers: bool = False skip_memory_metrics: bool = True use_legacy_prediction_loop: bool = False push_to_hub: bool = False resume_from_checkpoint: typing.Optional[str] = None hub_model_id: typing.Optional[str] = None hub_strategy: typing.Union[transformers.trainer_utils.HubStrategy, str] = 'every_save' hub_token: typing.Optional[str] = None hub_private_repo: typing.Optional[bool] = None hub_always_push: bool = False hub_revision: typing.Optional[str] = None gradient_checkpointing: bool = False gradient_checkpointing_kwargs: typing.Union[dict[str, typing.Any], str, NoneType] = None include_inputs_for_metrics: bool = False include_for_metrics: list[str] = <factory> eval_do_concat_batches: bool = True fp16_backend: str = 'auto' push_to_hub_model_id: typing.Optional[str] = None push_to_hub_organization: typing.Optional[str] = None push_to_hub_token: typing.Optional[str] = None mp_parameters: str = '' auto_find_batch_size: bool = False full_determinism: bool = False torchdynamo: typing.Optional[str] = None ray_scope: typing.Optional[str] = 'last' ddp_timeout: int = 1800 torch_compile: bool = False torch_compile_backend: typing.Optional[str] = None torch_compile_mode: typing.Optional[str] = None include_tokens_per_second: typing.Optional[bool] = False include_num_input_tokens_seen: typing.Optional[bool] = False neftune_noise_alpha: typing.Optional[float] = None optim_target_modules: typing.Union[NoneType, str, list[str]] = None batch_eval_metrics: bool = False eval_on_start: bool = False use_liger_kernel: typing.Optional[bool] = False liger_kernel_config: typing.Optional[dict[str, bool]] = None eval_use_gather_object: typing.Optional[bool] = False average_tokens_across_devices: typing.Optional[bool] = False use_module_with_loss: typing.Optional[bool] = False save_onnx: typing.Optional[bool] = False onnx_prefix: typing.Optional[str] = None onnx_log_level: typing.Optional[str] = 'WARNING' )
ORTSeq2SeqTrainingArguments
class optimum.onnxruntime.ORTSeq2SeqTrainingArguments
< 來源 >( output_dir: typing.Optional[str] = None overwrite_output_dir: bool = False do_train: bool = False do_eval: bool = False do_predict: bool = False eval_strategy: typing.Union[transformers.trainer_utils.IntervalStrategy, str] = 'no' prediction_loss_only: bool = False per_device_train_batch_size: int = 8 per_device_eval_batch_size: int = 8 per_gpu_train_batch_size: typing.Optional[int] = None per_gpu_eval_batch_size: typing.Optional[int] = None gradient_accumulation_steps: int = 1 eval_accumulation_steps: typing.Optional[int] = None eval_delay: typing.Optional[float] = 0 torch_empty_cache_steps: typing.Optional[int] = None learning_rate: float = 5e-05 weight_decay: float = 0.0 adam_beta1: float = 0.9 adam_beta2: float = 0.999 adam_epsilon: float = 1e-08 max_grad_norm: float = 1.0 num_train_epochs: float = 3.0 max_steps: int = -1 lr_scheduler_type: typing.Union[transformers.trainer_utils.SchedulerType, str] = 'linear' lr_scheduler_kwargs: typing.Union[dict[str, typing.Any], str, NoneType] = <factory> warmup_ratio: float = 0.0 warmup_steps: int = 0 log_level: str = 'passive' log_level_replica: str = 'warning' log_on_each_node: bool = True logging_dir: typing.Optional[str] = None logging_strategy: typing.Union[transformers.trainer_utils.IntervalStrategy, str] = 'steps' logging_first_step: bool = False logging_steps: float = 500 logging_nan_inf_filter: bool = True save_strategy: typing.Union[transformers.trainer_utils.SaveStrategy, str] = 'steps' save_steps: float = 500 save_total_limit: typing.Optional[int] = None save_safetensors: typing.Optional[bool] = True save_on_each_node: bool = False save_only_model: bool = False restore_callback_states_from_checkpoint: bool = False no_cuda: bool = False use_cpu: bool = False use_mps_device: bool = False seed: int = 42 data_seed: typing.Optional[int] = None jit_mode_eval: bool = False use_ipex: bool = False bf16: bool = False fp16: bool = False fp16_opt_level: str = 'O1' half_precision_backend: str = 'auto' bf16_full_eval: bool = False fp16_full_eval: bool = False tf32: typing.Optional[bool] = None local_rank: int = -1 ddp_backend: typing.Optional[str] = None tpu_num_cores: typing.Optional[int] = None tpu_metrics_debug: bool = False debug: typing.Union[str, list[transformers.debug_utils.DebugOption]] = '' dataloader_drop_last: bool = False eval_steps: typing.Optional[float] = None dataloader_num_workers: int = 0 dataloader_prefetch_factor: typing.Optional[int] = None past_index: int = -1 run_name: typing.Optional[str] = None disable_tqdm: typing.Optional[bool] = None remove_unused_columns: typing.Optional[bool] = True label_names: typing.Optional[list[str]] = None load_best_model_at_end: typing.Optional[bool] = False metric_for_best_model: typing.Optional[str] = None greater_is_better: typing.Optional[bool] = None ignore_data_skip: bool = False fsdp: typing.Union[list[transformers.trainer_utils.FSDPOption], str, NoneType] = '' fsdp_min_num_params: int = 0 fsdp_config: typing.Union[dict[str, typing.Any], str, NoneType] = None fsdp_transformer_layer_cls_to_wrap: typing.Optional[str] = None accelerator_config: typing.Union[dict, str, NoneType] = None deepspeed: typing.Union[dict, str, NoneType] = None label_smoothing_factor: float = 0.0 optim: typing.Optional[str] = 'adamw_hf' optim_args: typing.Optional[str] = None adafactor: bool = False group_by_length: bool = False length_column_name: typing.Optional[str] = 'length' report_to: typing.Union[NoneType, str, list[str]] = None ddp_find_unused_parameters: typing.Optional[bool] = None ddp_bucket_cap_mb: typing.Optional[int] = None ddp_broadcast_buffers: typing.Optional[bool] = None dataloader_pin_memory: bool = True dataloader_persistent_workers: bool = False skip_memory_metrics: bool = True use_legacy_prediction_loop: bool = False push_to_hub: bool = False resume_from_checkpoint: typing.Optional[str] = None hub_model_id: typing.Optional[str] = None hub_strategy: typing.Union[transformers.trainer_utils.HubStrategy, str] = 'every_save' hub_token: typing.Optional[str] = None hub_private_repo: typing.Optional[bool] = None hub_always_push: bool = False hub_revision: typing.Optional[str] = None gradient_checkpointing: bool = False gradient_checkpointing_kwargs: typing.Union[dict[str, typing.Any], str, NoneType] = None include_inputs_for_metrics: bool = False include_for_metrics: list[str] = <factory> eval_do_concat_batches: bool = True fp16_backend: str = 'auto' push_to_hub_model_id: typing.Optional[str] = None push_to_hub_organization: typing.Optional[str] = None push_to_hub_token: typing.Optional[str] = None mp_parameters: str = '' auto_find_batch_size: bool = False full_determinism: bool = False torchdynamo: typing.Optional[str] = None ray_scope: typing.Optional[str] = 'last' ddp_timeout: int = 1800 torch_compile: bool = False torch_compile_backend: typing.Optional[str] = None torch_compile_mode: typing.Optional[str] = None include_tokens_per_second: typing.Optional[bool] = False include_num_input_tokens_seen: typing.Optional[bool] = False neftune_noise_alpha: typing.Optional[float] = None optim_target_modules: typing.Union[NoneType, str, list[str]] = None batch_eval_metrics: bool = False eval_on_start: bool = False use_liger_kernel: typing.Optional[bool] = False liger_kernel_config: typing.Optional[dict[str, bool]] = None eval_use_gather_object: typing.Optional[bool] = False average_tokens_across_devices: typing.Optional[bool] = False use_module_with_loss: typing.Optional[bool] = False save_onnx: typing.Optional[bool] = False onnx_prefix: typing.Optional[str] = None onnx_log_level: typing.Optional[str] = 'WARNING' sortish_sampler: bool = False predict_with_generate: bool = False generation_max_length: typing.Optional[int] = None generation_num_beams: typing.Optional[int] = None generation_config: typing.Union[str, pathlib.Path, transformers.generation.configuration_utils.GenerationConfig, NoneType] = None )