Optimum 文件
如何使用 ONNX Runtime 加速訓練
並獲得增強的文件體驗
開始使用
如何使用 ONNX Runtime 加速訓練
Optimum 透過 ORTTrainer
API 集成了 ONNX Runtime 訓練,該 API 擴充套件了 Transformers 中的 Trainer
。透過此擴充套件,與 Eager 模式下的 PyTorch 相比,許多流行的 Hugging Face 模型的訓練時間可以減少 35% 以上。
ORTTrainer
和 ORTSeq2SeqTrainer
API 使得將 ONNX Runtime (ORT) 與 Trainer
中的其他功能輕鬆組合。它包含功能完整的訓練迴圈和評估迴圈,並支援超引數搜尋、混合精度訓練以及使用多個 NVIDIA 和 AMD GPU 的分散式訓練。藉助 ONNX Runtime 後端,ORTTrainer
和 ORTSeq2SeqTrainer
可利用以下優勢:
- 計算圖最佳化:常量摺疊、節點消除、節點融合
- 高效記憶體規劃
- 核心最佳化
- ORT 融合 Adam 最佳化器:將應用於所有模型引數的逐元素更新批處理為一個或幾個核心啟動
- 更高效的 FP16 最佳化器:消除了大量的裝置到主機記憶體複製
- 混合精度訓練
嘗試一下,在 🤗 Transformers 中訓練模型時實現更低延遲、更高吞吐量和更大的最大批大小!
效能
下表顯示了當使用 ONNX Runtime 和 DeepSpeed ZeRO Stage 1 進行訓練時,Optimum 對 Hugging Face 模型實現了令人印象深刻的加速,從 39% 到 130%。效能測量是在選定的 Hugging Face 模型上進行的,其中 PyTorch 作為基線執行,僅 ONNX Runtime 進行訓練作為第二次執行,ONNX Runtime + DeepSpeed ZeRO Stage 1 作為最終執行,顯示出最大收益。基線 PyTorch 執行使用的最佳化器是 AdamW 最佳化器,而 ORT 訓練執行使用融合 Adam 最佳化器(在 ORTTrainingArguments
中可用)。執行在一臺具有 8 個 GPU 的 Nvidia A100 節點上執行。

這些執行中使用的版本資訊如下:
PyTorch: 1.14.0.dev20221103+cu116; ORT: 1.14.0.dev20221103001+cu116; DeepSpeed: 0.6.6; HuggingFace: 4.24.0.dev0; Optimum: 1.4.1.dev0; Cuda: 11.6.2
開始設定環境
要使用 ONNX Runtime 進行訓練,您需要一臺至少配備一個 NVIDIA 或 AMD GPU 的機器。
要使用 ORTTrainer
或 ORTSeq2SeqTrainer
,您需要安裝 ONNX Runtime Training 模組和 Optimum。
安裝 ONNX Runtime
為了設定環境,我們**強烈建議**您使用 Docker 安裝依賴項,以確保版本正確且配置良好。您可以在此處找到各種組合的 Dockerfile。
NVIDIA GPU 的設定
以下我們以安裝 onnxruntime-training 1.14.0
為例
- 如果您想透過 Dockerfile 安裝
onnxruntime-training 1.14.0
docker build -f Dockerfile-ort1.14.0-cu116 -t ort/train:1.14.0 .
pip install onnx ninja pip install torch==1.13.1+cu116 torchvision==0.14.1 -f https://download.pytorch.org/whl/cu116/torch_stable.html pip install onnxruntime-training==1.14.0 -f https://download.onnxruntime.ai/onnxruntime_stable_cu116.html pip install torch-ort pip install --upgrade protobuf==3.20.2
並執行安裝後配置
python -m torch_ort.configure
AMD GPU 的設定
以下我們以安裝 onnxruntime-training
nightly 為例
- 如果您想透過 Dockerfile 安裝
onnxruntime-training
docker build -f Dockerfile-ort-nightly-rocm57 -t ort/train:nightly .
- 如果您想在本地 Python 環境中安裝其他依賴項。您可以在成功安裝 ROCM 5.7 後,透過 pip 安裝它們。
pip install onnx ninja pip3 install --pre torch torchvision torchaudio --index-url https://download.pytorch.org/whl/nightly/rocm5.7 pip install pip install --pre onnxruntime-training -f https://download.onnxruntime.ai/onnxruntime_nightly_rocm57.html pip install torch-ort pip install --upgrade protobuf==3.20.2
並執行安裝後配置
python -m torch_ort.configure
安裝 Optimum
您可以透過 pypi 安裝 Optimum
pip install optimum
或從原始碼安裝
pip install git+https://github.com/huggingface/optimum.git
此命令安裝 Optimum 的當前主開發版本,其中可能包含最新的開發(新功能、錯誤修復)。但是,主版本可能不太穩定。如果遇到任何問題,請提出 問題,以便我們儘快解決。
ORTTrainer
ORTTrainer
類繼承了 Transformers 的 Trainer
。您可以透過用 ORTTrainer
替換 transformers 的 Trainer
來輕鬆調整程式碼,以利用 ONNX Runtime 帶來的加速。以下是與 Trainer
相比如何使用 ORTTrainer
的示例:
-from transformers import Trainer, TrainingArguments
+from optimum.onnxruntime import ORTTrainer, ORTTrainingArguments
# Step 1: Define training arguments
-training_args = TrainingArguments(
+training_args = ORTTrainingArguments(
output_dir="path/to/save/folder/",
- optim = "adamw_hf",
+ optim="adamw_ort_fused",
...
)
# Step 2: Create your ONNX Runtime Trainer
-trainer = Trainer(
+trainer = ORTTrainer(
model=model,
args=training_args,
train_dataset=train_dataset,
+ feature="text-classification",
...
)
# Step 3: Use ONNX Runtime for training!🤗
trainer.train()
請檢視 Optimum 倉庫中更詳細的示例指令碼。
ORTSeq2SeqTrainer
ORTSeq2SeqTrainer
類類似於 Transformers 的 Seq2SeqTrainer
。您可以透過用 ORTSeq2SeqTrainer
替換 transformers 的 Seq2SeqTrainer
來輕鬆調整程式碼,以利用 ONNX Runtime 帶來的加速。以下是與 Seq2SeqTrainer
相比如何使用 ORTSeq2SeqTrainer
的示例
-from transformers import Seq2SeqTrainer, Seq2SeqTrainingArguments
+from optimum.onnxruntime import ORTSeq2SeqTrainer, ORTSeq2SeqTrainingArguments
# Step 1: Define training arguments
-training_args = Seq2SeqTrainingArguments(
+training_args = ORTSeq2SeqTrainingArguments(
output_dir="path/to/save/folder/",
- optim = "adamw_hf",
+ optim="adamw_ort_fused",
...
)
# Step 2: Create your ONNX Runtime Seq2SeqTrainer
-trainer = Seq2SeqTrainer(
+trainer = ORTSeq2SeqTrainer(
model=model,
args=training_args,
train_dataset=train_dataset,
+ feature="text2text-generation",
...
)
# Step 3: Use ONNX Runtime for training!🤗
trainer.train()
請檢視 Optimum 倉庫中更詳細的示例指令碼。
ORTTrainingArguments
ORTTrainingArguments
類繼承了 Transformers 中的 TrainingArguments
類。除了 Transformers 中實現的最佳化器外,它還允許您使用 ONNX Runtime 中實現的最佳化器。將 Seq2SeqTrainingArguments
替換為 ORTSeq2SeqTrainingArguments
。
-from transformers import TrainingArguments
+from optimum.onnxruntime import ORTTrainingArguments
-training_args = TrainingArguments(
+training_args = ORTTrainingArguments(
output_dir="path/to/save/folder/",
num_train_epochs=1,
per_device_train_batch_size=8,
per_device_eval_batch_size=8,
warmup_steps=500,
weight_decay=0.01,
logging_dir="path/to/save/folder/",
- optim = "adamw_hf",
+ optim="adamw_ort_fused", # Fused Adam optimizer implemented by ORT
)
ONNX Runtime 支援 DeepSpeed(目前僅支援 ZeRO 階段 1 和 2)。您可以在 Optimum 儲存庫中找到一些 DeepSpeed 配置示例。
ORTSeq2SeqTrainingArguments
ORTSeq2SeqTrainingArguments
類繼承了 Transformers 中的 Seq2SeqTrainingArguments
類。除了 Transformers 中實現的最佳化器外,它還允許您使用 ONNX Runtime 中實現的最佳化器。將 Seq2SeqTrainingArguments
替換為 ORTSeq2SeqTrainingArguments
。
-from transformers import Seq2SeqTrainingArguments
+from optimum.onnxruntime import ORTSeq2SeqTrainingArguments
-training_args = Seq2SeqTrainingArguments(
+training_args = ORTSeq2SeqTrainingArguments(
output_dir="path/to/save/folder/",
num_train_epochs=1,
per_device_train_batch_size=8,
per_device_eval_batch_size=8,
warmup_steps=500,
weight_decay=0.01,
logging_dir="path/to/save/folder/",
- optim = "adamw_hf",
+ optim="adamw_ort_fused", # Fused Adam optimizer implemented by ORT
)
ONNX Runtime 支援 DeepSpeed(目前僅支援 ZeRO 階段 1 和 2)。您可以在 Optimum 儲存庫中找到一些 DeepSpeed 配置示例。
ORTModule+StableDiffusion
Optimum 支援在此示例中透過 ONNX Runtime 加速 Hugging Face Diffusers。啟用 ONNX Runtime 訓練所需的核心更改總結如下:
import torch
from diffusers import AutoencoderKL, UNet2DConditionModel
from transformers import CLIPTextModel
+from onnxruntime.training.ortmodule import ORTModule
+from onnxruntime.training.optim.fp16_optimizer import FP16_Optimizer as ORT_FP16_Optimizer
unet = UNet2DConditionModel.from_pretrained(
"CompVis/stable-diffusion-v1-4",
subfolder="unet",
...
)
text_encoder = CLIPTextModel.from_pretrained(
"CompVis/stable-diffusion-v1-4",
subfolder="text_encoder",
...
)
vae = AutoencoderKL.from_pretrained(
"CompVis/stable-diffusion-v1-4",
subfolder="vae",
...
)
optimizer = torch.optim.AdamW(
unet.parameters(),
...
)
+vae = ORTModule(vae)
+text_encoder = ORTModule(text_encoder)
+unet = ORTModule(unet)
+optimizer = ORT_FP16_Optimizer(optimizer)
其他資源
如果您對 ORTTrainer
有任何問題,請在 Optimum Github 上提出問題或在 HuggingFace 的社群論壇上與我們討論,祝好 🤗!