在 nanoVLM 中從零開始實現 KV Cache

釋出於 2025 年 6 月 4 日
在 GitHub 上更新

TL;DR

我們已經在 nanoVLM 倉庫(一個使用純 PyTorch 訓練自己的視覺語言模型的小型程式碼庫)中從零開始實現了 KV 快取。這使我們的生成速度提升了 38%。在這篇部落格文章中,我們將介紹 KV 快取以及我們在實現它時獲得的所有經驗。所學到的經驗是通用的,可以應用於所有自迴歸語言模型的生成。在一個小型程式碼庫上從零開始實現是一個很好的學習經驗,歡迎加入!

bar plot showcasing improvement in generation speed

引言

自迴歸語言模型透過**一次取樣一個 token** 來生成文字。在推理過程中,模型處理給定的輸入序列,預測下一個 token,將其附加到序列中,並重復此過程直到滿足某個停止條件。

diagram for autoregression

這種逐步生成本質上是順序的。

  • 為了生成 token ti+1 t_{i+1} ,模型必須考慮從 t0 t_0 ti t_i 的整個序列。在上述示例中,ti+1 t_{i+1} 將是 the,而所有之前的 token t0 t_0 ti t_i 將是 [What, is, in]
  • 儘管 Transformer 內部是並行的,但每個新的預測都需要對所有 Transformer 層進行一次完整的正向傳播,這會帶來與序列長度呈二次方的記憶體/計算開銷。

這種重複也會導致計算上的**冗餘**。在這篇文章中,我們將探討**KV 快取**,這是一種緩解這種低效率的最佳化技術。

目錄

重溫 Transformer 架構

在深入探討快取之前,讓我們回顧一下 Transformer 模型中注意力的運作方式。Transformer 語言模型由堆疊層組成,每層包含:

  • 多頭自注意力
  • 前饋網路 (MLP)
  • 殘差連線和層歸一化

為了理解**KV 快取的幫助之處**,我們重點關注**自注意力**機制,特別是單個注意力頭內部。

讓我們透過一個簡單的 PyTorch 實現來視覺化關鍵計算。

import torch

input_seq_length = 5
dim_model = 10

input_ids_emb = torch.randn(input_seq_length, dim_model)
W_q = torch.randn(dim_model, dim_model)
W_k = torch.randn(dim_model, dim_model)
W_v = torch.randn(dim_model, dim_model)

Q = input_ids_emb @ W_q
K = input_ids_emb @ W_k
V = input_ids_emb @ W_v

自注意力計算

對於 T T 個輸入嵌入序列,表示為 XRT×D X \in \mathbb{R}^{T \times D} ,自注意力計算如下:

  • Q=XWQ Q = XW_Q ,其中 WQRD×Dq W_Q \in \mathbb{R}^{D \times D_q}
  • K=XWK K = XW_K ,其中 WKRD×Dk W_K \in \mathbb{R}^{D \times D_k}
  • V=XWV V = XW_V ,其中 WVRD×Dv W_V \in \mathbb{R}^{D \times D_v}
  • 因果掩碼 M M 用於防止訪問未來 token。

最終輸出為

Attention(X;Q,K,V)=softmax(QKMdk)V \text{Attention}(X; Q, K, V) = \text{softmax}\left( \frac{QK^\top \cdot M}{\sqrt{d_k}} \right)V

這是一個使用因果掩碼的最小 PyTorch 等效實現。

import torch.nn.functional as F
import math

d_k = K.shape[-1]
attention_scores = (Q @ K.T) / math.sqrt(d_k)

# Lower triangular mask to prevent future token access
causal_mask = torch.tril(torch.ones(input_seq_length, input_seq_length))
masked_scores = attention_scores.masked_fill(causal_mask == 0, float('-inf'))

attention_weights = F.softmax(masked_scores, dim=-1)
output = attention_weights @ V

冗餘之處

在自迴歸生成中,模型每次生成一個 token。在每一步中,它都會**為整個序列**重新計算 Q Q K K V V ,即使較早的 token 並未改變。

new_token_emb = torch.randn(1, dim_model)
extended_input = torch.cat([input_ids_emb, new_token_emb], dim=0)

Q_ext = extended_input @ W_q
K_ext = extended_input @ W_k
V_ext = extended_input @ W_v

# (output_ext would be computed using Q_ext, K_ext, V_ext + masking)

為了確認冗餘

torch.testing.assert_close(K, K_ext[:input_seq_length]) # test pass
torch.testing.assert_close(V, V_ext[:input_seq_length]) # test pass

這些檢查表明,對於除最新 token 之外的所有 token,K K V V 與先前計算的值相同。

Original (5×5):         Extended (6×6):
■ ■ ■ ■ ■              ■ ■ ■ ■ ■ □
■ ■ ■ ■ ■              ■ ■ ■ ■ ■ □
■ ■ ■ ■ ■    →         ■ ■ ■ ■ ■ □
■ ■ ■ ■ ■              ■ ■ ■ ■ ■ □
■ ■ ■ ■ ■              ■ ■ ■ ■ ■ □
                       □ □ □ □ □ □
  • **■** = 已計算並重復使用
  • **□** = 不必要地重新計算

大部分注意力計算被不必要地重複。隨著序列的增長,這會變得更加昂貴。

KV 快取如何解決它

為了消除這種低效率,我們使用 **KV 快取**:

  • 在處理完初始提示後,我們**快取**每個層計算出的鍵 (K K ) 和值 (V V )。
  • 在生成過程中,我們**只計算新 token 的** K K **和** V V ,並**將其附加**到快取中。
  • 我們計算當前 token 的 Q Q ,並將其與**快取的 K K V V ** 一起使用以獲得輸出。

這使得生成從全序列重新計算變為輕量級的增量更新。

✅ 在實踐中,此快取是一個逐層字典,包含“key”和“value”,每個形狀為 (batch_size, num_heads, seq_len_cached, head_dim)。

這是現代 LLM 如何高效生成長輸出的基礎。

nanoVLM 中的 KV 快取:從理論到實踐

既然我們已經理解了 KV 快取背後的理論,接下來讓我們看看它在我們的 nanoVLM 倉庫中是如何實際實現的。這是一個理想的測試平臺,因為它是一個超級簡潔且自包含的程式碼庫。

KV 快取體現在我們模型的三個關鍵元件中:

  1. 使用和更新 KV 快取的**注意力塊**
  2. 跟蹤每層快取的**語言模型**
  3. 區分**預填充**(使用輸入提示的初始傳遞)和順序**解碼**階段的**生成迴圈**

1. 在注意力塊中更新 KV 快取

在 `LanguageModelGroupedAttention` 類中,我們修改了 `forward` 函式,使其接受並更新鍵和值(`block_kv_cache`)的快取。

以前,模型在每個生成步驟都會重新計算 K K V V 。現在我們只計算當前 token 的 Knew K_{\text{new}} Vnew V_{\text{new}} ,並將其附加到快取的值中。

def forward(self, x, cos, sin, attention_mask=None, block_kv_cache=None):
    is_prefill = block_kv_cache is None
    B, T_curr, C = x.size()

    # Project inputs to Q, K, V
    q_curr, k_curr, v_curr = project_current_tokens(x)
    q, k_rotated = apply_rotary_pos_embd(q_curr, k_curr, cos, sin)

    if not is_prefill and block_kv_cache['key'] is not None:
        # Append new keys and values to the cache
        k = torch.cat([block_kv_cache['key'], k_rotated], dim=2)
        v = torch.cat([block_kv_cache['value'], v_curr], dim=2)
    else:
        # First pass (prefill) — no cache
        k, v = k_rotated, v_curr

    block_kv_cache = {'key': k, 'value': v}
    return attention_output, block_kv_cache

2. 跨層跟蹤快取

在 `LanguageModel` 類中,我們引入了**逐層快取跟蹤**。`start_pos` 引數有助於模型為新生成的 token 計算正確的**旋轉位置編碼**。

def forward(self, x, kv_cache=None, start_pos=0):
    T_curr = x.size(1)
    position_ids = torch.arange(start_pos, start_pos + T_curr, device=x.device)
    cos, sin = self.rotary_embd(position_ids)

    for i, block in enumerate(self.blocks):
        # Pass per-layer KV cache
        x, kv_cache[i] = block(x, cos, sin, attention_mask, kv_cache[i])

    return x, kv_cache
  • `kv_cache`:一個字典列表,每個 transformer 層一個,儲存著先前的鍵和值。
  • `start_pos`:確保旋轉嵌入與當前生成索引對齊。

3. 生成迴圈中的預填充與解碼

`VisionLanguageModel` 的 `generate()` 方法發生了最大的架構變化。

我們將**生成分為兩個階段**:

  • **預填充階段:**編碼完整提示並構建初始快取。
  • **解碼階段:**使用快取的鍵/值一次生成一個 token。
PREFILL PHASE (cache construction)
[Prompt: "What is"] → [Transformer] → [Cache: K, V for all layers]

DECODE PHASE (token-by-token)
[Token: "the"] → [Q("the") + cached K/V] → [next token: "?"] → ...

相應的程式碼如下:

# PREFILL: Process the input prompt, fill the cache
prompt_output, kv_cache_list = self.forward(
    inputs,
    kv_cache=None,
    start_pos=0
)

# DECODE: Generate one token at a time using cached K/V
for i in range(max_new_tokens):
    next_token = sample_from(prompt_output)

    decode_output, kv_cache_list = self.forward(
        next_token,
        kv_cache=kv_cache_list,
        start_pos=current_position  # updated with each step
    )

    prompt_output = decode_output

透過分離這些階段,我們避免了冗餘計算,並顯著加快了推理速度,特別是對於長提示。

更改總結

模組 原始行為 新行為
LanguageModelGroupedAttention.forward 每步重新計算 Q Q K K V V 使用並更新 KV 快取
LanguageModel.forward 沒有之前的狀態記憶 跟蹤逐層 KV 快取,處理 `start_pos`
VisionLanguageModel.generate 單階段生成迴圈 分為**預填充**和**解碼**階段

總結:KV 快取的重要性

益處 說明
增量增長 快取每增加一個新 token 就增加一行
位置感知解碼 `start_pos` 確保位置編碼計算的正確性
效率 將每個 token 的推理時間複雜度從二次方降低到 O(`seq len`)

KV 快取消除了自迴歸生成過程中不必要的計算,從而實現了更快、更高效的推理,尤其是在長序列和即時應用中。這是速度與記憶體之間的權衡,其缺點可能是程式碼更復雜,並限制了更高階的推理方案,如束搜尋等。KV 快取是加速 LLM 推理的一種流行方法,使得它們可以在消費級硬體上執行,現在你也知道它是如何工作的了!

社群

感謝這篇精彩的文章!我從 nanoVLM 專案中學到了很多。
我不是生成式 AI 方面的專家,但我注意到注意力計算示例似乎缺少縮放 √(d_k)。這是為了簡化而故意省略的嗎?

d_k = K.shape[-1]
attention_scores = (Q @ K.T) / math.sqrt(d_k)

據我理解,這種縮放可以防止點積變得過大,並控制 softmax 區域。

·
文章作者

這個發現太棒了!

你願意為部落格文章提交一個包含這些更改的 PR 嗎?

這是部落格文章的原始碼:https://github.com/huggingface/blog/blob/main/kv-cache.md

讀得真好,我發現預填充和解碼的解釋非常直觀。幹得漂亮 👏

這是我用 kv 快取時注意力機制內部發生的情況的視覺化表示。
我想與社群分享 🤗

·
文章作者

太酷了!感謝分享。

註冊登入 以評論

© . This site is unofficial and not affiliated with Hugging Face, Inc.