Optimum 文件
概覽
並獲得增強的文件體驗
開始使用
概述
🤗 Optimum 提供了一個名為 BetterTransformer 的 API,它是標準 PyTorch Transformer API 的快速路徑,可以透過稀疏性和融合核(如 Flash Attention)在 CPU 和 GPU 上獲得顯著的速度提升。目前,BetterTransformer 支援原生 nn.TransformerEncoderLayer
的快速路徑,以及 torch.nn.functional.scaled_dot_product_attention
的 Flash Attention 和 Memory-Efficient Attention。
快速入門
自 1.13 版本以來,PyTorch 釋出了其標準 Transformer API 的快速路徑穩定版本,為基於 Transformer 的模型提供了開箱即用的效能改進。您可以在大多數消費型裝置(包括 CPU、NVIDIA GPU 的舊版本和新版本)上獲得顯著的速度提升。現在,您可以在 🤗 Optimum 中將此功能與 Transformers 一起使用,並將其用於 Hugging Face 生態系統中的主要模型。
在 2.0 版本中,PyTorch 包含一個原生縮放點積注意力運算子(SDPA),作為 torch.nn.functional
的一部分。此函式包含多種實現,可根據輸入和所使用的硬體進行應用。有關更多資訊,請參閱官方文件,並參閱此部落格文章以瞭解基準測試。
我們在 🤗 Optimum 中提供了對這些最佳化的開箱即用整合,以便您可以轉換任何受支援的 🤗 Transformers 模型,以便在相關時使用最佳化路徑和 scaled_dot_product_attention
函式。
因此,預設情況下,在訓練模式下,BetterTransformer 整合會**放棄對掩碼的支援,並且只能用於不需要填充掩碼的批處理訓練**。例如,這適用於掩碼語言建模或因果語言建模。BetterTransformer 不適用於對需要填充掩碼的任務進行模型微調。
在推理模式下,為了正確性保留了填充掩碼,因此只有在批處理大小 = 1 的情況下才能預期速度提升。
支援的模型
以下是支援的模型列表
- AlBERT
- Bark
- BART
- BERT
- BERT-generation
- BLIP-2
- BLOOM
- CamemBERT
- CLIP
- CodeGen
- Data2VecText
- DistilBert
- DeiT
- Electra
- Ernie
- Falcon (無需使用 BetterTransformer,它直接受 Transformers 支援)
- FSMT
- GPT2
- GPT-j
- GPT-neo
- GPT-neo-x
- GPT BigCode (SantaCoder, StarCoder - 無需使用 BetterTransformer,它直接受 Transformers 支援)
- HuBERT
- LayoutLM
- Llama & Llama2 (無需使用 BetterTransformer,它直接受 Transformers 支援)
- MarkupLM
- Marian
- MBart
- M2M100
- OPT
- ProphetNet
- RemBERT
- RoBERTa
- RoCBert
- RoFormer
- Splinter
- Tapas
- ViLT
- ViT
- ViT-MAE
- ViT-MSN
- Wav2Vec2
- Whisper (無需使用 BetterTransformer,它直接受 Transformers 支援)
- XLMRoberta
- YOLOS
如果您希望支援更多模型,請在 🤗 Optimum 中提出 issue,或者如果您想自己新增模型,請檢視貢獻指南!
快速使用
要使用 BetterTransformer
API,只需執行以下命令
>>> from transformers import AutoModelForSequenceClassification
>>> from optimum.bettertransformer import BetterTransformer
>>> model_hf = AutoModelForSequenceClassification.from_pretrained("bert-base-cased")
>>> model = BetterTransformer.transform(model_hf, keep_original_model=True)
如果您想用 BetterTransformer
版本覆蓋當前模型,可以將 keep_original_model=False
。
有關如何深入理解其用法的更多詳細資訊,請參閱 tutorials
部分,或檢視 Google colab 演示!