Optimum 文件
ONNX 匯出配置類
並獲得增強的文件體驗
開始使用
ONNX 匯出配置類
將模型匯出為 ONNX 涉及指定
- 輸入名稱。
- 輸出名稱。
- 動態軸。這些指的是執行時可以動態更改的輸入維度(例如,批次大小或序列長度)。所有其他軸將被視為靜態軸,因此在執行時固定。
- 用於跟蹤模型的虛擬輸入。在 PyTorch 中需要此操作來記錄計算圖並將其轉換為 ONNX。
由於此資料取決於模型和任務的選擇,我們將其表示為*配置類*。每個配置類都與特定的模型架構相關聯,並遵循命名約定 `ArchitectureNameOnnxConfig`。例如,指定 BERT 模型 ONNX 匯出的配置是 `BertOnnxConfig`。
由於許多架構在 ONNX 配置上共享相似的屬性,🤗 Optimum 採用了 3 級類層次結構
- 抽象和通用的基類。這些類處理所有基本功能,同時與模態(文字、影像、音訊等)無關。
- 中端類。這些類知道模態,但同一模態可以存在多個,具體取決於它們支援的輸入。它們指定應使用哪些輸入生成器來生成虛擬輸入,但與模型無關。
- 特定於模型的類,例如上面提到的 `BertOnnxConfig`。這些是實際用於匯出模型的類。
基類
class optimum.exporters.onnx.OnnxConfig
< source >( config: PretrainedConfig task: str = 'feature-extraction' preprocessors: typing.Optional[typing.List[typing.Any]] = None int_dtype: str = 'int64' float_dtype: str = 'fp32' legacy: bool = False )
包含要提供給模型的輸入張量軸定義的字典。
包含要提供給模型的輸出張量軸定義的字典。
生成虛擬輸入
< source >( framework: str = 'pt' **kwargs ) → Dict[str, [tf.Tensor, torch.Tensor]]
引數
- framework (
str
, 預設為"pt"
) — 用於建立虛擬輸入的框架。 - batch_size (
int
, 預設為 2) — 用於虛擬輸入的批次大小。 - sequence_length (
int
, 預設為 16) — 用於虛擬輸入的序列長度。 - num_choices (
int
, 預設為 4) — 多項選擇任務的候選答案數量。 - image_width (
int
, 預設為 64) — 用於視覺任務虛擬輸入的寬度。 - image_height (
int
, 預設為 64) — 用於視覺任務虛擬輸入的高度。 - num_channels (
int
, 預設為 3) — 用於視覺任務虛擬輸入的通道數。 - feature_size (
int
, 預設為 80) — 用於音訊任務虛擬輸入的特徵數量,如果不是原始音訊。例如,這是 STFT bin 或 MEL bin 的數量。 - nb_max_frames (
int
, 預設為 3000) — 用於音訊任務虛擬輸入的幀數,如果輸入不是原始音訊。 - audio_sequence_length (
int
, 預設為 16000) — 用於音訊任務虛擬輸入的幀數,如果輸入是原始音訊。
返回
Dict[str, [tf.Tensor, torch.Tensor]]
將輸入名稱對映到正確框架格式的虛擬張量的字典。
生成跟蹤模型所需的虛擬輸入。如果未明確指定,則使用預設輸入形狀。
class optimum.exporters.onnx.OnnxConfigWithPast
< source >( config: PretrainedConfig task: str = 'feature-extraction' int_dtype: str = 'int64' float_dtype: str = 'fp32' use_past: bool = False use_past_in_inputs: bool = False preprocessors: typing.Optional[typing.List[typing.Any]] = None legacy: bool = False )
繼承自 OnnxConfig。處理僅解碼器模型 ONNX 配置的基類。
add_past_key_values
< source >( inputs_or_outputs: typing.Dict[str, typing.Dict[int, str]] direction: str )
根據方向使用 past_key_values 動態軸填充 `input_or_outputs` 對映。
class optimum.exporters.onnx.OnnxSeq2SeqConfigWithPast
< source >( config: PretrainedConfig task: str = 'feature-extraction' int_dtype: str = 'int64' float_dtype: str = 'fp32' use_past: bool = False use_past_in_inputs: bool = False behavior: ConfigBehavior = <ConfigBehavior.MONOLITH: 'monolith'> preprocessors: typing.Optional[typing.List[typing.Any]] = None legacy: bool = False )
繼承自 OnnxConfigWithPast。處理編碼器-解碼器模型 ONNX 配置的基類。
with_behavior
< source >( behavior: typing.Union[str, optimum.exporters.onnx.base.ConfigBehavior] use_past: bool = False use_past_in_inputs: bool = False ) → OnnxSeq2SeqConfigWithPast
建立當前 OnnxConfig 的副本,但具有不同的 `ConfigBehavior` 和 `use_past` 值。
中端類
文字
class optimum.exporters.onnx.TextEncoderOnnxConfig
< source >( config: PretrainedConfig task: str = 'feature-extraction' preprocessors: typing.Optional[typing.List[typing.Any]] = None int_dtype: str = 'int64' float_dtype: str = 'fp32' legacy: bool = False )
處理基於編碼器的文字架構。
class optimum.exporters.onnx.TextDecoderOnnxConfig
< source >( config: PretrainedConfig task: str = 'feature-extraction' int_dtype: str = 'int64' float_dtype: str = 'fp32' use_past: bool = False use_past_in_inputs: bool = False preprocessors: typing.Optional[typing.List[typing.Any]] = None legacy: bool = False )
處理基於解碼器的文字架構。
class optimum.exporters.onnx.TextSeq2SeqOnnxConfig
< source >( config: PretrainedConfig task: str = 'feature-extraction' int_dtype: str = 'int64' float_dtype: str = 'fp32' use_past: bool = False use_past_in_inputs: bool = False behavior: ConfigBehavior = <ConfigBehavior.MONOLITH: 'monolith'> preprocessors: typing.Optional[typing.List[typing.Any]] = None legacy: bool = False )
處理基於編碼器-解碼器的文字架構。
視覺
class optimum.exporters.onnx.config.VisionOnnxConfig
< source >( config: PretrainedConfig task: str = 'feature-extraction' preprocessors: typing.Optional[typing.List[typing.Any]] = None int_dtype: str = 'int64' float_dtype: str = 'fp32' legacy: bool = False )
處理視覺架構。
多模態
class optimum.exporters.onnx.config.TextAndVisionOnnxConfig
< source >( config: PretrainedConfig task: str = 'feature-extraction' preprocessors: typing.Optional[typing.List[typing.Any]] = None int_dtype: str = 'int64' float_dtype: str = 'fp32' legacy: bool = False )
處理多模態文字和視覺架構。