一次失敗的實驗:Infini-Attention,以及我們為何應繼續嘗試?
TLDR:Infini-attention 的效能隨著記憶體壓縮次數的增加而變差,據我們所知,環注意力(ring attention)、YaRN 和 rope scaling 仍然是擴充套件預訓練模型到更長上下文長度的最佳方法。
第 0 節:引言
語言模型的上下文長度是其核心屬性之一,與模型效能並列。自上下文學習興起以來,向模型輸入新增相關資訊變得越來越重要。因此,上下文長度迅速從段落(BERT/GPT-1 的 512 個 token)增加到頁面(GPT-2 和 GPT-3 分別為 1024/2048 個 token),再到書籍(Claude 的 128k),直至書籍集合(Gemini 的 1-10M 個 token)。然而,將標準注意力擴充套件到如此長的上下文長度仍然具有挑戰性。
環注意力(Ring Attention)簡介:環注意力最初由加州大學伯克利分校的研究人員於 2024 年提出(據我們所知)[連結]。這項工程技術透過分塊執行自注意力和前饋網路計算,並將序列維度分佈到多個裝置上,從而實現併發計算和通訊,有助於克服記憶體限制。
即使使用環注意力,以批次大小為 1、100 萬 token 上下文長度訓練 Llama 3 8B 仍需要 512 個 GPU。正如縮放定律所表明的,模型大小與其下游效能之間存在強相關性,這意味著模型越大越好(當然,兩個模型都應該經過良好訓練)。因此,我們不僅需要 100 萬的上下文長度,還需要在最大的模型(例如 Llama 3 8B 405B)上實現 100 萬的上下文長度。而目前只有少數公司擁有這樣做的資源。
自注意力記憶體複雜性回顧:在標準注意力(非 Flash Attention)中,每個 token 都會關注序列中的所有其他 token,導致注意力矩陣的大小為 [seq_len, seq_len]。對於每對 token,我們都會計算一個注意力分數,並且隨著序列長度 (seq_len) 的增加,記憶體和計算需求呈二次方增長:注意力矩陣的記憶體複雜度為 O(seq_len^2)。例如,序列長度增加 10 倍會導致記憶體需求增加 100 倍。即使是記憶體效率高的注意力方法,如 Flash Attention,其記憶體需求也隨上下文長度線性增加,並受限於單個 GPU 記憶體,導致當今 GPU 上的典型最大上下文長度遠低於 1M token。
受此啟發,我們探索了標準注意力的另一種方法:Infini-attention。該論文由 Google 的研究人員於 2024 年 4 月釋出[連結]。Infini-attention 不計算每個詞之間的注意力分數,而是將序列分成更小的固定大小的“片段”,將較早的片段壓縮到固定緩衝區中,並允許下一個片段從較早的片段中檢索記憶體,同時將注意力分數限制在當前片段的詞中。一個關鍵的優勢是其固定的緩衝區大小限制了總記憶體使用量。它還在一個片段中使用相同的查詢來訪問其自身片段和壓縮記憶體中的資訊,這使得我們能夠廉價地擴充套件預訓練模型的上下文長度。理論上,我們可以實現無限上下文長度,因為它只為所有較早片段的記憶體保留一個緩衝區。然而,實際上,壓縮限制了可以有效儲存的資訊量,因此問題是:這種壓縮後的記憶體有多大的可用性?
雖然在紙面上理解一種新方法相對容易,但實際使其工作通常是另一回事,而且這個故事很少公開分享。受此啟發,我們決定分享我們在重現 Infini-attention 論文方面的實驗和歷程,是什麼激勵我們在除錯過程中(我們花了 90% 的時間除錯收斂問題),以及讓這些事情工作起來有多麼困難。
隨著 Llama 3 8B 的釋出(其上下文長度限制為 8k token),我們試圖將其長度擴充套件到 100 萬 token,而無需二次方增加記憶體。在這篇部落格文章中,我們將首先解釋 Infini-attention 的工作原理。然後,我們將概述我們的復現原則,並描述我們最初的小規模實驗。我們將討論我們面臨的挑戰,我們如何解決這些挑戰,並總結我們的發現和我們探索過的其他想法。如果您有興趣測試我們訓練好的檢查點[連結],您可以在以下儲存庫中找到它[連結](請注意,我們目前按原樣提供程式碼)。
第 1 節:復現原則
我們發現在實現新方法時,以下規則很有幫助,並將其作為我們許多工作的指導原則:
- 原則 1: 從能提供良好訊號的最小模型尺寸開始,一旦獲得良好訊號,再擴大實驗規模。
- 原則 2: 始終訓練一個可靠的基線模型來衡量進展。
- 原則 3: 為了確定修改是否能提高效能,訓練兩個模型,除了被測試的修改之外,其他設定均相同。
牢記這些原則,讓我們深入瞭解 Infini-attention 的實際工作原理。理解其機制對我們後續實驗至關重要。
第 2 節:Infini-attention 的工作原理
步驟 1:將輸入序列分割成更小的、固定大小的塊,稱為“片段”。
步驟 2:在每個片段內計算標準的因果點積注意力。
步驟 3:使用當前片段的查詢向量從壓縮記憶體中提取相關資訊。檢索過程的數學定義如下:
- :從記憶體中檢索到的內容,表示長期上下文。
- :查詢矩陣,其中 是查詢數量, 是每個查詢的維度。
- :來自前一個片段的記憶體矩陣,儲存鍵值對。
- :一個非線性啟用函式,具體為逐元素的指數線性單元(ELU)加 1。
- :一個歸一化項。
import torch.nn.functional as F
from torch import einsum
from einops import rearrange
def _retrieve_from_memory(query_states, prev_memory, prev_normalization):
...
sigma_query_states = F.elu(query_states) + 1
retrieved_memory = einsum(
sigma_query_states,
prev_memory,
"batch_size n_heads seq_len d_k, batch_size n_heads d_k d_v -> batch_size n_heads seq_len d_v",
)
denominator = einsum(
sigma_query_states,
prev_normalization,
"batch_size n_heads seq_len d_head, batch_size n_heads d_head -> batch_size n_heads seq_len",
)
denominator = rearrange(
denominator,
"batch_size n_heads seq_len -> batch_size n_heads seq_len 1",
)
# NOTE: because normalization is the sum of all the keys, so each word should have the same normalization
retrieved_memory = retrieved_memory / denominator
return retrieved_memory
步驟 4:將區域性上下文(來自當前片段)與長期上下文(從壓縮記憶體中檢索)結合,生成最終輸出。透過這種方式,注意力輸出可以同時考慮短期和長期上下文。
- :合併後的注意力輸出。
- :一個可學習的標量引數,用於控制長期記憶內容 與區域性上下文之間的權衡。
- :使用點積注意力從當前片段獲得的注意力輸出。
步驟 5:透過添加當前片段的鍵值狀態來更新壓縮記憶體,從而允許我們隨著時間的推移累積上下文。
- :當前片段的更新記憶體矩陣,包含了新資訊。
- :當前片段的鍵矩陣,表示要儲存的新鍵。
- :當前片段的值矩陣,表示與鍵關聯的新值。
- :鍵矩陣中的第 個鍵向量。
- :當前片段的更新歸一化項。
import torch
def _update_memory(prev_memory, prev_normalization, key_states, value_states):
...
sigma_key_states = F.elu(key_states) + 1
if prev_memory is None or prev_normalization is None:
new_value_states = value_states
else:
numerator = einsum(
sigma_key_states,
prev_memory,
"batch_size n_heads seq_len d_k, batch_size n_heads d_k d_v -> batch_size n_heads seq_len d_v",
)
denominator = einsum(
sigma_key_states,
prev_normalization,
"batch_size n_heads seq_len d_k, batch_size n_heads d_k -> batch_size n_heads seq_len",
)
denominator = rearrange(
denominator,
"batch_size n_heads seq_len -> batch_size n_heads seq_len 1",
)
prev_v = numerator / denominator
new_value_states = value_states - prev_v
memory = torch.matmul(sigma_key_states.transpose(-2, -1), new_value_states)
normalization = reduce(
sigma_key_states,
"batch_size n_heads seq_len d_head -> batch_size n_heads d_head",
reduction="sum",
...
)
memory += prev_memory if prev_memory is not None else 0
normalization += prev_normalization if prev_normalization is not None else 0
return memory, normalization
- 步驟 6:當我們從一個片段移動到下一個片段時,我們丟棄前一個片段的注意力狀態,並將更新後的壓縮記憶體傳遞給下一個片段。
def forward(...):
...
outputs = []
global_weights = F.sigmoid(self.balance_factors)
...
local_weights = 1 - global_weights
memory = None
normalization = None
for segment_hidden_state, segment_sequence_mask in zip(segment_hidden_states, segment_sequence_masks):
attn_outputs = self.forward_with_hidden_states(
hidden_states=segment_hidden_state, sequence_mask=segment_sequence_mask, return_qkv_states=True
)
local_attn_outputs = attn_outputs["attention_output"]
query_states, key_states, value_states = attn_outputs["qkv_states_without_pe"]
q_bs = query_states.shape[0]
q_length = query_states.shape[2]
...
retrieved_memory = _retrieve_from_memory(
query_states, prev_memory=memory, prev_normalization=normalization
)
attention_output = global_weights * retrieved_memory + local_weights * local_attn_outputs
...
output = o_proj(attention_output)
memory, normalization = _update_memory(memory, normalization, key_states, value_states)
outputs.append(output)
outputs = torch.cat(outputs, dim=1) # concat along sequence dimension
...
既然我們已經掌握了理論,是時候捲起袖子,進行一些實際的實驗了。讓我們從小處著手,以便快速獲得反饋並快速迭代。
第 3 節:小規模初步實驗
Llama 3 8B 相當大,因此我們決定從一個 200M 的 Llama 模型開始,使用 Nanotron [連結] 和 Fineweb 資料集 [連結] 從頭開始預訓練 Infini-attention。一旦我們獲得了 200M 模型的良好結果,我們便開始對 Llama 3 8B 進行持續預訓練。我們使用 200 萬 token 的批次大小,256 的上下文長度,梯度裁剪為 1,權重衰減為 0.1,前 5,000 次迭代為線性預熱,其餘步驟為餘弦衰減,學習率為 3e-5。
使用通行碼檢索任務進行評估
通行碼檢索任務最初由 EPFL 的研究人員引入[連結]。它旨在評估模型從長上下文(資訊位置可控)中檢索資訊的能力。提示模型的輸入格式結構如下:
有重要資訊隱藏在大量無關文字中。找到並記住它們。我將考你關於那裡的重要資訊。草是綠色的。天空是藍色的。太陽是黃色的。我們開始吧。來來回回。(重複 x 次)通行碼是 9054。記住它。9054 是通行碼。草是綠色的。天空是藍色的。太陽是黃色的。我們開始吧。來來回回。(重複 y 次)通行碼是什麼?通行碼是
如果模型的輸出包含“針”(在上述情況下為“9054”),則我們認為模型在此任務中成功;如果模型輸出不包含,則認為不成功。在我們的實驗中,我們將針放置在上下文中的不同位置,具體是總上下文長度的 0%、5%、10%、...、95% 和 100%(0% 是離生成 token 最遠的位置)。例如,如果上下文長度為 1024 個 token,將針放置在 10% 意味著它位於大約第 102 個 token 的位置。在每個深度位置,我們用 10 個不同的樣本測試模型,並計算平均成功率。
初步結果
以下是一些 200M 小模型的初步結果:
如你所見,它在一定程度上起作用。如果你檢視樣本生成,你會發現 Infini-attention 生成的內容與早期片段相關。
由於 Infini-attention 透過以第一個片段的全部內容為條件來預測第二個片段的第一個 token(它將第一個 token 生成為“_grad”),這提供了一個良好的訊號。為了驗證該訊號是否是假陽性,我們假設 Infini-attention 生成與其早期片段相關的內容,因為當給定“_grad”作為第二個片段的第一個生成 token 時,它始終生成與 PyTorch 相關的教程,而這些教程恰好與它之前的片段相關。因此,我們進行了一個健全性測試,其中唯一的輸入 token 是“_grad”,它生成了[文字在此]。這表明它確實使用了記憶體,但使用得不夠好(無法檢索精確的針或繼續其早期片段的精確內容)。生成結果如下:
_graduate_education.html
Graduate Education
The Department of Physics and Astronomy offers a program leading to the Master of Science degree in physics. The program is designed to provide students with a broad background in
根據這些結果,模型似乎確實使用了壓縮記憶體。我們決定透過持續預訓練 Llama 3 8B 來擴大實驗規模。不幸的是,當針被放置在較早的片段中時,模型未能透過針評估。
我們決定檢查所有層中的平衡因子(平衡壓縮記憶體和未壓縮記憶體的因子)。根據圖 3a 和圖 3b,我們發現大約 95% 的權重集中在 0.5 左右。回想一下,權重是否能收斂到理想範圍取決於兩個一般因素:步長和梯度的幅度。然而,Adam 將梯度歸一化為 1 的幅度,所以問題變成了:訓練超引數是否正確,以便微調能夠收斂?
第 4 節:研究收斂性?
我們決定模擬在梯度處於良好範圍(L2 範數為 0.01)時平衡權重在訓練期間會改變多少,發現根據上一次 8B LLaMA3 微調實驗的配置,權重的絕對總變化將為 0.03。由於我們將平衡因子初始化為 0(在這種情況下無關緊要),因此最終權重將在 [0 - 0.03, 0 + 0.03] = [-0.03, 0.03] 範圍內。
對於 Infini-attention 的良好工作,一個合理的猜測是全域性權重如論文中所示在 0 和 1 範圍內分散。鑑於上述權重,sigmoid([-0.03, 0.03]) = tensor([0.4992, 0.5008])(這與我們之前的實驗結果相符,即平衡因子約為 0.5)。我們決定下一步對平衡因子使用更高的學習率(所有其他引數使用 Llama 3 8B 的學習率),並增加訓練步數,以允許平衡因子至少改變 4,這樣我們可以讓全域性權重在梯度下降需要時達到理想權重(sigmoid(-4) ≈ 0, sigmoid(4) ≈ 1)。
我們還注意到,由於梯度並不總是朝同一個方向,因此會出現抵消。這意味著我們應該將學習率和訓練步數設定得顯著大於總絕對變化。回想一下,Llama 3 8B 的學習率為 3.0x10^-4,這意味著如果我們將其用作全域性學習率,門控功能將無法收斂。
結論:我們決定採用 3.0x10^-4 的全域性學習率和 0.01 的門控學習率,這應該能使門控函式收斂。
在這些超引數下,Infini-attention 中的平衡因子是可訓練的,但我們觀察到 200M llama 的損失在 20B token 後變為 NaN(我們嘗試了從 0.001 到 1.0e-6 的學習率)。我們檢查了 20B token 檢查點(10k 訓練步)的一些生成結果,您可以在圖 4a 中看到。模型現在繼續生成精確的內容並召回身份(如果記憶體被清除,它會生成垃圾)。
但它仍然無法從一個片段中召回另一個片段中的“針”(它在片段內可靠地完成)。當“針”放置在第一個片段中時,針評估完全失敗(當放置在第二個片段中時,總共兩個片段,成功率為 100%)。如圖 4b 所示,我們還觀察到平衡因子在 5,000 步後停止變化。雖然我們取得了一些進展,但我們尚未完全擺脫困境。平衡因子仍然沒有按我們希望的方式表現。我們決定深入挖掘並進行更多調整。
第 5 節:平衡因子無權重衰減
再次仔細檢查平衡因子,我們看到了一些進展:大約 95% 的頭部現在顯示全域性權重在 0.4 到 0.5 之間,並且沒有一個頭部的全域性權重大於 0.6。但是權重仍然不在理想範圍內。
我們想到了另一個潛在原因:權重衰減,它會促使平衡因子 L2 範數較小,導致 sigmoid 值收斂到接近零,並且因子集中在 0.5 附近。
另一個潛在原因是我們的“rollout”設定太小。在 200M 實驗中,我們只使用了 4 個“rollout”,而在 8B 實驗中,我們只使用了 2 個“rollout”(8192**2)。使用更大的“rollout”應該會促使模型更好地壓縮和使用記憶體。因此,我們決定將“rollout”數量增加到 16,並且不使用權重衰減。我們將上下文長度縮小到 1024,並使用 16 個“rollout”,從而獲得 64 的片段長度。
如您所見,全域性權重現在分佈在 0 到 1 的範圍內,其中 10% 的頭部全域性權重在 0.9 到 1.0 之間,儘管在 18k 步之後,大多數頭部停止了其全域性權重的變化。我們現在非常有信心,如果梯度下降的“精神”與我們同在,實驗設定將允許收斂。唯一剩下的問題是 Infini-attention 的總體方法是否能夠很好地工作。
以下評估在 1.5B token 下執行。
- 0-短:在提示 2 中,它回憶了一個人學習的地方(昨天的 8b 模型在這方面失敗了),但在針通行碼方面失敗了(尚未全面執行;將執行)。
- 1-短
- 提示 3:它識別了一個人的位置。
- 提示 4:它通過了通行碼檢測。
在這種情況下,模型會繼續生成與早期片段完全相同的內容。(在我們之前的實驗中,模型未能繼續生成早期片段的精確內容,而只是生成了大致相關的內容;因此,新模型已經好很多了。)
第 6 節:結論
不幸的是,儘管取得了這些進展,我們發現在我們的實驗中 Infini-attention 並沒有足夠的說服力,尤其是不夠可靠。在我們的復現階段,我們仍然認為環注意力(Ring Attention)[連結]、YaRN [連結] 和 rope scaling [連結] 是將預訓練模型擴充套件到更長上下文長度的更好選擇。
對於超大型模型(例如 400B 及以上),這些技術仍然需要大量的資源。因此,我們仍然認為探索壓縮技術或繼續推進我們在本部落格文章中描述的系列實驗對社群來說具有巨大的興趣,我們很高興能關注並嘗試可能開發出來的新技術,以克服當前工作的一些限制。
總結
- 訓練神經網路的意義:提供好的資料,設定好的架構和訓練以接收良好的梯度訊號,並使其收斂。
- Infini-attention 的長上下文效能隨著記憶體壓縮次數的增加而下降。
- 門控很重要;調整訓練以使門控收斂可以提高 Infini-attention 的長上下文效能(但還不夠好)。
- 始終訓練一個好的參考模型作為基線來衡量進展。
- 還有一個錯誤會搞亂注意力輸出的維度,導致即使訓練過程中損失不斷下降,模型仍然無法在其片段長度內生成連貫的文字。得到的教訓是:即使你對模型進行了糟糕的條件設定,梯度下降仍然可以找到降低損失的方法。然而,模型不會按預期工作,所以務必進行評估。
致謝
感謝 Leandro von Werra 和 Thomas Wolf 在專案中的指導,以及 Tsendsuren Munkhdalai 分享了原始實驗的更多細節。我們還要感謝 Leandro 對部落格文章的反饋,並感謝 Hugging Face 的科學叢集提供的計算資源。