通用輔助生成:使用任意輔助模型實現更快解碼
TL;DR:許多 LLM,如
gemma-2-9b
和 Mixtral-8x22B-Instruct-v0.1
,缺乏一個更小的版本可用於輔助生成。在本篇部落格文章中,我們介紹了**通用輔助生成**:一種由 Intel Labs 和 Hugging Face 開發的方法,它將輔助生成擴充套件到**任意模型家族**的小型語言模型 🤯。因此,現在可以通過幾乎零開銷的方式將**任何**解碼器或混合專家模型的推理速度提升 **1.5 倍至 2.0 倍** 🔥🔥🔥。讓我們深入瞭解!
引言
如今,最強大的開源 LLM 通常擁有數十億到數千億個引數(你好 Llama-3.1-405B 👋),在生產環境中部署這些龐然大物帶來了一系列工程挑戰。其中一個挑戰是這些大型模型的文字生成速度慢,這促使社群開發了各種技術來加速解碼過程。輔助生成,也稱為推測解碼,是一種非常流行且實用的加速 LLM 推理而不損失準確性的方法。在這篇部落格文章中,我們將探討輔助生成的工作原理,並分享我們將其擴充套件到 Hugging Face Hub 上14 萬個語言模型中任意模型的研究成果 🚀!
輔助生成
輔助生成的核心思想是使用一對模型,即目標模型和輔助模型。輔助模型是目標模型的一個更小、更高效的版本,例如,你可以使用 Llama-3.2-1B
作為更大的 Llama-3.1-70b
目標模型的輔助模型。輔助生成是一個迭代過程。每個迴圈中,輔助模型自迴歸地逐個生成一系列 token。然後,目標模型在一次前向傳播中驗證序列中的所有輔助 token。透過在目標模型的每次前向傳播中確認多個 token,而不是每次只生成一個 token,從而實現加速。有關更詳細的解釋,請參閱原始部落格文章。結合最近引入的動態推測策略,輔助生成可將文字生成速度提高 1.5 倍到 3 倍,具體取決於任務和所使用的模型。
輔助生成帶來的顯著加速伴隨著一個顯著的缺點:目標模型和輔助模型必須共享相同的分詞器,這意味著它們需要來自同一個模型家族。然而,許多廣泛使用的模型缺乏足夠小巧且準確的小型版本,無法大幅減少延遲。根據我們的經驗,當輔助模型比目標模型小 50-100 倍時,通常才能看到有意義的加速。例如,CodeLlama-13b
缺乏一個更小的版本,而 gemma-2-9b
只有一個 2b
變體,它仍然不夠小/快,無法實現顯著的效能提升。
通用輔助生成
為了緩解這一痛點,Intel Labs 與 Hugging Face 的朋友們共同開發了通用輔助生成(UAG)。UAG 允許選擇任意一對目標模型和輔助模型,無論它們的分詞器如何。例如,可以將 gemma-2-9b
用作目標模型,而使用微小的 vicuna-68m
作為輔助模型。
我們提出的方法主要思想是雙向分詞器轉換。一旦輔助模型完成一次生成迭代,輔助 token 將被轉換為文字,然後使用目標模型的分詞器進行分詞以生成目標 token。在驗證步驟之後,目標 token 同樣被轉換回輔助 token 格式,然後將其附加到輔助模型的上下文,再開始下一次迭代。
由於輔助分詞器和目標分詞器使用不同的詞彙表,因此有必要處理它們之間的差異。為了準確地重新編碼新生成的輔助 token,需要預先新增一個包含幾個先前 token 的上下文視窗。然後將整個序列重新編碼為目標 token 格式,並與最新的目標 token 對齊,以精確定位新生成的 token 應附加的位置。這個過程在下面的影片中有所說明。
雖然上述影片中未顯示,但從目標模型到輔助模型的 token 重編碼遵循類似的過程。然而,不匹配的 token 必須從輔助模型的鍵值 (KV) 快取中丟棄,以確保資料完整性。
基準測試
下表顯示了目標模型與使用不同分詞器的輔助模型配對時觀察到的延遲改進。
目標模型 | 輔助模型 | 資料集 | 任務 | 加速比 |
---|---|---|---|---|
codellama/CodeLlama-13b-Instruct-hf |
bigcode/tiny_starcoder_py |
openai/humaneval |
程式碼生成 | 1.90倍 |
mistralai/Mixtral-8x22B-Instruct-v0.1 |
double7/vicuna-68m |
cnn_dailymail |
摘要 | 1.52倍 |
google/gemma-2-9b |
double7/vicuna-68m |
cnn_dailymail |
摘要 | 1.76倍 |
mistralai/Mixtral-8x22B-Instruct-v0.1 |
Qwen/Qwen2-0.5B-Instruct |
tau/scrolls |
長上下文摘要 | 1.78倍 |
meta-llama/Llama-3.1-70B |
Qwen/Qwen2-0.5B-Instruct |
tau/scrolls |
長上下文摘要 | 1.78倍 |
microsoft/Phi-3-medium-128k-instruct |
Qwen/Qwen2-0.5B-Instruct |
tau/scrolls |
長上下文摘要 | 1.91倍 |
請注意,上述目標模型沒有適合使用標準輔助生成進行加速的小型變體(小於 10 億引數)。
每個實驗都在 100 個隨機選擇的示例上進行。使用 Llama
和 Mixtral
目標模型的實驗分別使用 2 塊和 4 塊 A100 GPU。所有其他實驗均使用單塊 A6000 GPU 執行。
程式碼
通用輔助生成已整合到 🤗 Transformers 的 4.46.0 版本中。
要使用,將 tokenizer
和 assistant_tokenizer
傳遞給 generate()
>>> from transformers import AutoModelForCausalLM, AutoTokenizer
>>> prompt = "Alice and Bob"
>>> checkpoint = "google/gemma-2-9b"
>>> assistant_checkpoint = "double7/vicuna-68m"
>>> assistant_tokenizer = AutoTokenizer.from_pretrained(assistant_checkpoint)
>>> tokenizer = AutoTokenizer.from_pretrained(checkpoint)
>>> inputs = tokenizer(prompt, return_tensors="pt")
>>> model = AutoModelForCausalLM.from_pretrained(checkpoint)
>>> assistant_model = AutoModelForCausalLM.from_pretrained(assistant_checkpoint)
>>> outputs = model.generate(**inputs, assistant_model=assistant_model, tokenizer=tokenizer, assistant_tokenizer=assistant_tokenizer)
>>> tokenizer.batch_decode(outputs, skip_special_tokens=True)
['Alice and Bob are sitting in a bar. Alice is drinking a beer and Bob is drinking a']
未來方向
雖然在使用標準輔助生成時,傳遞 `do_sample=True` 會使用推測性取樣演算法(論文中的演算法 1),但 UAG 目前僅支援多項式取樣。在多項式取樣中,如果目標模型沒有采樣到與輔助模型相同的 token,則該 token 會自動被拒絕,而推測性取樣則不然。實際上,這意味著 UAG 在 `do_sample=True` 模式下的吞吐量將低於輔助模型具有相同分詞器的情況。未來,我們計劃為 UAG 新增推測性取樣支援。此外,我們打算將 UAG 整合到 🤗 Transformers 管道中,以實現更簡潔和流線型的使用方式。