Nyströmformer:透過 Nyström 方法線上性時間和記憶體中逼近自注意力機制
引言
Transformer 在各種自然語言處理和計算機視覺任務上都表現出了卓越的效能。其成功可歸功於自注意力機制,該機制能捕捉輸入中所有詞元(token)之間的成對互動。然而,標準的自注意力機制的時間和記憶體複雜度為 (其中 是輸入序列的長度),這使得在長輸入序列上進行訓練的成本很高。
Nyströmformer 是眾多高效 Transformer 模型之一,它以 的複雜度逼近了標準自注意力機制。Nyströmformer 在各種下游 NLP 和 CV 任務上展現出有競爭力的效能,同時提升了標準自注意力機制的效率。這篇博文旨在向讀者概述 Nyström 方法以及如何將其應用於逼近自注意力機制。
用於矩陣逼近的 Nyström 方法
Nyströmformer 的核心是用於矩陣逼近的 Nyström 方法。它允許我們透過取樣矩陣的部分行和列來逼近整個矩陣。讓我們考慮一個矩陣 ,完整計算該矩陣的成本很高。因此,我們轉而使用 Nyström 方法對其進行逼近。我們首先從 中取樣 行和 列。然後我們可以將取樣出的行和列排列如下:
現在我們有四個子矩陣: 和 ,它們的尺寸分別為 和 。取樣的 列包含在 和 中,而取樣的 行包含在 和 中。因此, 和 的元素是已知的,我們將估計 。根據 Nyström 方法, 由下式給出:
這裡, 表示 Moore-Penrose 逆(或偽逆)。因此, 的 Nyström 逼近 可以寫為:
如第二行所示, 可以表示為三個矩陣的乘積。這樣做的原因稍後會變得清晰。
我們能用 Nyström 方法逼近自注意力機制嗎?
我們的最終目標是逼近標準自注意力機制中的 softmax 矩陣:S = softmax
這裡, 和 分別表示查詢(queries)和鍵(keys)。按照上面討論的程式,我們會從 中取樣 行和 列,形成四個子矩陣,並得到 。
但是,從 中取樣一列意味著什麼呢?這意味著我們從每一行中選擇一個元素。回想一下 S 是如何計算的:最後一步是逐行進行 softmax。要計算一行中的單個元素,我們必須訪問該行的所有其他元素(用於 softmax 的分母)。因此,取樣一列需要我們知道矩陣中的所有其他列。所以,我們無法直接應用 Nyström 方法來逼近 softmax 矩陣。
如何調整 Nyström 方法以逼近自注意力機制?
作者提出,不從 中取樣,而是從查詢和鍵中取樣地標點(landmarks)(或 Nyström 點)。我們將查詢地標點和鍵地標點表示為 和 。 和 可用於構建三個矩陣,這些矩陣對應於 的 Nyström 逼近中的矩陣。我們定義以下矩陣:
、 和 和 。我們用定義的新矩陣替換 的 Nyström 逼近中的三個矩陣,得到一個替代的 Nyström 逼近:
這就是自注意力機制中 softmax 矩陣的 Nyström 逼近。我們將這個矩陣與值()相乘,得到自注意力的線性逼近。請注意,我們從未計算乘積 ,從而避免了 的複雜度。
如何選擇地標點(landmarks)?
作者提出,不從 和 中取樣 行,而是使用分段均值來構建 和 。在這個過程中, 個詞元被分成 個段,並計算每個段的均值。理想情況下, 遠小於 。根據論文中的實驗,即使對於長序列( 或 ),僅選擇 或 個地標點就能產生與標準自注意力機制及其他高效注意力機制相媲美的效能。
論文中的下圖總結了整個演算法:
上圖中的三個橙色矩陣對應於我們使用鍵和查詢地標點構建的三個矩陣。另外,請注意有一個 DConv 框。這對應於使用一維深度卷積向值(values)新增的跳躍連線(skip connection)。
Nyströmformer 是如何實現的?
Nyströmformer 的原始實現可以在這裡找到,HuggingFace 的實現可以在這裡找到。讓我們看一下 HuggingFace 實現中的幾行程式碼(添加了一些註釋)。請注意,為了簡化,省略了一些細節,如歸一化、注意力掩碼和深度卷積。
key_layer = self.transpose_for_scores(self.key(hidden_states)) # K
value_layer = self.transpose_for_scores(self.value(hidden_states)) # V
query_layer = self.transpose_for_scores(mixed_query_layer) # Q
q_landmarks = query_layer.reshape(
-1,
self.num_attention_heads,
self.num_landmarks,
self.seq_len // self.num_landmarks,
self.attention_head_size,
).mean(dim=-2) # \tilde{Q}
k_landmarks = key_layer.reshape(
-1,
self.num_attention_heads,
self.num_landmarks,
self.seq_len // self.num_landmarks,
self.attention_head_size,
).mean(dim=-2) # \tilde{K}
kernel_1 = torch.nn.functional.softmax(torch.matmul(query_layer, k_landmarks.transpose(-1, -2)), dim=-1) # \tilde{F}
kernel_2 = torch.nn.functional.softmax(torch.matmul(q_landmarks, k_landmarks.transpose(-1, -2)), dim=-1) # \tilde{A} before pseudo-inverse
attention_scores = torch.matmul(q_landmarks, key_layer.transpose(-1, -2)) # \tilde{B} before softmax
kernel_3 = nn.functional.softmax(attention_scores, dim=-1) # \tilde{B}
attention_probs = torch.matmul(kernel_1, self.iterative_inv(kernel_2)) # \tilde{F} * \tilde{A}
new_value_layer = torch.matmul(kernel_3, value_layer) # \tilde{B} * V
context_layer = torch.matmul(attention_probs, new_value_layer) # \tilde{F} * \tilde{A} * \tilde{B} * V
在 HuggingFace 中使用 Nyströmformer
用於掩碼語言建模(MLM)的 Nyströmformer 已在 HuggingFace 上提供。目前有 4 個檢查點,對應不同的序列長度:nystromformer-512
、nystromformer-1024
、nystromformer-2048
和 nystromformer-4096
。地標點的數量 可以透過 NystromformerConfig
中的 num_landmarks
引數來控制。讓我們看一個 Nyströmformer 用於 MLM 的最小示例:
from transformers import AutoTokenizer, NystromformerForMaskedLM
import torch
tokenizer = AutoTokenizer.from_pretrained("uw-madison/nystromformer-512")
model = NystromformerForMaskedLM.from_pretrained("uw-madison/nystromformer-512")
inputs = tokenizer("Paris is the [MASK] of France.", return_tensors="pt")
with torch.no_grad():
logits = model(**inputs).logits
# retrieve index of [MASK]
mask_token_index = (inputs.input_ids == tokenizer.mask_token_id)[0].nonzero(as_tuple=True)[0]
predicted_token_id = logits[0, mask_token_index].argmax(axis=-1)
tokenizer.decode(predicted_token_id)
Output:
----------------------------------------------------------------------------------------------------
capital
另外,我們也可以使用 pipeline API(它為我們處理了所有複雜性):
from transformers import pipeline
unmasker = pipeline('fill-mask', model='uw-madison/nystromformer-512')
unmasker("Paris is the [MASK] of France.")
Output:
----------------------------------------------------------------------------------------------------
[{'score': 0.829957902431488,
'token': 1030,
'token_str': 'capital',
'sequence': 'paris is the capital of france.'},
{'score': 0.022157637402415276,
'token': 16081,
'token_str': 'birthplace',
'sequence': 'paris is the birthplace of france.'},
{'score': 0.01904447190463543,
'token': 197,
'token_str': 'name',
'sequence': 'paris is the name of france.'},
{'score': 0.017583081498742104,
'token': 1107,
'token_str': 'kingdom',
'sequence': 'paris is the kingdom of france.'},
{'score': 0.005948934704065323,
'token': 148,
'token_str': 'city',
'sequence': 'paris is the city of france.'}]
結論
Nyströmformer 為標準自注意力機制提供了一種高效的逼近方法,同時其效能優於其他線性自注意力方案。在這篇博文中,我們概要地介紹了 Nyström 方法以及如何將其用於自注意力機制。有興趣在下游任務中部署或微調 Nyströmformer 的讀者可以在這裡找到 HuggingFace 的文件。