在 nanoVLM 中從零開始實現 KV Cache
TL;DR
我們已經在 nanoVLM 倉庫(一個使用純 PyTorch 訓練自己的視覺語言模型的小型程式碼庫)中從零開始實現了 KV 快取。這使我們的生成速度提升了 38%。在這篇部落格文章中,我們將介紹 KV 快取以及我們在實現它時獲得的所有經驗。所學到的經驗是通用的,可以應用於所有自迴歸語言模型的生成。在一個小型程式碼庫上從零開始實現是一個很好的學習經驗,歡迎加入!
引言
自迴歸語言模型透過**一次取樣一個 token** 來生成文字。在推理過程中,模型處理給定的輸入序列,預測下一個 token,將其附加到序列中,並重復此過程直到滿足某個停止條件。
這種逐步生成本質上是順序的。
- 為了生成 token ,模型必須考慮從 到 的整個序列。在上述示例中, 將是
the,而所有之前的 token 到 將是[What, is, in]。 - 儘管 Transformer 內部是並行的,但每個新的預測都需要對所有 Transformer 層進行一次完整的正向傳播,這會帶來與序列長度呈二次方的記憶體/計算開銷。
這種重複也會導致計算上的**冗餘**。在這篇文章中,我們將探討**KV 快取**,這是一種緩解這種低效率的最佳化技術。
目錄
重溫 Transformer 架構
在深入探討快取之前,讓我們回顧一下 Transformer 模型中注意力的運作方式。Transformer 語言模型由堆疊層組成,每層包含:
- 多頭自注意力
- 前饋網路 (MLP)
- 殘差連線和層歸一化
為了理解**KV 快取的幫助之處**,我們重點關注**自注意力**機制,特別是單個注意力頭內部。
讓我們透過一個簡單的 PyTorch 實現來視覺化關鍵計算。
import torch
input_seq_length = 5
dim_model = 10
input_ids_emb = torch.randn(input_seq_length, dim_model)
W_q = torch.randn(dim_model, dim_model)
W_k = torch.randn(dim_model, dim_model)
W_v = torch.randn(dim_model, dim_model)
Q = input_ids_emb @ W_q
K = input_ids_emb @ W_k
V = input_ids_emb @ W_v
自注意力計算
對於 個輸入嵌入序列,表示為 ,自注意力計算如下:
- ,其中
- ,其中
- ,其中
- 因果掩碼 用於防止訪問未來 token。
最終輸出為
這是一個使用因果掩碼的最小 PyTorch 等效實現。
import torch.nn.functional as F
import math
d_k = K.shape[-1]
attention_scores = (Q @ K.T) / math.sqrt(d_k)
# Lower triangular mask to prevent future token access
causal_mask = torch.tril(torch.ones(input_seq_length, input_seq_length))
masked_scores = attention_scores.masked_fill(causal_mask == 0, float('-inf'))
attention_weights = F.softmax(masked_scores, dim=-1)
output = attention_weights @ V
冗餘之處
在自迴歸生成中,模型每次生成一個 token。在每一步中,它都會**為整個序列**重新計算 、 和 ,即使較早的 token 並未改變。
new_token_emb = torch.randn(1, dim_model)
extended_input = torch.cat([input_ids_emb, new_token_emb], dim=0)
Q_ext = extended_input @ W_q
K_ext = extended_input @ W_k
V_ext = extended_input @ W_v
# (output_ext would be computed using Q_ext, K_ext, V_ext + masking)
為了確認冗餘
torch.testing.assert_close(K, K_ext[:input_seq_length]) # test pass
torch.testing.assert_close(V, V_ext[:input_seq_length]) # test pass
這些檢查表明,對於除最新 token 之外的所有 token, 和 與先前計算的值相同。
Original (5×5): Extended (6×6):
■ ■ ■ ■ ■ ■ ■ ■ ■ ■ □
■ ■ ■ ■ ■ ■ ■ ■ ■ ■ □
■ ■ ■ ■ ■ → ■ ■ ■ ■ ■ □
■ ■ ■ ■ ■ ■ ■ ■ ■ ■ □
■ ■ ■ ■ ■ ■ ■ ■ ■ ■ □
□ □ □ □ □ □
- **■** = 已計算並重復使用
- **□** = 不必要地重新計算
大部分注意力計算被不必要地重複。隨著序列的增長,這會變得更加昂貴。
KV 快取如何解決它
為了消除這種低效率,我們使用 **KV 快取**:
- 在處理完初始提示後,我們**快取**每個層計算出的鍵 () 和值 ()。
- 在生成過程中,我們**只計算新 token 的** **和** ,並**將其附加**到快取中。
- 我們計算當前 token 的 ,並將其與**快取的 和 ** 一起使用以獲得輸出。
這使得生成從全序列重新計算變為輕量級的增量更新。
✅ 在實踐中,此快取是一個逐層字典,包含“key”和“value”,每個形狀為 (
batch_size,num_heads,seq_len_cached,head_dim)。
這是現代 LLM 如何高效生成長輸出的基礎。
nanoVLM 中的 KV 快取:從理論到實踐
既然我們已經理解了 KV 快取背後的理論,接下來讓我們看看它在我們的 nanoVLM 倉庫中是如何實際實現的。這是一個理想的測試平臺,因為它是一個超級簡潔且自包含的程式碼庫。
KV 快取體現在我們模型的三個關鍵元件中:
- 使用和更新 KV 快取的**注意力塊**
- 跟蹤每層快取的**語言模型**
- 區分**預填充**(使用輸入提示的初始傳遞)和順序**解碼**階段的**生成迴圈**
1. 在注意力塊中更新 KV 快取
在 `LanguageModelGroupedAttention` 類中,我們修改了 `forward` 函式,使其接受並更新鍵和值(`block_kv_cache`)的快取。
以前,模型在每個生成步驟都會重新計算 和 。現在我們只計算當前 token 的 和 ,並將其附加到快取的值中。
def forward(self, x, cos, sin, attention_mask=None, block_kv_cache=None):
is_prefill = block_kv_cache is None
B, T_curr, C = x.size()
# Project inputs to Q, K, V
q_curr, k_curr, v_curr = project_current_tokens(x)
q, k_rotated = apply_rotary_pos_embd(q_curr, k_curr, cos, sin)
if not is_prefill and block_kv_cache['key'] is not None:
# Append new keys and values to the cache
k = torch.cat([block_kv_cache['key'], k_rotated], dim=2)
v = torch.cat([block_kv_cache['value'], v_curr], dim=2)
else:
# First pass (prefill) — no cache
k, v = k_rotated, v_curr
block_kv_cache = {'key': k, 'value': v}
return attention_output, block_kv_cache
2. 跨層跟蹤快取
在 `LanguageModel` 類中,我們引入了**逐層快取跟蹤**。`start_pos` 引數有助於模型為新生成的 token 計算正確的**旋轉位置編碼**。
def forward(self, x, kv_cache=None, start_pos=0):
T_curr = x.size(1)
position_ids = torch.arange(start_pos, start_pos + T_curr, device=x.device)
cos, sin = self.rotary_embd(position_ids)
for i, block in enumerate(self.blocks):
# Pass per-layer KV cache
x, kv_cache[i] = block(x, cos, sin, attention_mask, kv_cache[i])
return x, kv_cache
- `kv_cache`:一個字典列表,每個 transformer 層一個,儲存著先前的鍵和值。
- `start_pos`:確保旋轉嵌入與當前生成索引對齊。
3. 生成迴圈中的預填充與解碼
`VisionLanguageModel` 的 `generate()` 方法發生了最大的架構變化。
我們將**生成分為兩個階段**:
- **預填充階段:**編碼完整提示並構建初始快取。
- **解碼階段:**使用快取的鍵/值一次生成一個 token。
PREFILL PHASE (cache construction)
[Prompt: "What is"] → [Transformer] → [Cache: K, V for all layers]
DECODE PHASE (token-by-token)
[Token: "the"] → [Q("the") + cached K/V] → [next token: "?"] → ...
相應的程式碼如下:
# PREFILL: Process the input prompt, fill the cache
prompt_output, kv_cache_list = self.forward(
inputs,
kv_cache=None,
start_pos=0
)
# DECODE: Generate one token at a time using cached K/V
for i in range(max_new_tokens):
next_token = sample_from(prompt_output)
decode_output, kv_cache_list = self.forward(
next_token,
kv_cache=kv_cache_list,
start_pos=current_position # updated with each step
)
prompt_output = decode_output
透過分離這些階段,我們避免了冗餘計算,並顯著加快了推理速度,特別是對於長提示。
更改總結
| 模組 | 原始行為 | 新行為 |
|---|---|---|
LanguageModelGroupedAttention.forward |
每步重新計算 、、 | 使用並更新 KV 快取 |
LanguageModel.forward |
沒有之前的狀態記憶 | 跟蹤逐層 KV 快取,處理 `start_pos` |
VisionLanguageModel.generate |
單階段生成迴圈 | 分為**預填充**和**解碼**階段 |
總結:KV 快取的重要性
| 益處 | 說明 |
|---|---|
| 增量增長 | 快取每增加一個新 token 就增加一行 |
| 位置感知解碼 | `start_pos` 確保位置編碼計算的正確性 |
| 效率 | 將每個 token 的推理時間複雜度從二次方降低到 O(`seq len`) |
KV 快取消除了自迴歸生成過程中不必要的計算,從而實現了更快、更高效的推理,尤其是在長序列和即時應用中。這是速度與記憶體之間的權衡,其缺點可能是程式碼更復雜,並限制了更高階的推理方案,如束搜尋等。KV 快取是加速 LLM 推理的一種流行方法,使得它們可以在消費級硬體上執行,現在你也知道它是如何工作的了!

