Transformers 文件

快取

Hugging Face's logo
加入 Hugging Face 社群

並獲得增強的文件體驗

開始使用

快取

想象一下你正在和某人聊天,他們沒有記住之前說過的話,而是每次你回應時都必須從頭開始。這會非常慢且效率低下,對嗎?

你可以將這個類比擴充套件到 transformer 模型。自迴歸模型生成可能很慢,因為它一次預測一個 token。每個新的預測都依賴於所有先前的上下文。

為了預測第 1000 個 token,模型需要來自前 999 個 token 的資訊。這些資訊表示為 token 表示之間的矩陣乘法。

為了預測第 1001 個 token,除了來自第 1000 個 token 的任何資訊之外,你還需要來自前 999 個 token 的相同資訊。模型必須為每個 token 反覆計算大量的矩陣乘法!

鍵值 (KV) 快取透過儲存從先前處理的 token 的注意力層派生出的 kv 對來消除這種低效率。儲存的 kv 對從快取中檢索並重新用於後續 token,避免了重新計算的需要。

快取只應用於推理。如果在訓練期間啟用它,可能會導致意外錯誤。

為了更好地理解快取的工作原理和原因,讓我們仔細看看注意力矩陣的結構。

注意力矩陣

批處理大小為 b,注意力頭數為 h,到目前為止的序列長度為 T,每個注意力頭的維度為 d_head縮放點積注意力計算如下:Attention(Q,K,V)=softmax(QKdhead×mask)V \text{Attention}(Q, K, V) = \text{softmax}\left( \frac{Q K^\top}{\sqrt{d_{\text{head}}}} \times \text{mask} \right) V

查詢 (Q)、鍵 (K) 和值 (V) 矩陣是輸入嵌入的投影,形狀為 (b, h, T, d_head)

對於因果注意力,掩碼阻止模型關注未來的 token。一旦 token 被處理,它的表示相對於未來的 token 就不會改變,這意味著Kpast K_{\text{past}} Vpast V_{\text{past}} 可以快取並重新用於計算最後一個 token 的表示。Attention(qt,[k1,k2,,kt1cached,kt],[v1,v2,,vt1cached,vt]) \text{Attention}(q_t, [\underbrace{k_1, k_2, \dots, k_{t-1}}_{\text{cached}}, k_{t}], [\underbrace{v_1, v_2, \dots, v_{t-1}}_{\text{cached}}, v_{t}])

在推理時,你只需要最後一個 token 的查詢來計算表示xt x_t ,它預測下一個 tokent+1 t+1 。在每個步驟中,新的鍵和值向量都被儲存在快取中,並附加到過去的鍵和值中。Kcacheconcat(Kpast,kt),Vcacheconcat(Vpast,vt) K_{\text{cache}} \leftarrow \text{concat}(K_{\text{past}}, k_t), \quad V_{\text{cache}} \leftarrow \text{concat}(V_{\text{past}}, v_t)

注意力在模型的每一層獨立計算,並且快取是逐層進行的。

請參考下表比較快取如何提高效率。

不使用快取 使用快取
對於每個步驟,重新計算所有先前的 KV 對於每個步驟,僅計算當前的 KV
每個步驟的注意力成本與序列長度呈二次關係 每個步驟的注意力成本與序列長度呈線性關係(記憶體線性增長,但計算/token 保持較低)

快取類

一個基本的 KV 快取介面接收當前 token 的鍵張量和值張量,並返回更新後的 KV 張量。這由模型的 forward 方法內部管理。

new_K, new_V = cache.update(k_t, v_t, layer_idx)
attn_output = attn_layer_idx_fn(q_t, new_K, new_V)

當你使用 Transformers 的 Cache 類時,自注意力模組執行幾個關鍵步驟來整合過去和現在的資訊。

  1. 注意力模組將當前 kv 對與快取中儲存的過去 kv 對連線起來。這會建立形狀為 (new_tokens_length, past_kv_length + new_tokens_length) 的注意力權重。當前和過去的 kv 對本質上是組合起來計算注意力分數,確保模型瞭解以前的上下文和當前輸入。

  2. forward 方法迭代呼叫時,注意力掩碼的形狀與過去和當前 kv 對的組合長度匹配至關重要。注意力掩碼的形狀應為 (batch_size, past_kv_length + new_tokens_length)。這通常在 generate() 中內部處理,但如果你想使用 Cache 實現自己的生成迴圈,請記住這一點!注意力掩碼應包含過去和當前 token 值。

  3. 同樣重要的是要注意 cache_position。如果你想用 forward 方法重用預填充的 Cache,這很重要,因為你必須傳遞一個有效的 cache_position 值。這表示序列中的輸入位置。cache_position 不受填充影響,並且它總是為每個 token 增加一個位置。例如,如果 kv 快取包含 10 個 token - 無論填充 token 如何 - 下一個 token 的快取位置應該是 torch.tensor([10])

快取儲存實現

鍵值對的實際儲存在不同的快取實現之間有所不同。例如,考慮 DynamicCache

DynamicCache 中,鍵值對作為兩個張量列表儲存。列表中的每個張量都具有形狀 [batch_size, num_heads, seq_len, head_dim]

  • key_cache:一個張量列表,每層一個。
  • value_cache:一個張量列表,每層一個。

當處理新 token 時

  1. 對於每一層,新的鍵和值狀態與現有快取連線。
self.key_cache[layer_idx] = torch.cat([self.key_cache[layer_idx], key_states], dim=-2)
self.value_cache[layer_idx] = torch.cat([self.value_cache[layer_idx], value_states], dim=-2)
  1. 隨著更多 token 的處理,快取會動態增長。序列長度維度(seq_len)隨著每個新 token 的增加而增加。

  2. 快取透過 self._seen_tokens 維護已看到的 token 計數。當第一層處理新 token 時,此計數會更新。

以下示例演示瞭如何使用 DynamicCache 建立生成迴圈。如前所述,注意力掩碼是過去和當前 token 值的連線,並且為下一個 token 的快取位置新增 1

import torch
from transformers import AutoTokenizer, AutoModelForCausalLM, DynamicCache

model_id = "meta-llama/Llama-2-7b-chat-hf"
model = AutoModelForCausalLM.from_pretrained(model_id, torch_dtype=torch.bfloat16, device_map="cuda:0")
tokenizer = AutoTokenizer.from_pretrained(model_id)

past_key_values = DynamicCache()
messages = [{"role": "user", "content": "Hello, what's your name."}]
inputs = tokenizer.apply_chat_template(messages, add_generation_prompt=True, return_tensors="pt", return_dict=True).to("cuda:0")

generated_ids = inputs.input_ids
cache_position = torch.arange(inputs.input_ids.shape[1], dtype=torch.int64, device="cuda:0")
max_new_tokens = 10

for _ in range(max_new_tokens):
    outputs = model(**inputs, cache_position=cache_position, past_key_values=past_key_values, use_cache=True)
    # Greedily sample one next token
    next_token_ids = outputs.logits[:, -1:].argmax(-1)
    generated_ids = torch.cat([generated_ids, next_token_ids], dim=-1)
    # Prepare inputs for the next generation step by leaving unprocessed tokens, in our case we have only one new token
    # and expanding attn mask for the new token, as explained above
    attention_mask = inputs["attention_mask"]
    attention_mask = torch.cat([attention_mask, attention_mask.new_ones((attention_mask.shape[0], 1))], dim=-1)
    inputs = {"input_ids": next_token_ids, "attention_mask": attention_mask}
    cache_position = cache_position[-1:] + 1 # add one more position for the next token

print(tokenizer.batch_decode(generated_ids, skip_special_tokens=True)[0])
"[INST] Hello, what's your name. [/INST]  Hello! My name is LLaMA,"

傳統快取格式

Cache 類之前,快取通常以張量元組的元組形式儲存。這種格式是動態的,因為它會隨著文字的生成而增長,類似於 DynamicCache

傳統格式本質上是相同的資料結構,但組織方式不同。

  • 它是一個元組的元組,其中每個內部元組包含一層的鍵和值張量。
  • 張量具有相同的形狀 [batch_size, num_heads, seq_len, head_dim]
  • 這種格式靈活性較低,不支援量化或解除安裝等功能。

如果你的專案依賴於此傳統格式,你可以使用 from_legacy_cache()DynamicCache.to_legacy_cache() 函式在 DynamicCache 和元組的元組之間進行轉換。如果你有用於以特定格式操作快取的自定義邏輯,這將很有幫助。

import torch
from transformers import AutoTokenizer, AutoModelForCausalLM, DynamicCache

tokenizer = AutoTokenizer.from_pretrained("meta-llama/Llama-2-7b-chat-hf")
model = AutoModelForCausalLM.from_pretrained("meta-llama/Llama-2-7b-chat-hf", torch_dtype=torch.float16, device_map="auto")
inputs = tokenizer("Hello, my name is", return_tensors="pt").to(model.device)

# `return_dict_in_generate=True` is required to return the cache and `return_legacy_cache` forces the returned cache
# in the legacy format
generation_outputs = model.generate(**inputs, return_dict_in_generate=True, return_legacy_cache=True, max_new_tokens=5)

cache = DynamicCache.from_legacy_cache(generation_outputs.past_key_values)
legacy_format_cache = cache.to_legacy_cache()
< > 在 GitHub 上更新

© . This site is unofficial and not affiliated with Hugging Face, Inc.