Transformers 文件

自定義模型元件

Hugging Face's logo
加入 Hugging Face 社群

並獲得增強的文件體驗

開始使用

自定義模型元件

自定義模型的另一種方法是修改其元件,而不是完全編寫一個新模型,這允許您根據特定的用例定製模型。例如,您可以新增新層或最佳化架構的注意力機制。自定義直接應用於 Transformers 模型,因此您可以繼續使用諸如 TrainerPreTrainedModelPEFT 庫等功能。

本指南將向您展示如何自定義模型的注意力機制,以便對其應用 低秩適應 (LoRA)

clear_import_cache 工具在迭代修改和開發模型程式碼時非常有用。它會刪除所有快取的 Transformers 模組,並允許 Python 重新載入修改後的程式碼,而無需不斷重啟您的環境。

from transformers import AutoModel
from transformers.utils.import_utils import clear_import_cache

model = AutoModel.from_pretrained("bert-base-uncased")
# modifications to model code
# clear cache to reload modified code
clear_import_cache()
# re-import to use updated code
model = AutoModel.from_pretrained("bert-base-uncased")

注意力類

Segment Anything 是一個影像分割模型,它在其注意力機制中結合了查詢-鍵-值 (qkv) 投影。為了減少可訓練引數的數量和計算開銷,您可以將 LoRA 應用於 qkv 投影。這需要拆分 qkv 投影,以便您可以單獨使用 LoRA 定位 qv

  1. 透過繼承原始的 SamVisionAttention 類來建立一個自定義注意力類 SamVisionAttentionSplit。在 __init__ 中,刪除組合的 qkv 併為 qkv 建立單獨的線性層。
import torch
import torch.nn as nn
from transformers.models.sam.modeling_sam import SamVisionAttention

class SamVisionAttentionSplit(SamVisionAttention, nn.Module):
    def __init__(self, config, window_size):
        super().__init__(config, window_size)
        # remove combined qkv
        del self.qkv
        # separate q, k, v projections
        self.q = nn.Linear(config.hidden_size, config.hidden_size, bias=config.qkv_bias)
        self.k = nn.Linear(config.hidden_size, config.hidden_size, bias=config.qkv_bias)
        self.v = nn.Linear(config.hidden_size, config.hidden_size, bias=config.qkv_bias)
        self._register_load_state_dict_pre_hook(self.split_q_k_v_load_hook)
  1. _split_qkv_load_hook 函式在載入模型時將預訓練的 qkv 權重拆分為單獨的 qkv 權重,以確保與任何預訓練模型的相容性。
    def split_q_k_v_load_hook(self, state_dict, prefix, *args):
        keys_to_delete = []
        for key in list(state_dict.keys()):
            if "qkv." in key:
                # split q, k, v from the combined projection
                q, k, v = state_dict[key].chunk(3, dim=0)
                # replace with individual q, k, v projections
                state_dict[key.replace("qkv.", "q.")] = q
                state_dict[key.replace("qkv.", "k.")] = k
                state_dict[key.replace("qkv.", "v.")] = v
                # mark the old qkv key for deletion
                keys_to_delete.append(key)
        
        # remove old qkv keys
        for key in keys_to_delete:
            del state_dict[key]
  1. forward 傳遞中,qkv 分別計算,而其餘的注意力機制保持不變。
    def forward(self, hidden_states: torch.Tensor, output_attentions=False) -> torch.Tensor:
        batch_size, height, width, _ = hidden_states.shape
        qkv_shapes = (batch_size *  self.num_attention_heads,  height * width, -1)
        query = self.q(hidden_states).reshape((batch_size,  height * width,self.num_attention_heads, -1)).permute(0,2,1,3).reshape(qkv_shapes)
        key = self.k(hidden_states).reshape((batch_size,  height * width,self.num_attention_heads, -1)).permute(0,2,1,3).reshape(qkv_shapes)
        value = self.v(hidden_states).reshape((batch_size,  height * width,self.num_attention_heads, -1)).permute(0,2,1,3).reshape(qkv_shapes)

        attn_weights = (query * self.scale) @ key.transpose(-2, -1)

        attn_weights = torch.nn.functional.softmax(attn_weights, dtype=torch.float32, dim=-1).to(query.dtype)
        attn_probs = nn.functional.dropout(attn_weights, p=self.dropout, training=self.training)
        attn_output = (attn_probs @ value).reshape(batch_size, self.num_attention_heads, height, width, -1)
        attn_output = attn_output.permute(0, 2, 3, 1, 4).reshape(batch_size, height, width, -1)
        attn_output = self.proj(attn_output)

        if output_attentions:
            outputs = (attn_output, attn_weights)
        else:
            outputs = (attn_output, None)
        return outputs

將自定義的 SamVisionAttentionSplit 類分配給原始模型的 SamVisionAttention 模組以替換它。模型中所有 SamVisionAttention 的例項都替換為拆分注意力版本。

使用 from_pretrained() 載入模型。

from transformers import SamModel

# load the pretrained SAM model
model = SamModel.from_pretrained("facebook/sam-vit-base")

# replace the attention class in the vision_encoder module
for layer in model.vision_encoder.layers:
    if hasattr(layer, "attn"):
        layer.attn = SamVisionAttentionSplit(model.config.vision_config, model.config.vision_config.window_size)

LoRA

使用單獨的 qkv 投影,將 LoRA 應用於 qv

建立一個 LoraConfig 並指定秩 rlora_alphalora_dropouttask_type,最重要的是要定位的模組。

from peft import LoraConfig, get_peft_model

config = LoraConfig(
    r=16,
    lora_alpha=32,
    # apply LoRA to q and v
    target_modules=["q", "v"],
    lora_dropout=0.1,
    task_type="FEATURE_EXTRACTION"
)

將模型和 LoraConfig 傳遞給 get_peft_model 以將 LoRA 應用於模型。

model = get_peft_model(model, config)

呼叫 print_trainable_parameters 以檢視您訓練的引數數量與總引數數量的對比。

model.print_trainable_parameters()
"trainable params: 589,824 || all params: 94,274,096 || trainable%: 0.6256"
< > 在 GitHub 上更新

© . This site is unofficial and not affiliated with Hugging Face, Inc.