Transformers 文件
Attention Interface
並獲得增強的文件體驗
開始使用
注意力介面
本頁描述如何使用 `AttentionInterface` 來註冊自定義注意力函式,以用於支援的模型。
自定義注意力函式
得益於一個簡單的對映,大多數最新模型現在可以從注意力層中使用的注意力函式切換到另一個注意力函式。預設情況下,我們提供了 `sdpa`、`flash_attention_2` 和 `flex_attention` 的實現,以及 `eager`,它是一個簡單的矩陣乘法,沒有任何最佳化。
這是您在例項化模型時通常可以選擇的設定
from transformers import AutoModelForCausalLM
model_id = "meta-llama/Llama-3.2-1B"
# Here, using flash attention as an example
model = AutoModelForCausalLM.from_pretrained(model_id, attn_implementation="flash_attention_2")
但是,如果您想建立自己的注意力函式呢?或者只是嘗試現有的函式,在其中新增一些語句?現在,您可以使用 `AttentionInterface` 來實現!這是一個示例
from transformers import AutoModelForCausalLM, AttentionInterface
from transformers.integrations.sdpa_attention import sdpa_attention_forward
import torch
model_id = "meta-llama/Llama-3.2-1B"
def my_new_sdpa(*args, **kwargs):
print("I just entered the attention computation")
return sdpa_attention_forward(*args, **kwargs)
AttentionInterface.register("my_new_sdpa", my_new_sdpa)
model = AutoModelForCausalLM.from_pretrained(model_id, attn_implementation="my_new_sdpa")
# Try running the forward with the new attention function
model(torch.ones(1, 5, dtype=int))
您將看到它列印“I just entered the attention computation”,列印次數與模型中的層數一樣多(在此示例中為 16 次)。
動態切換注意力函式
您還可以透過覆蓋 `config._attn_implementation` 欄位來動態更改模型的注意力函式
# Back to use original sdpa implementation
model.config._attn_implementation = "sdpa"
model(torch.ones(1, 5, dtype=int))
它將停止列印語句,因為它現在使用 `sdpa` 注意力。
這允許快速更改注意力函式,而無需重新載入模型!
我的自定義注意力函式中需要新引數怎麼辦?
但確實,如果新函式需要新引數才能正常使用怎麼辦?這不是問題!支援 `AttentionInterface` 的模型會將 kwargs 一直傳播到注意力層和所使用的注意力函式。這樣,您只需在模型的 forward 中傳遞引數(作為 kwargs,即您需要限定引數名稱),它就會在注意力中正確使用。但是,自定義注意力函式有一些限制。特別是,它必須遵循其他注意力函式的簽名和返回格式,即
from transformers import AutoModelForCausalLM, AttentionInterface
from transformers.integrations.sdpa_attention import sdpa_attention_forward
import torch
def custom_attention(
module: torch.nn.Module, # required arg
query: torch.Tensor, # required arg
key: torch.Tensor, # required arg
value: torch.Tensor, # required arg
attention_mask: Optional[torch.Tensor], # required arg
a_new_kwargs = None, # You can now add as many kwargs as you need
another_new_kwargs = None, # You can now add as many kwargs as you need
**kwargs, # You need to accept **kwargs as models will pass other args
) -> tuple[torch.Tensor, Optional[torch.Tensor]]
... # do your magic!
return attn_output, attn_weights # attn_weights are optional here
AttentionInterface.register("custom", custom_attention)
model = AutoModelForCausalLM.from_pretrained(model_id, attn_implementation="custom")
# Forward pass with the new kwargs
model(torch.ones(1, 5, dtype=int), a_new_kwargs=..., another_new_kwargs=...)
如果對給定模型傳送給注意力函式的引數/關鍵字引數有疑問,只需檢視該模型在 GitHub 上的建模程式碼!
訪問當前可用實現
大多數情況下,您只需 `註冊` 一個新函式。但是,如果您需要訪問現有函式,和/或執行一些檢查,首選的方式是使用全域性 `ALL_ATTENTION_FUNCTIONS`。它的行為方式與您期望的普通 Python 字典相同
>>> from transformers.modeling_utils import ALL_ATTENTION_FUNCTIONS
>>> list(ALL_ATTENTION_FUNCTIONS.keys())
>>> ['flash_attention_2', 'flex_attention', 'sdpa']
>>> ALL_ATTENTION_FUNCTIONS["sdpa"]
>>> <function transformers.integrations.sdpa_attention.sdpa_attention_forward>
>>> ALL_ATTENTION_FUNCTIONS.get("sdpa", None)
>>> <function transformers.integrations.sdpa_attention.sdpa_attention_forward>
# You can also globally `register` a new function directly on it
>>> ALL_ATTENTION_FUNCTIONS.register("new_func", new_func)
注意力掩碼介面
擁有一個新的注意力函式可能意味著您需要一種新的注意力掩碼格式來決定查詢令牌應該關注哪些鍵和值令牌。現在,透過 `AttentionMaskInterface` 可以實現這一點!它的工作方式與 `AttentionInterface` 相同。
from transformers import AttentionMaskInterface
from transformers.masking_utils import sdpa_mask
import torch
def my_new_sdpa_mask(*args, **kwargs):
print("I just entered the attention mask computation")
return sdpa_mask(*args, **kwargs)
AttentionMaskInterface.register("my_new_sdpa_mask", my_new_sdpa_mask)
您必須註冊它的原因是,我們需要根據注意力實現自動更正您的掩碼格式(例如,flex attention 使用 BlockMask 格式,而 sdpa 使用 4D 張量)。預設情況下,如果您沒有註冊注意力掩碼函式以及您的注意力函式,將跳過掩碼建立,並且 `attention_mask=None` 將傳遞給注意力層。
注意力掩碼函式的預設簽名如下:
def custom_attention_mask(
batch_size: int, # required arg
cache_position: torch.Tensor, # required arg
kv_length: int, # required arg
kv_offset: int = 0, # required arg
mask_function: Callable = causal_mask_function, # required arg
attention_mask: Optional[torch.Tensor] = None, # required arg
**kwargs, # a few additional args may be passed as kwargs, especially the model's config is always passed
) -> Optional[torch.Tensor]:
它主要透過 `mask_function` 實現,這是一個 `Callable`,形式類似於 torch 的 `mask_mod` 函式,它接受 4 個索引作為輸入,並返回一個布林值以指示該位置是否應參與注意力計算。
如果由於某種原因無法使用 `mask_function` 建立掩碼,您可以嘗試透過類似於我們的 torch 匯出變通方法 來解決。
< > 在 GitHub 上更新