利用自推測解碼加速文字生成
自推測解碼,由LayerSkip: Enabling Early Exit Inference and Self-Speculative Decoding提出,是一種新穎的文字生成方法。它結合了推測解碼和大型語言模型(LLM)的提前退出優勢。該方法透過使用**同一模型**的早期層來起草標記,並使用後續層進行驗證,從而實現高效生成。
這項技術不僅加速了文字生成,還在記憶體和計算延遲方面取得了顯著的節省。為了實現端到端加速,早期層的輸出需要足夠接近最後一層。這透過一種訓練方案來實現,如論文所述,該方案可在預訓練期間應用,也可在特定領域進行微調時應用。自推測解碼對於實際應用特別高效,它使得模型能夠在較小的 GPU 上部署,並降低了**大規模推理**所需的總體硬體佔用。
在這篇部落格文章中,我們將探討自推測解碼的概念、其實現以及使用 🤗 transformers 庫的實際應用。您將瞭解其技術基礎,包括**提前退出層**、**反嵌入**和**訓練修改**。為了將這些概念付諸實踐,我們提供了程式碼示例、與傳統推測解碼的基準比較以及效能權衡的見解。
請直接檢視以下 Hugging Face 資源,以瞭解更多有關該方法的資訊並親身體驗:
推測解碼與自推測解碼
LayerSkip 推理在
facebook/layerskip-llama2-7B
(使用 LayerSkip 方案持續預訓練的 Llama2 7B)上的演示。
傳統的推測解碼使用**兩個**模型:一個較小的模型(草稿模型)生成一系列草稿標記,一個較大的模型(驗證模型)驗證草稿的準確性。較小的模型執行大部分生成工作,而較大的模型則進行結果修正。這加快了文字生成速度,因為較大的模型可以一次性驗證完整的序列,而不是一次生成一個草稿。
在自推測解碼中,作者在此概念的基礎上,使用大型模型的早期層來生成草稿標記,然後由模型的深層進行驗證。這種推測解碼的“自我”方面需要特定的訓練,使得模型能夠同時執行起草和驗證。反過來,這與傳統的推測解碼相比,提高了速度並降低了計算成本。
使用 transformers
為了在 🤗 transformers 庫中啟用提前退出自推測解碼,我們只需要在 generate()
函式中新增 assistant_early_exit
引數。
這是一個展示該功能的簡單程式碼片段。
pip install transformers
from transformers import AutoTokenizer, AutoModelForCausalLM
early_exit_layer = 4
prompt = "Alice and Bob"
checkpoint = "facebook/layerskip-llama2-7B"
tokenizer = AutoTokenizer.from_pretrained(checkpoint)
inputs = tokenizer(prompt, return_tensors="pt").to("cuda")
model = AutoModelForCausalLM.from_pretrained(checkpoint).to("cuda")
outputs = model.generate(**inputs, assistant_early_exit=early_exit_layer)
注意: 雖然
assistant_early_exit
引數理論上可以為任何僅解碼器 Transformer 啟用提前退出自推測解碼,但中間層的 logits 無法進行**反嵌入**(透過 LM Head 解碼的過程,稍後在部落格文章中描述),除非模型經過專門訓練。此外,只有當檢查點經過特殊訓練以提高早期層準確性時,您才能**獲得加速**。LayerSkip 論文提出了一種訓練方法來實現這一點(即,應用提前退出損失,並逐步增加層 dropout 率)。此處提供了一系列使用 LayerSkip 訓練方法持續預訓練的 Llama2、Llama3 和 Code Llama 檢查點。
基準測試
我們進行了一系列廣泛的基準測試,以衡量 LayerSkip 的自推測解碼相對於各種模型上的自迴歸解碼的加速效果。我們還比較了自推測解碼(基於提前退出)與標準推測解碼技術。要重現結果,您可以在此處找到程式碼,並在此電子表格中找到執行每個實驗的命令。所有實驗均在單個 80GB A100 GPU 上執行,Llama2 70B 實驗除外,它們在包含 8 個 A100 GPU 的節點上執行。
Llama3.2 1B
模型變體 | 層數 | 輔助模型 | 輔助層 | 任務 | 總層數 | FLOPs/輸入 (G) | 時間/輸入 (s) | FLOPs/輸出 (G) | 時間/輸出 (s) | 效率 |
---|---|---|---|---|---|---|---|---|---|---|
facebook/layerskip-llama3.2-1B | 1 | 提前退出 @ 第 4 層 | 摘要 | 1 | 1195.28 | 9.96 | 2147.7 | 17.9 | 1.80 |
Llama3 8B
模型變體 | 層數 | 輔助模型 | 輔助層 | 任務 | 總層數 | FLOPs/輸入 (G) | 時間/輸入 (s) | FLOPs/輸出 (G) | 時間/輸出 (s) | 效率 |
---|---|---|---|---|---|---|---|---|---|---|
meta-llama/Meta-Llama-3-8B | 8 | meta-llama/Llama-3.2-1B | 1 | 摘要 | 9 | 1872.46 | 19.04 | 2859.35 | 29.08 | 1.53 |
meta-llama/Meta-Llama-3-8B | 8 | meta-llama/Llama-3.2-3B | 3 | 摘要 | 11 | 2814.82 | 28.63 | 2825.36 | 28.73 | 1.00 |
facebook/layerskip-llama3-8B | 8 | 提前退出 @ 第 4 層 | 摘要 | 8 | 1949.02 | 15.75 | 3571.81 | 28.87 | 1.83 |
Llama2 70B
模型變體 | 層數 | 輔助模型 | 輔助層 | 任務 | 總層數 | FLOPs/輸入 (G) | 時間/輸入 (s) | FLOPs/輸出 (G) | 時間/輸出 (s) | 效率 |
---|---|---|---|---|---|---|---|---|---|---|
meta-llama/Llama-2-70b-hf | 70 | meta-llama/Llama-2-13b-hf | 13 | 摘要 | 83 | 5036.54 | 46.3 | 12289.01 | 112.97 | 2.44 |
meta-llama/Llama-2-70b-hf | 70 | meta-llama/Llama-2-7b-hf | 7 | 摘要 | 77 | 4357.55 | 40.06 | 12324.19 | 113.3 | 2.83 |
meta-llama/Llama-2-70b-hf | 70 | TinyLlama/TinyLlama_v1.1 | 1 | 摘要 | 71 | 4356.21 | 40.05 | 12363.22 | 113.66 | 2.84 |
facebook/layerskip-llama2-70B | 70 | 提前退出 @ 第 10 層 | 摘要 | 70 | 6012.04 | 54.96 | 1283.34 | 113.2 | 2.06 |
Llama2 13B
模型變體 | 層數 | 輔助模型 | 輔助層 | 任務 | 總層數 | FLOPs/輸入 (G) | 時間/輸入 (s) | FLOPs/輸出 (G) | 時間/輸出 (s) | 效率 |
---|---|---|---|---|---|---|---|---|---|---|
meta-llama/Llama-2-13b-hf | 13 | meta-llama/Llama-2-7b-hf | 7 | 摘要 | 20 | 3557.07 | 27.79 | 4088.48 | 31.94 | 1.15 |
meta-llama/Llama-2-13b-hf | 13 | TinyLlama/TinyLlama_v1.1 | 1 | 摘要 | 14 | 2901.92 | 22.67 | 4190.42 | 32.74 | 1.44 |
meta-llama/Llama-2-13b-hf | 13 | apple/OpenELM-270M | 0.27 | 摘要 | 13.27 | 2883.33 | 22.53 | 4521.12 | 35.32 | 1.57 |
meta-llama/Llama-2-13b-hf | 13 | apple/OpenELM-450M | 0.45 | 摘要 | 13.45 | 3267.69 | 25.53 | 4321.75 | 33.76 | 1.32 |
facebook/layerskip-llama2-13B | 13 | 提前退出 @ 第 4 層 | 摘要 | 13 | 4238.45 | 33.11 | 4217.78 | 32.95 | 0.995 | |
facebook/layerskip-llama2-13B | 13 | 提前退出 @ 第 8 層 | 摘要 | 13 | 2459.61 | 19.22 | 4294.98 | 33.55 | 1.746 |
Llama2 7B
模型變體 | 層數 | 輔助模型 | 輔助層 | 任務 | 總層數 | FLOPs/輸入 (G) | 時間/輸入 (s) | FLOPs/輸出 (G) | 時間/輸出 (s) | 效率 |
---|---|---|---|---|---|---|---|---|---|---|
meta-llama/Llama-2-7b-hf | 7 | TinyLlama/TinyLlama_v1.1 | 1 | 摘要 | 8 | 2771.54 | 21.65 | 3368.48 | 26.32 | 1.22 |
meta-llama/Llama-2-7b-hf | 7 | apple/OpenELM-270M | 0.27 | 摘要 | 7.27 | 2607.82 | 20.37 | 4221.14 | 32.98 | 1.62 |
meta-llama/Llama-2-7b-hf | 7 | apple/OpenELM-450M | 0.45 | 摘要 | 7.45 | 3324.68 | 25.97 | 4178.66 | 32.65 | 1.26 |
facebook/layerskip-llama2-7B | 7 | 提前退出 @ 第 4 層 | 摘要 | 7 | 2548.4 | 19.91 | 3306.73 | 25.83 | 1.297 |
我們可以從結果中得出以下觀察:
- 如“總引數數量”列所示,自推測解碼消耗的記憶體更少,因為它不需要單獨的草稿模型,並且草稿階段層的權重被重用。
- 對於除 Llama2 70B 之外的所有模型大小和生成,提前退出自推測解碼都比常規的雙模型推測解碼更快。
Llama2 70B 上自推測解碼速度提升相對有限的原因可能有多種,例如 LayerSkip Llama2 70B 檢查點持續預訓練的 token 數量較少(Llama2 70B 為 328M token,而 Llama2 7B 為 52B token)。但這仍然是未來研究需要深入探索的改進領域。儘管如此,70B 模型的自推測解碼仍顯著快於自迴歸解碼。
提前退出與反嵌入
自推測解碼的一項關鍵技術是提前退出,即生成過程可以在預設層停止。為了實現這一點,我們透過將這些層的 logits 投射到語言模型(LM)頭部來預測下一個 token,從而**反嵌入**它們。這使得模型能夠跳過後續層並提高推理時間。
反嵌入可以在任何 Transformer 層進行,將提前退出轉變為一種高效的 token 預測機制。一個自然而然的問題是:LM 頭如何適應反嵌入早期層的 logits,因為它最初是經過訓練只與最終層一起工作的?這就是訓練修改發揮作用的地方。
訓練修改:層 dropout 和 提前退出損失
在訓練階段,我們引入了**層 dropout**,它允許模型在訓練期間跳過某些層。dropout 率在深層中逐漸增加,使模型對後續層的依賴性降低,同時增強模型的泛化能力並加速訓練。
除了層 dropout,還應用了**提前退出損失**,以確保 LM 頭部學習反嵌入不同層。用於訓練具有提前退出功能的模型的總損失函式由每次退出(中間層)的標準化損失之和給出。該技術透過在所有層之間分配學習任務來實現高效訓練。
自起草與自驗證
訓練完成後,我們可以在推理期間應用自推測解碼。該過程始於**自起草**,其中透過從某個中間層提前退出生成標記。推測標記的數量定義了在此階段生成的草稿標記數量,而我們退出的層定義了草稿階段的大小和準確性。這兩個引數都可以在推理時根據速度和草稿階段準確性之間的權衡進行指定。
下一階段是**自驗證**,其中使用完整模型來驗證草稿標記。驗證模型重用了草稿模型的部分快取。如果草稿標記與驗證標記一致,它們就會被新增到最終輸出中,從而更好地利用我們系統的記憶體頻寬,因為使用完整模型生成一系列標記比驗證草稿的成本要高得多,只要有幾個標記匹配。
在自驗證階段,由於早期層的結果在起草階段已快取,因此僅計算剩餘層進行驗證。
最佳化:共享權重、共享 KV 快取和共享計算
自推測解碼顯著受益於快取重用,特別是**KV 快取**,它儲存在起草階段計算的鍵值對。此快取允許模型跳過冗餘計算,因為起草和驗證階段都使用相同的早期層。此外,**退出查詢快取**儲存退出層中的查詢向量,允許驗證從起草階段無縫繼續。
與傳統的雙模型推測解碼相比,提前退出自推測解碼可以從以下節省中受益:
- 共享權重:重用前 層的權重,用於起草和驗證。
- 共享 KV 快取:重用前 層的鍵值對,用於起草和驗證。
- 共享計算:透過使用**退出查詢快取**(只儲存退出層 的查詢向量),重用前 層的計算,從而使驗證過程無需計算從 到 的層。
KV 和退出查詢快取的結合,稱為**KVQ 快取**,減少了記憶體開銷並改善了推理延遲。
到目前為止,🤗 transformers 庫已在此拉取請求中實現了第一項最佳化(共享權重)。隨著使用此方法的模型數量增加,我們將考慮其他最佳化。如果您感興趣,請隨時提交 PR!
我們能多早退出?
草稿階段的提前退出層是一個超引數,我們可以在推理過程中進行調整或修改。
- 我們越早退出,草稿 token 的生成速度越快,但準確性越低。
- 我們越晚退出,草稿 token 的準確性越高,但生成速度越慢。
我們編寫了一個指令碼,以在不同的提前退出層上測量 A100 GPU 上的每秒 token 數。在下表中,我們繪製了不同 Llama 模型(包括 LayerSkip 和基線檢查點)的每秒 token 數與提前退出層之間的關係(您可以此處檢視完整的日誌)。
Llama3.2 1B
Llama3 8B
Code Llama3 34B
Code Llama3 7B
Llama2 70B
Llama2 13B
Llama2 7B
我們可以觀察到以下幾點:
- 對於未經過 LayerSkip 訓練方案預訓練或持續預訓練的基線檢查點,提前退出自推測解碼比自迴歸解碼慢。這是因為在大多數 LLM 的訓練過程中,早期層沒有被激勵去學習預測輸出,因此使用早期層生成標記的接受率會非常低。
- 另一方面,對於使用 LayerSkip 訓練持續預訓練的 Llama 檢查點,提前退出自推測解碼在至少部分層中比自迴歸解碼具有更高的加速效果。
- 對於大多數模型,除了 Llama3.2 1B,我們注意到在遍歷層時有一個規律:加速效果在前幾層開始較低,逐漸增加到最佳點,然後再次下降。
- 提前退出層的最佳點是我們在高預測準確性和低標記生成開銷之間達到最佳權衡的時候。這個最佳點取決於每個模型,也可能取決於提示或提示的領域。
這些觀察結果為進一步的實驗和探索提供了有趣的機會。我們鼓勵讀者在這些想法的基礎上進行構建,測試變體,並追求自己的研究。這些努力可以帶來寶貴的見解,併為該領域做出有意義的貢獻。
結論
LayerSkip 利用提前退出、層 dropout 和快取重用之間的協同作用,建立了一個快速高效的文字生成管道。透過訓練模型以反嵌入不同層的輸出,並利用快取最佳化驗證過程,該方法在速度和準確性之間取得了平衡。因此,它顯著縮短了大型語言模型的推理時間,同時保持了高質量的輸出。由於只使用一個模型作為草稿和驗證模型,它還減少了與傳統推測解碼技術相比的記憶體消耗。
自推測是一個令人興奮的領域,同一個 LLM 既可以生成草稿標記,又可以自行修正。其他自推測方法包括:
- 起草與驗證 (Draft & Verify):草稿階段涉及跳過預定的注意力層和前饋層。
- MagicDec:草稿階段使用 KV 快取的一個子集,這對於長上下文輸入很有用。
- 雅可比解碼 (Jacobi Decoding) 和 前瞻解碼 (Lookahead Decoding):草稿階段是一系列“猜測標記”,這些標記可以是隨機的,也可以是從 N-gram 查詢表中獲得的。