理解 Gemma 3n:MatFormer 如何讓你在一個模型中擁有多個模型
當我們談論部署大型語言模型時,話題幾乎總是落在一個熟悉的權衡上:你可以擁有一個更大、更智慧的模型,或者一個更小、更快、能適應你硬體的模型。這似乎是常識,對吧?你在效能與資源的曲線上選擇一個點,然後就固定下來。
但如果你不必如此呢?如果你可以訓練一個大模型,並免費獲得一整套小而高效能的模型呢?
這就是谷歌 Gemma 3n 背後的核心思想,它建立在一個名為 Matryoshka Transformer (套娃 Transformer),或稱 MatFormer 的迷人架構之上。這是一項巧妙的工程設計,改變了我們對模型效率的看法。
讓我們像上次在 https://huggingface.co/blog/rishiraj/kld-guided-quantization 中一樣,一起來分解這個問題。我們將從核心架構思想開始,逐步瞭解它如何在推理時為我們提供如此大的靈活性。
套娃原則:一個模型,多種尺寸
你知道下圖中的俄羅斯套娃嗎?開啟一個,會發現裡面有一個更小的、一模一樣的娃娃,再開啟那個,裡面還有一個。這正是 MatFormer 的完美心智模型。
在標準的 Transformer 模組中,前饋網路 (FFN) 有一個固定的中間層大小。例如,它可能接收一個 4096 維的輸入,將其擴充套件到一個 16384 維的中間層 (W_in
),然後再將其投影回 4096 維 (W_out
)。這些維度是固定的。
MatFormer 改變了這一點。在每個 Transformer 層內部,它不只有一個 FFN,而是一系列 *巢狀* 的 FFN。這不僅僅是概念上的巢狀,而是字面上的。較小 FFN 的權重矩陣是較大 FFN 的子矩陣。
讓我們具體點。如果最大的 FFN (我們稱之為尺寸 S
) 的權重矩陣是 W_in
(4096x16384) 和 W_out
(16384x4096),那麼下一個較小的 FFN (S/2
) 將只使用這些矩陣的左上部分——比如說,W_in
的前 8192 列和 W_out
的前 8192 行。S/4
FFN 將使用前 4096 列/行,依此類推。它們物理上嵌入在同一個引數塊內。
那麼,如何訓練這樣的東西而不會讓較小的網路落後呢?
訣竅在於訓練過程,這是一種隨機深度或隨機路徑訓練的形式。在每個訓練步驟中,對於每一層,模型會隨機選擇一個“容量因子”——S
、S/2
、S/4
等。該層的輸入隨後只通過那個特定的子網路進行前向傳播。有一次,一個輸入可能在第 1 層透過 S/2
FFN,在第 2 層透過 S/8
FFN。下一次,它可能在兩層都使用完整的 S
FFN。
透過讓每個子模組都有平等的機會看到資料、計算梯度和更新其權重,訓練確保了 *所有* 子模組都變得有能力。較小的網路不僅僅是弱的近似;它們是經過明確和穩健訓練的。結果是,你不僅僅是在訓練一個大模型。你同時在訓練指數級數量的、更小的、有效的、並且都巢狀在同一組權重內的子模型。
回報:推理時“選擇你的戰士”
現在看看下面的架構,因為這是架構優雅在實踐中得到回報的地方。因為每個子模型都是一個經過充分訓練、可行的網路,所以在執行模型時你會獲得令人難以置信的靈活性。
1. 簡單的縮小: 假設你訓練了一個大模型,但需要將其部署在只有四分之一記憶體的裝置上。使用 MatFormer,你可以簡單地決定在 *每一層* 都使用 S/4
大小的 FFN 子模組。你會立即得到一個大小約為原始模型 1/4 的模型。至關重要的是,由於這個配置是經過明確訓練的,它的效能明顯優於在那個較小尺寸下從頭開始訓練的獨立模型。它得益於與更大、更強大的路徑共同訓練所帶來的“知識轉移”。
2. “混搭”傑作: 這才是真正有趣的地方。在 Transformer 中,並非所有層對每個任務的貢獻都相同。早期層可能處理語法和區域性模式,而深層則管理更抽象的語義推理。
使用 MatFormer,你可以在不同層之間“混搭”子模組,以建立定製的架構。你可以對模型進行效能分析,找到對你的任務最關鍵的層,併為它們分配更大的 FFN (如 S
或 S/2
),同時透過使用較小的 FFN (如 S/8
) 來節省不那麼關鍵的層的容量。
例如,如果你確定第 5 層對於處理翻譯任務中的語法細微差別至關重要,你可以為其分配完整的 S
FFN。但如果第 20 層的影響較小,你可以將其縮小到 S/8
,從而在對該特定任務的效能損失最小的情況下,節省大量的計算和記憶體。這使你能夠構建一個定製的模型,以最佳方式平衡效能和資源使用。
記憶體魔法:50 億引數如何裝入 20 億引數的記憶體佔用空間
所以,我們有了 MatFormer 這種靈活的計算結構。但 Gemma 3n 還有另一個錦囊妙計,而且完全關乎記憶體。你可能已經看到 Gemma 3n 2B 模型 (E2B) 實際上有大約 50 億個真實引數,但它佔用的 GPU 記憶體卻與典型的 2B 模型相當。這怎麼可能?
答案是 逐層嵌入 (Per-Layer Embeddings, PLE)。
在標準的語言模型中,詞元嵌入表是一個單一、龐大的記憶體塊。它是一個大小為 詞彙表大小 x 隱藏層維度
的巨型查詢表,必須駐留在你的 GPU 視訊記憶體 (VRAM) 中。讓我們用數字來說明。對於一個擁有 256,000 個詞元的詞彙表和 2048 維隱藏層的模型,使用 bfloat16 (每個引數 2 位元組),僅嵌入表就需要 256,000 * 2048 * 2 位元組 ≈ 1.05 GB
。在你處理單個詞元之前,這是一個巨大的、靜態的成本。
PLE 巧妙地避開了這個問題,它將嵌入權重從高速但稀缺的 GPU VRAM 解除安裝到容量大得多但速度較慢的 CPU RAM。當模型需要處理一個輸入序列時,它不會載入整個表。相反,它只通過 PCIe 匯流排將該序列中詞元特定的嵌入向量從 CPU 拉到 GPU。
這是一個經典的工程權衡。你接受了從 CPU 到 GPU 資料傳輸帶來的微小延遲,但作為回報,你釋放了大量的 VRAM。這使得一個擁有更大 *真實* 引數數量的模型能夠在受限的記憶體預算內執行。
這正是 Gemma 3n 家族的構建方式。4B 模型 (E4B,實際上是 5.44B 引數) 是完整的模型。2B 模型 (E2B) 是其內部的一個子網路,透過結合兩樣東西建立:
- MatFormer: 選擇較小的 FFN 子模組以減少計算量和活動引數數量。
- 逐層嵌入 (Per-Layer Embeddings): 使用記憶體解除安裝來管理完整的 5B 引數集的記憶體佔用。
最後一塊拼圖:利用 KV 快取共享加速長上下文
對於涉及長序列的任務,如總結文件或處理長音訊剪輯,鍵值 (KV) 快取通常是主要瓶頸。在自迴歸生成中,模型會儲存所有先前詞元的計算出的鍵 (Key) 和值 (Value),這樣就不必為每個新詞元重新計算它們。
這個快取的大小與序列長度成線性增長,並且可能變得非常巨大:序列長度 * 層數 * 注意力頭數 * 注意力頭維度 * 2
。對於非常長的上下文,這個快取很容易超過可用的 VRAM。
Gemma 3n 使用 KV 快取共享 來緩解這個問題,尤其是在多模態輸入中。這項技術允許模型的不同部分或不同模態 (例如,音訊和文字) 重用或共享此快取的部分。透過避免冗餘儲存,它顯著減少了記憶體壓力並加速了“預填充”階段——即對整個輸入提示的初始、昂貴的處理。不過,我目前對這部分的技術理解還不夠深入,希望以後能瞭解更多。
融會貫通
Gemma 3n 不僅僅是模型排行榜上的又一個點。它展示了智慧、高效的架構設計。透過結合:
- MatFormer: 實現靈活、巢狀的計算結構,在一個模型中為你提供指數級數量的模型。
- 逐層嵌入 (Per-Layer Embeddings): 實現巧妙的記憶體管理,讓更大的模型能適應更小的空間。
- KV 快取共享: 加速長上下文、多模態任務。
...你得到了一個天生具有適應性的系統。它讓我們擺脫了僵化的“一刀切”方法,賦予開發者權力,可以為他們的特定應用、硬體,甚至是特定輸入選擇正確的權衡。這是一個強有力的提醒:最激動人心的創新不總是關於規模的擴大,也關乎 *更聰明* 地擴充套件。