Transformers 文件
自定義模型元件
並獲得增強的文件體驗
開始使用
自定義模型元件
自定義模型的另一種方法是修改其元件,而不是完全編寫一個新模型,這允許您根據特定的用例定製模型。例如,您可以新增新層或最佳化架構的注意力機制。自定義直接應用於 Transformers 模型,因此您可以繼續使用諸如 Trainer、PreTrainedModel 和 PEFT 庫等功能。
本指南將向您展示如何自定義模型的注意力機制,以便對其應用 低秩適應 (LoRA)。
clear_import_cache 工具在迭代修改和開發模型程式碼時非常有用。它會刪除所有快取的 Transformers 模組,並允許 Python 重新載入修改後的程式碼,而無需不斷重啟您的環境。
from transformers import AutoModel
from transformers.utils.import_utils import clear_import_cache
model = AutoModel.from_pretrained("bert-base-uncased")
# modifications to model code
# clear cache to reload modified code
clear_import_cache()
# re-import to use updated code
model = AutoModel.from_pretrained("bert-base-uncased")
注意力類
Segment Anything 是一個影像分割模型,它在其注意力機制中結合了查詢-鍵-值 (qkv
) 投影。為了減少可訓練引數的數量和計算開銷,您可以將 LoRA 應用於 qkv
投影。這需要拆分 qkv
投影,以便您可以單獨使用 LoRA 定位 q
和 v
。
- 透過繼承原始的
SamVisionAttention
類來建立一個自定義注意力類SamVisionAttentionSplit
。在__init__
中,刪除組合的qkv
併為q
、k
和v
建立單獨的線性層。
import torch
import torch.nn as nn
from transformers.models.sam.modeling_sam import SamVisionAttention
class SamVisionAttentionSplit(SamVisionAttention, nn.Module):
def __init__(self, config, window_size):
super().__init__(config, window_size)
# remove combined qkv
del self.qkv
# separate q, k, v projections
self.q = nn.Linear(config.hidden_size, config.hidden_size, bias=config.qkv_bias)
self.k = nn.Linear(config.hidden_size, config.hidden_size, bias=config.qkv_bias)
self.v = nn.Linear(config.hidden_size, config.hidden_size, bias=config.qkv_bias)
self._register_load_state_dict_pre_hook(self.split_q_k_v_load_hook)
_split_qkv_load_hook
函式在載入模型時將預訓練的qkv
權重拆分為單獨的q
、k
和v
權重,以確保與任何預訓練模型的相容性。
def split_q_k_v_load_hook(self, state_dict, prefix, *args):
keys_to_delete = []
for key in list(state_dict.keys()):
if "qkv." in key:
# split q, k, v from the combined projection
q, k, v = state_dict[key].chunk(3, dim=0)
# replace with individual q, k, v projections
state_dict[key.replace("qkv.", "q.")] = q
state_dict[key.replace("qkv.", "k.")] = k
state_dict[key.replace("qkv.", "v.")] = v
# mark the old qkv key for deletion
keys_to_delete.append(key)
# remove old qkv keys
for key in keys_to_delete:
del state_dict[key]
- 在
forward
傳遞中,q
、k
和v
分別計算,而其餘的注意力機制保持不變。
def forward(self, hidden_states: torch.Tensor, output_attentions=False) -> torch.Tensor:
batch_size, height, width, _ = hidden_states.shape
qkv_shapes = (batch_size * self.num_attention_heads, height * width, -1)
query = self.q(hidden_states).reshape((batch_size, height * width,self.num_attention_heads, -1)).permute(0,2,1,3).reshape(qkv_shapes)
key = self.k(hidden_states).reshape((batch_size, height * width,self.num_attention_heads, -1)).permute(0,2,1,3).reshape(qkv_shapes)
value = self.v(hidden_states).reshape((batch_size, height * width,self.num_attention_heads, -1)).permute(0,2,1,3).reshape(qkv_shapes)
attn_weights = (query * self.scale) @ key.transpose(-2, -1)
attn_weights = torch.nn.functional.softmax(attn_weights, dtype=torch.float32, dim=-1).to(query.dtype)
attn_probs = nn.functional.dropout(attn_weights, p=self.dropout, training=self.training)
attn_output = (attn_probs @ value).reshape(batch_size, self.num_attention_heads, height, width, -1)
attn_output = attn_output.permute(0, 2, 3, 1, 4).reshape(batch_size, height, width, -1)
attn_output = self.proj(attn_output)
if output_attentions:
outputs = (attn_output, attn_weights)
else:
outputs = (attn_output, None)
return outputs
將自定義的 SamVisionAttentionSplit
類分配給原始模型的 SamVisionAttention
模組以替換它。模型中所有 SamVisionAttention
的例項都替換為拆分注意力版本。
使用 from_pretrained() 載入模型。
from transformers import SamModel
# load the pretrained SAM model
model = SamModel.from_pretrained("facebook/sam-vit-base")
# replace the attention class in the vision_encoder module
for layer in model.vision_encoder.layers:
if hasattr(layer, "attn"):
layer.attn = SamVisionAttentionSplit(model.config.vision_config, model.config.vision_config.window_size)
LoRA
使用單獨的 q
、k
和 v
投影,將 LoRA 應用於 q
和 v
。
建立一個 LoraConfig 並指定秩 r
、lora_alpha
、lora_dropout
、task_type
,最重要的是要定位的模組。
from peft import LoraConfig, get_peft_model
config = LoraConfig(
r=16,
lora_alpha=32,
# apply LoRA to q and v
target_modules=["q", "v"],
lora_dropout=0.1,
task_type="FEATURE_EXTRACTION"
)
將模型和 LoraConfig 傳遞給 get_peft_model 以將 LoRA 應用於模型。
model = get_peft_model(model, config)
呼叫 print_trainable_parameters 以檢視您訓練的引數數量與總引數數量的對比。
model.print_trainable_parameters()
"trainable params: 589,824 || all params: 94,274,096 || trainable%: 0.6256"