Optimum 文件

訓練器

您正在檢視的是需要從原始碼安裝。如果您想進行常規 pip 安裝,請檢視最新穩定版本 (v1.27.0)。
Hugging Face's logo
加入 Hugging Face 社群

並獲得增強的文件體驗

開始使用

訓練器

ORTTrainer

class optimum.onnxruntime.ORTTrainer

< >

( 模型: typing.Union[transformers.modeling_utils.PreTrainedModel, torch.nn.modules.module.Module] = None 引數: ORTTrainingArguments = None 資料收集器: typing.Optional[transformers.data.data_collator.DataCollator] = None 訓練資料集: typing.Optional[torch.utils.data.dataset.Dataset] = None 評估資料集: typing.Union[torch.utils.data.dataset.Dataset, typing.Dict[str, torch.utils.data.dataset.Dataset], NoneType] = None 分詞器: typing.Optional[transformers.tokenization_utils_base.PreTrainedTokenizerBase] = None 模型初始化: typing.Optional[typing.Callable[[], transformers.modeling_utils.PreTrainedModel]] = None 計算指標: typing.Optional[typing.Callable[[transformers.trainer_utils.EvalPrediction], typing.Dict]] = None 回撥: typing.Optional[typing.List[transformers.trainer_callback.TrainerCallback]] = None 最佳化器: typing.Tuple[torch.optim.optimizer.Optimizer, torch.optim.lr_scheduler.LambdaLR] = (None, None) 用於指標的預處理logits: typing.Optional[typing.Callable[[torch.Tensor, torch.Tensor], torch.Tensor]] = None )

引數

  • 模型PreTrainedModeltorch.nn.Module可選)— 用於訓練、評估或預測的模型。如果未提供,則必須傳入一個 model_init

    ORTTrainer 經過最佳化,可與 transformers 庫提供的 PreTrainedModel 配合使用。只要您自己的模型(定義為 torch.nn.Module)與 🤗 Transformers 模型的工作方式相同,您仍然可以使用它們進行 ONNX Runtime 後端訓練和 PyTorch 後端推理。

  • 引數ORTTrainingArguments可選)— 用於訓練的調整引數。如果未提供,將預設為 ORTTrainingArguments 的一個基本例項,其中 output_dir 設定為當前目錄中名為 tmp_trainer 的目錄。
  • 資料收集器DataCollator可選)— 用於從 train_dataseteval_dataset 元素列表中形成批次的功能。如果未提供 tokenizer,將預設為 default_data_collator,否則將預設為 DataCollatorWithPadding 的例項。
  • 訓練資料集torch.utils.data.Datasettorch.utils.data.IterableDataset可選)— 用於訓練的資料集。如果是 Dataset,則會自動刪除 model.forward() 方法不接受的列。請注意,如果它是一個帶有隨機化的 torch.utils.data.IterableDataset,並且您以分散式方式進行訓練,那麼您的可迭代資料集應該要麼使用一個內部屬性 generator(它是一個 torch.Generator),用於所有程序上必須相同的隨機化,要麼有一個 set_epoch() 方法,該方法在內部設定所使用的 RNG 的種子。
  • 評估資料集(Union[torch.utils.data.Dataset, Dict[str, torch.utils.data.Dataset]],可選)— 用於評估的資料集。如果是 Dataset,則會自動刪除 model.forward() 方法不接受的列。如果是一個字典,則會在每個資料集上進行評估,並將字典鍵作為度量名稱的字首。
  • 分詞器PreTrainedTokenizerBase可選)— 用於預處理資料用的分詞器。如果提供,它將在批處理輸入時自動將輸入填充到最大長度,並與模型一起儲存,以便更容易地重新執行中斷的訓練或重用微調後的模型。
  • 模型初始化Callable[[], PreTrainedModel]可選)— 例項化要使用的模型的功能。如果提供,每次呼叫 ORTTrainer.train 都將從該功能給出的新模型例項開始。該功能可以不帶引數,或帶有一個引數(包含 optuna/Ray Tune/SigOpt 試驗物件),以便能夠根據超引數(如層數、內部層大小、dropout 機率等)選擇不同的架構。
  • 計算指標Callable[[EvalPrediction], Dict]可選)— 將用於在評估時計算指標的功能。必須接受一個 EvalPrediction 並返回一個從字串到指標值的字典。
  • 回撥(List of TrainerCallback可選)— 用於自定義訓練迴圈的回撥列表。這些回撥將被新增到此處詳述的預設回撥列表中。如果您想刪除其中一個預設回撥,請使用 ORTTrainer.remove_callback 方法。
  • 最佳化器Tuple[torch.optim.Optimizer, torch.optim.lr_scheduler.LambdaLR]可選)— 包含要使用的最佳化器和排程器的元組。將預設為您的模型上的 AdamW 例項和由 args 控制的 get_linear_schedule_with_warmup 給出的排程器。
  • 用於指標的預處理logitsCallable[[torch.Tensor, torch.Tensor], torch.Tensor]可選)— 一個在每個評估步驟快取 logits 之前立即預處理 logits 的函式。必須接受兩個張量,即 logits 和標籤,並返回經過處理的 logits。此函式所做的修改將反映在 compute_metrics 收到的預測中。請注意,如果資料集中沒有標籤,則標籤(第二個引數)將為 None

ORTTrainer 是一個簡單但功能完備的 ONNX Runtime 訓練和評估迴圈,針對 🤗 Transformers 進行了最佳化。

重要屬性

  • 模型 — 始終指向核心模型。如果使用 transformers 模型,它將是 PreTrainedModel 的子類。
  • 模型包裝器 — 如果一個或多個其他模組包裝了原始模型,則始終指向最外部的模型。這是應該用於前向傳播的模型。例如,在 DeepSpeed 下,內部模型首先被 ORTModule 包裝,然後被 DeepSpeed 包裝,然後再次被 torch.nn.DistributedDataParallel 包裝。如果內部模型尚未包裝,則 self.model_wrappedself.model 相同。
  • is_model_parallel — 模型是否已切換到模型並行模式(與資料並行不同,這意味著一些模型層分佈在不同的 GPU 上)。
  • place_model_on_device — 是否自動將模型放置在裝置上 - 如果使用模型並行或 DeepSpeed,或者如果預設的 ORTTrainingArguments.place_model_on_device 被覆蓋為返回 False,則此項將設定為 False
  • is_in_train — 模型當前是否正在執行 train(例如,在 train 中呼叫 evaluate 時)

建立最佳化器

< >

( )

設定最佳化器。

我們提供了一個效果良好的合理預設值。如果您想使用其他最佳化器,可以透過 optimizers 在 ORTTrainer 的初始化中傳遞一個元組,或者在子類中重寫此方法。

獲取 ort_optimizer_cls_and_kwargs

< >

( 引數: ORTTrainingArguments )

引數

  • 引數ORTTrainingArguments)— 訓練會話的訓練引數。

根據 ORTTrainingArguments 返回在 ONNX Runtime 中實現的最佳化器類和最佳化器引數。

訓練

< >

( 從檢查點恢復: typing.Union[bool, str, NoneType] = None 試用: typing.Union[ForwardRef('optuna.Trial'), typing.Dict[str, typing.Any]] = None 忽略評估鍵: typing.Optional[typing.List[str]] = None **kwargs )

引數

  • 從檢查點恢復strbool可選)— 如果是 str,則為上一個 ORTTrainer 例項儲存的檢查點的本地路徑。如果是 bool 且等於 True,則載入上一個 ORTTrainer 例項儲存在 args.output_dir 中的最後一個檢查點。如果存在,訓練將從此處載入的模型/最佳化器/排程器狀態恢復。
  • 試用optuna.TrialDict[str, Any]可選)— 用於超引數搜尋的試用執行或超引數字典。
  • 忽略評估鍵List[str]可選)— 模型輸出(如果是字典)中應在訓練期間收集評估預測時忽略的鍵列表。
  • kwargsDict[str, Any]可選)— 用於隱藏已棄用引數的附加關鍵字引數。

使用 ONNX Runtime 加速器進行訓練的主要入口點。

ORTSeq2SeqTrainer

class optimum.onnxruntime.ORTSeq2SeqTrainer

< >

( 模型: typing.Union[transformers.modeling_utils.PreTrainedModel, torch.nn.modules.module.Module] = None 引數: ORTTrainingArguments = None 資料收集器: typing.Optional[transformers.data.data_collator.DataCollator] = None 訓練資料集: typing.Optional[torch.utils.data.dataset.Dataset] = None 評估資料集: typing.Union[torch.utils.data.dataset.Dataset, typing.Dict[str, torch.utils.data.dataset.Dataset], NoneType] = None 分詞器: typing.Optional[transformers.tokenization_utils_base.PreTrainedTokenizerBase] = None 模型初始化: typing.Optional[typing.Callable[[], transformers.modeling_utils.PreTrainedModel]] = None 計算指標: typing.Optional[typing.Callable[[transformers.trainer_utils.EvalPrediction], typing.Dict]] = None 回撥: typing.Optional[typing.List[transformers.trainer_callback.TrainerCallback]] = None 最佳化器: typing.Tuple[torch.optim.optimizer.Optimizer, torch.optim.lr_scheduler.LambdaLR] = (None, None) 用於指標的預處理logits: typing.Optional[typing.Callable[[torch.Tensor, torch.Tensor], torch.Tensor]] = None )

評估

< >

( 評估資料集: typing.Optional[torch.utils.data.dataset.Dataset] = None 忽略鍵: typing.Optional[typing.List[str]] = None 指標鍵字首: str = 'eval' **gen_kwargs )

引數

  • 評估資料集Dataset可選)— 如果您希望覆蓋 self.eval_dataset,請傳遞一個數據集。如果是 Dataset,則會自動刪除 model.forward() 方法不接受的列。它必須實現 __len__ 方法。
  • 忽略鍵List[str]可選)— 模型輸出(如果是字典)中應在收集預測時忽略的鍵列表。
  • 指標鍵字首str可選,預設為 "eval")— 用於作為指標鍵字首的可選字首。例如,如果字首為 "eval"(預設),則指標“bleu”將命名為“eval_bleu”。
  • 最大長度int可選)— 使用生成方法預測時要使用的最大目標長度。
  • 光束數量int可選)— 使用生成方法預測時將使用的光束搜尋的光束數量。1 表示不進行光束搜尋。
  • gen_kwargs — 其他特定於 generate 的關鍵字引數。

執行評估並返回指標。

呼叫指令碼將負責提供一個計算指標的方法,因為它們是依賴於任務的(將其傳遞給初始化 compute_metrics 引數)。

您還可以透過子類化並覆蓋此方法來注入自定義行為。

預測

< >

( 測試資料集: Dataset 忽略鍵: typing.Optional[typing.List[str]] = None 指標鍵字首: str = 'test' **gen_kwargs )

引數

  • 測試資料集Dataset)— 用於執行預測的資料集。如果是 Dataset,則會自動刪除 model.forward() 方法不接受的列。必須實現 __len__ 方法。
  • 忽略鍵List[str]可選)— 模型輸出(如果是字典)中應在收集預測時忽略的鍵列表。
  • 指標鍵字首str可選,預設為 "eval")— 用於作為指標鍵字首的可選字首。例如,如果字首為 "eval"(預設),則指標“bleu”將命名為“eval_bleu”。
  • 最大長度int可選)— 使用生成方法預測時要使用的最大目標長度。
  • 光束數量int可選)— 使用生成方法預測時將使用的光束搜尋的光束數量。1 表示不進行光束搜尋。
  • gen_kwargs — 其他特定於 generate 的關鍵字引數。

執行預測並返回預測和潛在指標。

根據資料集和您的用例,您的測試資料集可能包含標籤。在這種情況下,此方法也將返回指標,就像在 evaluate() 中一樣。

如果您的預測或標籤具有不同的序列長度(例如,因為您在令牌分類任務中進行動態填充),則預測將被填充(右側)以允許連線成一個數組。填充索引為 -100。

返回:NamedTuple 具有以下鍵的命名元組

  • 預測(np.ndarray):在 test_dataset 上的預測。
  • 標籤ID(np.ndarray可選):標籤(如果資料集中包含)。
  • 指標(Dict[str, float]可選):潛在的指標字典(如果資料集中包含標籤)。

ORTTrainingArguments

class optimum.onnxruntime.ORTTrainingArguments

< >

( output_dir: typing.Optional[str] = None overwrite_output_dir: bool = False do_train: bool = False do_eval: bool = False do_predict: bool = False eval_strategy: typing.Union[transformers.trainer_utils.IntervalStrategy, str] = 'no' prediction_loss_only: bool = False per_device_train_batch_size: int = 8 per_device_eval_batch_size: int = 8 per_gpu_train_batch_size: typing.Optional[int] = None per_gpu_eval_batch_size: typing.Optional[int] = None gradient_accumulation_steps: int = 1 eval_accumulation_steps: typing.Optional[int] = None eval_delay: typing.Optional[float] = 0 torch_empty_cache_steps: typing.Optional[int] = None learning_rate: float = 5e-05 weight_decay: float = 0.0 adam_beta1: float = 0.9 adam_beta2: float = 0.999 adam_epsilon: float = 1e-08 max_grad_norm: float = 1.0 num_train_epochs: float = 3.0 max_steps: int = -1 lr_scheduler_type: typing.Union[transformers.trainer_utils.SchedulerType, str] = 'linear' lr_scheduler_kwargs: typing.Union[dict[str, typing.Any], str, NoneType] = <factory> warmup_ratio: float = 0.0 warmup_steps: int = 0 log_level: str = 'passive' log_level_replica: str = 'warning' log_on_each_node: bool = True logging_dir: typing.Optional[str] = None logging_strategy: typing.Union[transformers.trainer_utils.IntervalStrategy, str] = 'steps' logging_first_step: bool = False logging_steps: float = 500 logging_nan_inf_filter: bool = True save_strategy: typing.Union[transformers.trainer_utils.SaveStrategy, str] = 'steps' save_steps: float = 500 save_total_limit: typing.Optional[int] = None save_safetensors: typing.Optional[bool] = True save_on_each_node: bool = False save_only_model: bool = False restore_callback_states_from_checkpoint: bool = False no_cuda: bool = False use_cpu: bool = False use_mps_device: bool = False seed: int = 42 data_seed: typing.Optional[int] = None jit_mode_eval: bool = False use_ipex: bool = False bf16: bool = False fp16: bool = False fp16_opt_level: str = 'O1' half_precision_backend: str = 'auto' bf16_full_eval: bool = False fp16_full_eval: bool = False tf32: typing.Optional[bool] = None local_rank: int = -1 ddp_backend: typing.Optional[str] = None tpu_num_cores: typing.Optional[int] = None tpu_metrics_debug: bool = False debug: typing.Union[str, list[transformers.debug_utils.DebugOption]] = '' dataloader_drop_last: bool = False eval_steps: typing.Optional[float] = None dataloader_num_workers: int = 0 dataloader_prefetch_factor: typing.Optional[int] = None past_index: int = -1 run_name: typing.Optional[str] = None disable_tqdm: typing.Optional[bool] = None remove_unused_columns: typing.Optional[bool] = True label_names: typing.Optional[list[str]] = None load_best_model_at_end: typing.Optional[bool] = False metric_for_best_model: typing.Optional[str] = None greater_is_better: typing.Optional[bool] = None ignore_data_skip: bool = False fsdp: typing.Union[list[transformers.trainer_utils.FSDPOption], str, NoneType] = '' fsdp_min_num_params: int = 0 fsdp_config: typing.Union[dict[str, typing.Any], str, NoneType] = None fsdp_transformer_layer_cls_to_wrap: typing.Optional[str] = None accelerator_config: typing.Union[dict, str, NoneType] = None deepspeed: typing.Union[dict, str, NoneType] = None label_smoothing_factor: float = 0.0 optim: typing.Optional[str] = 'adamw_hf' optim_args: typing.Optional[str] = None adafactor: bool = False group_by_length: bool = False length_column_name: typing.Optional[str] = 'length' report_to: typing.Union[NoneType, str, list[str]] = None ddp_find_unused_parameters: typing.Optional[bool] = None ddp_bucket_cap_mb: typing.Optional[int] = None ddp_broadcast_buffers: typing.Optional[bool] = None dataloader_pin_memory: bool = True dataloader_persistent_workers: bool = False skip_memory_metrics: bool = True use_legacy_prediction_loop: bool = False push_to_hub: bool = False resume_from_checkpoint: typing.Optional[str] = None hub_model_id: typing.Optional[str] = None hub_strategy: typing.Union[transformers.trainer_utils.HubStrategy, str] = 'every_save' hub_token: typing.Optional[str] = None hub_private_repo: typing.Optional[bool] = None hub_always_push: bool = False hub_revision: typing.Optional[str] = None gradient_checkpointing: bool = False gradient_checkpointing_kwargs: typing.Union[dict[str, typing.Any], str, NoneType] = None include_inputs_for_metrics: bool = False include_for_metrics: list[str] = <factory> eval_do_concat_batches: bool = True fp16_backend: str = 'auto' push_to_hub_model_id: typing.Optional[str] = None push_to_hub_organization: typing.Optional[str] = None push_to_hub_token: typing.Optional[str] = None mp_parameters: str = '' auto_find_batch_size: bool = False full_determinism: bool = False torchdynamo: typing.Optional[str] = None ray_scope: typing.Optional[str] = 'last' ddp_timeout: int = 1800 torch_compile: bool = False torch_compile_backend: typing.Optional[str] = None torch_compile_mode: typing.Optional[str] = None include_tokens_per_second: typing.Optional[bool] = False include_num_input_tokens_seen: typing.Optional[bool] = False neftune_noise_alpha: typing.Optional[float] = None optim_target_modules: typing.Union[NoneType, str, list[str]] = None batch_eval_metrics: bool = False eval_on_start: bool = False use_liger_kernel: typing.Optional[bool] = False liger_kernel_config: typing.Optional[dict[str, bool]] = None eval_use_gather_object: typing.Optional[bool] = False average_tokens_across_devices: typing.Optional[bool] = False use_module_with_loss: typing.Optional[bool] = False save_onnx: typing.Optional[bool] = False onnx_prefix: typing.Optional[str] = None onnx_log_level: typing.Optional[str] = 'WARNING' )

引數

  • optim (strtraining_args.ORTOptimizerNamestransformers.training_args.OptimizerNames, 可選, 預設為 "adamw_hf") — 要使用的最佳化器,包括 Transformers 中的最佳化器:adamw_hf、adamw_torch、adamw_apex_fused 或 adafactor。以及 ONNX Runtime 實現的最佳化器:adamw_ort_fused。

ORTSeq2SeqTrainingArguments

class optimum.onnxruntime.ORTSeq2SeqTrainingArguments

< >

( output_dir: typing.Optional[str] = None overwrite_output_dir: bool = False do_train: bool = False do_eval: bool = False do_predict: bool = False eval_strategy: typing.Union[transformers.trainer_utils.IntervalStrategy, str] = 'no' prediction_loss_only: bool = False per_device_train_batch_size: int = 8 per_device_eval_batch_size: int = 8 per_gpu_train_batch_size: typing.Optional[int] = None per_gpu_eval_batch_size: typing.Optional[int] = None gradient_accumulation_steps: int = 1 eval_accumulation_steps: typing.Optional[int] = None eval_delay: typing.Optional[float] = 0 torch_empty_cache_steps: typing.Optional[int] = None learning_rate: float = 5e-05 weight_decay: float = 0.0 adam_beta1: float = 0.9 adam_beta2: float = 0.999 adam_epsilon: float = 1e-08 max_grad_norm: float = 1.0 num_train_epochs: float = 3.0 max_steps: int = -1 lr_scheduler_type: typing.Union[transformers.trainer_utils.SchedulerType, str] = 'linear' lr_scheduler_kwargs: typing.Union[dict[str, typing.Any], str, NoneType] = <factory> warmup_ratio: float = 0.0 warmup_steps: int = 0 log_level: str = 'passive' log_level_replica: str = 'warning' log_on_each_node: bool = True logging_dir: typing.Optional[str] = None logging_strategy: typing.Union[transformers.trainer_utils.IntervalStrategy, str] = 'steps' logging_first_step: bool = False logging_steps: float = 500 logging_nan_inf_filter: bool = True save_strategy: typing.Union[transformers.trainer_utils.SaveStrategy, str] = 'steps' save_steps: float = 500 save_total_limit: typing.Optional[int] = None save_safetensors: typing.Optional[bool] = True save_on_each_node: bool = False save_only_model: bool = False restore_callback_states_from_checkpoint: bool = False no_cuda: bool = False use_cpu: bool = False use_mps_device: bool = False seed: int = 42 data_seed: typing.Optional[int] = None jit_mode_eval: bool = False use_ipex: bool = False bf16: bool = False fp16: bool = False fp16_opt_level: str = 'O1' half_precision_backend: str = 'auto' bf16_full_eval: bool = False fp16_full_eval: bool = False tf32: typing.Optional[bool] = None local_rank: int = -1 ddp_backend: typing.Optional[str] = None tpu_num_cores: typing.Optional[int] = None tpu_metrics_debug: bool = False debug: typing.Union[str, list[transformers.debug_utils.DebugOption]] = '' dataloader_drop_last: bool = False eval_steps: typing.Optional[float] = None dataloader_num_workers: int = 0 dataloader_prefetch_factor: typing.Optional[int] = None past_index: int = -1 run_name: typing.Optional[str] = None disable_tqdm: typing.Optional[bool] = None remove_unused_columns: typing.Optional[bool] = True label_names: typing.Optional[list[str]] = None load_best_model_at_end: typing.Optional[bool] = False metric_for_best_model: typing.Optional[str] = None greater_is_better: typing.Optional[bool] = None ignore_data_skip: bool = False fsdp: typing.Union[list[transformers.trainer_utils.FSDPOption], str, NoneType] = '' fsdp_min_num_params: int = 0 fsdp_config: typing.Union[dict[str, typing.Any], str, NoneType] = None fsdp_transformer_layer_cls_to_wrap: typing.Optional[str] = None accelerator_config: typing.Union[dict, str, NoneType] = None deepspeed: typing.Union[dict, str, NoneType] = None label_smoothing_factor: float = 0.0 optim: typing.Optional[str] = 'adamw_hf' optim_args: typing.Optional[str] = None adafactor: bool = False group_by_length: bool = False length_column_name: typing.Optional[str] = 'length' report_to: typing.Union[NoneType, str, list[str]] = None ddp_find_unused_parameters: typing.Optional[bool] = None ddp_bucket_cap_mb: typing.Optional[int] = None ddp_broadcast_buffers: typing.Optional[bool] = None dataloader_pin_memory: bool = True dataloader_persistent_workers: bool = False skip_memory_metrics: bool = True use_legacy_prediction_loop: bool = False push_to_hub: bool = False resume_from_checkpoint: typing.Optional[str] = None hub_model_id: typing.Optional[str] = None hub_strategy: typing.Union[transformers.trainer_utils.HubStrategy, str] = 'every_save' hub_token: typing.Optional[str] = None hub_private_repo: typing.Optional[bool] = None hub_always_push: bool = False hub_revision: typing.Optional[str] = None gradient_checkpointing: bool = False gradient_checkpointing_kwargs: typing.Union[dict[str, typing.Any], str, NoneType] = None include_inputs_for_metrics: bool = False include_for_metrics: list[str] = <factory> eval_do_concat_batches: bool = True fp16_backend: str = 'auto' push_to_hub_model_id: typing.Optional[str] = None push_to_hub_organization: typing.Optional[str] = None push_to_hub_token: typing.Optional[str] = None mp_parameters: str = '' auto_find_batch_size: bool = False full_determinism: bool = False torchdynamo: typing.Optional[str] = None ray_scope: typing.Optional[str] = 'last' ddp_timeout: int = 1800 torch_compile: bool = False torch_compile_backend: typing.Optional[str] = None torch_compile_mode: typing.Optional[str] = None include_tokens_per_second: typing.Optional[bool] = False include_num_input_tokens_seen: typing.Optional[bool] = False neftune_noise_alpha: typing.Optional[float] = None optim_target_modules: typing.Union[NoneType, str, list[str]] = None batch_eval_metrics: bool = False eval_on_start: bool = False use_liger_kernel: typing.Optional[bool] = False liger_kernel_config: typing.Optional[dict[str, bool]] = None eval_use_gather_object: typing.Optional[bool] = False average_tokens_across_devices: typing.Optional[bool] = False use_module_with_loss: typing.Optional[bool] = False save_onnx: typing.Optional[bool] = False onnx_prefix: typing.Optional[str] = None onnx_log_level: typing.Optional[str] = 'WARNING' sortish_sampler: bool = False predict_with_generate: bool = False generation_max_length: typing.Optional[int] = None generation_num_beams: typing.Optional[int] = None generation_config: typing.Union[str, pathlib.Path, transformers.generation.configuration_utils.GenerationConfig, NoneType] = None )

引數

  • optim (strtraining_args.ORTOptimizerNamestransformers.training_args.OptimizerNames, 可選, 預設為 "adamw_hf") — 要使用的最佳化器,包括 Transformers 中的最佳化器:adamw_hf、adamw_torch、adamw_apex_fused 或 adafactor。以及 ONNX Runtime 實現的最佳化器:adamw_ort_fused。
< > 在 GitHub 上更新

© . This site is unofficial and not affiliated with Hugging Face, Inc.