Transformers 文件
快取
並獲得增強的文件體驗
開始使用
快取
想象一下你正在和某人聊天,他們沒有記住之前說過的話,而是每次你回應時都必須從頭開始。這會非常慢且效率低下,對嗎?
你可以將這個類比擴充套件到 transformer 模型。自迴歸模型生成可能很慢,因為它一次預測一個 token。每個新的預測都依賴於所有先前的上下文。
為了預測第 1000 個 token,模型需要來自前 999 個 token 的資訊。這些資訊表示為 token 表示之間的矩陣乘法。
為了預測第 1001 個 token,除了來自第 1000 個 token 的任何資訊之外,你還需要來自前 999 個 token 的相同資訊。模型必須為每個 token 反覆計算大量的矩陣乘法!
鍵值 (KV) 快取透過儲存從先前處理的 token 的注意力層派生出的 kv 對來消除這種低效率。儲存的 kv 對從快取中檢索並重新用於後續 token,避免了重新計算的需要。
快取只應用於推理。如果在訓練期間啟用它,可能會導致意外錯誤。
為了更好地理解快取的工作原理和原因,讓我們仔細看看注意力矩陣的結構。
注意力矩陣
批處理大小為 b
,注意力頭數為 h
,到目前為止的序列長度為 T
,每個注意力頭的維度為 d_head
的縮放點積注意力計算如下:
查詢 (Q
)、鍵 (K
) 和值 (V
) 矩陣是輸入嵌入的投影,形狀為 (b, h, T, d_head)
。
對於因果注意力,掩碼阻止模型關注未來的 token。一旦 token 被處理,它的表示相對於未來的 token 就不會改變,這意味著和可以快取並重新用於計算最後一個 token 的表示。
在推理時,你只需要最後一個 token 的查詢來計算表示,它預測下一個 token。在每個步驟中,新的鍵和值向量都被儲存在快取中,並附加到過去的鍵和值中。
注意力在模型的每一層獨立計算,並且快取是逐層進行的。
請參考下表比較快取如何提高效率。
不使用快取 | 使用快取 |
---|---|
對於每個步驟,重新計算所有先前的 K 和 V | 對於每個步驟,僅計算當前的 K 和 V |
每個步驟的注意力成本與序列長度呈二次關係 | 每個步驟的注意力成本與序列長度呈線性關係(記憶體線性增長,但計算/token 保持較低) |
快取類
一個基本的 KV 快取介面接收當前 token 的鍵張量和值張量,並返回更新後的 K
和 V
張量。這由模型的 forward
方法內部管理。
new_K, new_V = cache.update(k_t, v_t, layer_idx) attn_output = attn_layer_idx_fn(q_t, new_K, new_V)
當你使用 Transformers 的 Cache 類時,自注意力模組執行幾個關鍵步驟來整合過去和現在的資訊。
注意力模組將當前 kv 對與快取中儲存的過去 kv 對連線起來。這會建立形狀為
(new_tokens_length, past_kv_length + new_tokens_length)
的注意力權重。當前和過去的 kv 對本質上是組合起來計算注意力分數,確保模型瞭解以前的上下文和當前輸入。當
forward
方法迭代呼叫時,注意力掩碼的形狀與過去和當前 kv 對的組合長度匹配至關重要。注意力掩碼的形狀應為(batch_size, past_kv_length + new_tokens_length)
。這通常在 generate() 中內部處理,但如果你想使用 Cache 實現自己的生成迴圈,請記住這一點!注意力掩碼應包含過去和當前 token 值。同樣重要的是要注意
cache_position
。如果你想用forward
方法重用預填充的 Cache,這很重要,因為你必須傳遞一個有效的cache_position
值。這表示序列中的輸入位置。cache_position
不受填充影響,並且它總是為每個 token 增加一個位置。例如,如果 kv 快取包含 10 個 token - 無論填充 token 如何 - 下一個 token 的快取位置應該是torch.tensor([10])
。
快取儲存實現
鍵值對的實際儲存在不同的快取實現之間有所不同。例如,考慮 DynamicCache。
在 DynamicCache 中,鍵值對作為兩個張量列表儲存。列表中的每個張量都具有形狀 [batch_size, num_heads, seq_len, head_dim]
。
key_cache
:一個張量列表,每層一個。value_cache
:一個張量列表,每層一個。
當處理新 token 時
- 對於每一層,新的鍵和值狀態與現有快取連線。
self.key_cache[layer_idx] = torch.cat([self.key_cache[layer_idx], key_states], dim=-2)
self.value_cache[layer_idx] = torch.cat([self.value_cache[layer_idx], value_states], dim=-2)
隨著更多 token 的處理,快取會動態增長。序列長度維度(
seq_len
)隨著每個新 token 的增加而增加。快取透過
self._seen_tokens
維護已看到的 token 計數。當第一層處理新 token 時,此計數會更新。
以下示例演示瞭如何使用 DynamicCache 建立生成迴圈。如前所述,注意力掩碼是過去和當前 token 值的連線,並且為下一個 token 的快取位置新增 1
。
import torch
from transformers import AutoTokenizer, AutoModelForCausalLM, DynamicCache
model_id = "meta-llama/Llama-2-7b-chat-hf"
model = AutoModelForCausalLM.from_pretrained(model_id, torch_dtype=torch.bfloat16, device_map="cuda:0")
tokenizer = AutoTokenizer.from_pretrained(model_id)
past_key_values = DynamicCache()
messages = [{"role": "user", "content": "Hello, what's your name."}]
inputs = tokenizer.apply_chat_template(messages, add_generation_prompt=True, return_tensors="pt", return_dict=True).to("cuda:0")
generated_ids = inputs.input_ids
cache_position = torch.arange(inputs.input_ids.shape[1], dtype=torch.int64, device="cuda:0")
max_new_tokens = 10
for _ in range(max_new_tokens):
outputs = model(**inputs, cache_position=cache_position, past_key_values=past_key_values, use_cache=True)
# Greedily sample one next token
next_token_ids = outputs.logits[:, -1:].argmax(-1)
generated_ids = torch.cat([generated_ids, next_token_ids], dim=-1)
# Prepare inputs for the next generation step by leaving unprocessed tokens, in our case we have only one new token
# and expanding attn mask for the new token, as explained above
attention_mask = inputs["attention_mask"]
attention_mask = torch.cat([attention_mask, attention_mask.new_ones((attention_mask.shape[0], 1))], dim=-1)
inputs = {"input_ids": next_token_ids, "attention_mask": attention_mask}
cache_position = cache_position[-1:] + 1 # add one more position for the next token
print(tokenizer.batch_decode(generated_ids, skip_special_tokens=True)[0])
"[INST] Hello, what's your name. [/INST] Hello! My name is LLaMA,"
傳統快取格式
在 Cache 類之前,快取通常以張量元組的元組形式儲存。這種格式是動態的,因為它會隨著文字的生成而增長,類似於 DynamicCache。
傳統格式本質上是相同的資料結構,但組織方式不同。
- 它是一個元組的元組,其中每個內部元組包含一層的鍵和值張量。
- 張量具有相同的形狀
[batch_size, num_heads, seq_len, head_dim]
。 - 這種格式靈活性較低,不支援量化或解除安裝等功能。
如果你的專案依賴於此傳統格式,你可以使用 from_legacy_cache() 和 DynamicCache.to_legacy_cache() 函式在 DynamicCache 和元組的元組之間進行轉換。如果你有用於以特定格式操作快取的自定義邏輯,這將很有幫助。
import torch
from transformers import AutoTokenizer, AutoModelForCausalLM, DynamicCache
tokenizer = AutoTokenizer.from_pretrained("meta-llama/Llama-2-7b-chat-hf")
model = AutoModelForCausalLM.from_pretrained("meta-llama/Llama-2-7b-chat-hf", torch_dtype=torch.float16, device_map="auto")
inputs = tokenizer("Hello, my name is", return_tensors="pt").to(model.device)
# `return_dict_in_generate=True` is required to return the cache and `return_legacy_cache` forces the returned cache
# in the legacy format
generation_outputs = model.generate(**inputs, return_dict_in_generate=True, return_legacy_cache=True, max_new_tokens=5)
cache = DynamicCache.from_legacy_cache(generation_outputs.past_key_values)
legacy_format_cache = cache.to_legacy_cache()