Timm ❤️ Transformers:在 Transformers 中使用任何 timm 模型

釋出於 2025 年 1 月 16 日
在 GitHub 上更新

在友好的 🤗 transformers 生態系統中,為**任何** timm 模型實現閃電般的推理速度、快速量化、torch.compile 加速和輕鬆微調。

隆重介紹 TimmWrapper——一個簡單而強大的工具,釋放了這一潛力。

在這篇文章中,我們將涵蓋

  • timm 整合的工作原理及其為何能改變遊戲規則。
  • 如何將 timm 模型與 🤗 transformers 整合。
  • 實踐示例:pipeline、量化、微調等。

要跟隨本部落格文章,請執行以下命令安裝最新版本的 transformerstimm

pip install -Uq transformers timm

檢視包含所有程式碼示例和 notebook 的完整程式碼庫:🔗 TimmWrapper 示例

什麼是 timm?

PyTorch Image Models (timm) 庫提供了豐富的最先進的計算機視覺模型,以及有用的層、工具、最佳化器和資料增強。截至本文撰寫時,它在 GitHub 上擁有超過 32K 顆星,每日下載量超過 200K,是影像分類和特徵提取(用於目標檢測、分割、影像搜尋等下游任務)的首選資源。

timm 擁有涵蓋各種架構的預訓練模型,簡化了計算機視覺從業者的工作流程。

為何使用 timm 整合?

雖然 🤗 transformers 支援多種視覺模型,但 timm 提供了更廣泛的集合,包括許多在 transformers 中不可用的移動端友好和高效的模型。

timm 整合彌補了這一差距,帶來了兩全其美的優勢

  • Pipeline API 支援:輕鬆將任何 timm 模型插入到高階 transformers pipeline 中,以實現流線型推理。
  • 🧩 與 Auto 類相容:雖然 timm 模型本身與 transformers 不相容,但此整合使其能夠與 Auto 類 API 無縫協作。
  • 快速量化:只需約 5 行程式碼,您就可以量化**任何** timm 模型以進行高效推理。
  • 🎯 使用 Trainer API 進行微調:使用 Trainer API 微調 timm 模型,甚至可以與低秩自適應 (LoRA) 等介面卡整合。
  • 🔁 返回 timm:在 timm 中再次使用微調後的模型。
  • 🚀 Torch Compile 加速:利用 torch.compile 最佳化推理時間。

Pipeline API:使用 timm 模型進行影像分類

timm 整合的一個突出特點是它允許您利用 🤗 pipeline APIpipeline API 抽象了許多複雜性,使得載入預訓練模型、執行推理和檢視結果變得非常簡單,只需幾行程式碼即可完成。

讓我們看看如何將 transformers pipeline 與 MobileNetV4 一起使用。該架構沒有原生的 transformers 實現,但可以輕鬆地從 timm 中使用

from transformers import pipeline
import requests

# Load the image classification pipeline with a timm model
image_classifier = pipeline(model="timm/mobilenetv4_conv_medium.e500_r256_in1k")

# URL of the image to classify
url = "https://huggingface.co/datasets/huggingface/documentation-images/resolve/main/timm/cat.jpg"

# Perform inference
outputs = image_classifier(url)

# Print the results
for output in outputs:
    print(f"Label: {output['label'] :20} Score: {output['score'] :0.2f}")

輸出:

Device set to use cpu
Label: tabby, tabby cat     Score: 0.69
Label: tiger cat            Score: 0.21
Label: Egyptian cat         Score: 0.02
Label: bee                  Score: 0.00
Label: marmoset             Score: 0.00

Gradio 整合:構建食物分類器演示 🍣

想要快速建立一個用於影像分類的互動式 Web 應用嗎?Gradio 使您能夠用最少的程式碼構建一個使用者友好的介面。讓我們將 Gradiopipeline API 結合起來,使用一個微調過的 timm ViT 模型來分類食物影像(我們將在後面的章節中介紹微調)。

以下是如何使用 timm 模型快速設定一個演示

import gradio as gr
from transformers import pipeline

# Load the image classification pipeline using a timm model
pipe = pipeline(
    "image-classification",
    model="ariG23498/vit_base_patch16_224.augreg2_in21k_ft_in1k.ft_food101"
)

def classify(image):
    return pipe(image)[0]["label"]

demo = gr.Interface(
    fn=classify,
    inputs=gr.Image(type="pil"),
    outputs="text",
    examples=[["./sushi.png", "sushi"]]
)

demo.launch()

這是一個託管在 Hugging Face Spaces 上的即時示例。您可以直接在瀏覽器中測試!

Auto 類:簡化模型載入

🤗 transformers 庫提供了 Auto 類 來抽象化載入模型和處理器的複雜性。透過 TimmWrapper,您可以使用 AutoModelForImageClassificationAutoImageProcessor 輕鬆載入任何 timm 模型。

這是一個快速示例

from transformers import (
    AutoModelForImageClassification,
    AutoImageProcessor,
)
from transformers.image_utils import load_image

image_url = "https://huggingface.co/datasets/huggingface/documentation-images/resolve/main/timm/cat.jpg"
image = load_image(image_url)

# Use Auto classes to load a timm model
checkpoint = "timm/mobilenetv4_conv_medium.e500_r256_in1k"
image_processor = AutoImageProcessor.from_pretrained(checkpoint)
model = AutoModelForImageClassification.from_pretrained(checkpoint).eval()

# Check the types
print(type(image_processor))  # TimmWrapperImageProcessor
print(type(model))            # TimmWrapperForImageClassification

執行量化的 timm 模型

量化是一種強大的技術,可以減小模型大小並加速推理,尤其適用於資源受限的裝置。透過 timm 整合,您可以使用 bitsandbytes 中的 BitsAndBytesConfig,只需幾行程式碼即可即時量化任何 timm 模型。

以下是量化一個 timm 模型是多麼簡單

from transformers import TimmWrapperForImageClassification, BitsAndBytesConfig

quantization_config = BitsAndBytesConfig(load_in_8bit=True)
checkpoint = "timm/vit_base_patch16_224.augreg2_in21k_ft_in1k"

model = TimmWrapperForImageClassification.from_pretrained(checkpoint).to("cuda")
model_8bit = TimmWrapperForImageClassification.from_pretrained(
    checkpoint,
    quantization_config=quantization_config,
    low_cpu_mem_usage=True,
)
original_footprint = model.get_memory_footprint()
quantized_footprint = model_8bit.get_memory_footprint()

print(f"Original model size: {original_footprint / 1e6:.2f} MB")
print(f"Quantized model size: {quantized_footprint / 1e6:.2f} MB")
print(f"Reduction: {(original_footprint - quantized_footprint) / original_footprint * 100:.2f}%")

輸出

Original model size: 346.27 MB  
Quantized model size: 88.20 MB  
Reduction: 74.53%  

量化模型在推理時的效能與全精度模型幾乎完全相同

模型 標籤 準確率
原始模型 遙控器,遙控 0.35%
量化模型 遙控器,遙控 0.33%

timm 模型的監督式微調

使用 🤗 transformersTrainer API 微調 timm 模型是直接且高度靈活的。您可以使用 Trainer 類在自定義資料集上微調您的模型,該類處理訓練迴圈、日誌記錄和評估。此外,您可以使用 LoRA (低秩自適應) 進行微調,以更少的引數高效地進行訓練。

本節對標準微調和 LoRA 微調進行了簡要概述,並提供了完整程式碼的連結。

使用 Trainer API 進行標準微調

Trainer API 使得用最少的程式碼設定訓練變得容易。以下是微調設定的概要

from transformers import TrainingArguments, Trainer

# Define training arguments
training_args = TrainingArguments(
    output_dir="my_model_output",
    evaluation_strategy="epoch",
    save_strategy="epoch",
    learning_rate=5e-5,
    per_device_train_batch_size=16,
    num_train_epochs=3,
    load_best_model_at_end=True,
    push_to_hub=True,
)

# Initialize the Trainer
trainer = Trainer(
    model=model,
    args=training_args,
    train_dataset=train_ds,
    eval_dataset=val_ds,
    data_collator=data_collator,
    compute_metrics=compute_metrics,
)

# Start training
trainer.train()

這種方法的顯著之處在於,它反映了用於原生 transformers 模型的完全相同的工作流程,從而在不同模型型別之間保持了一致性。

這意味著您可以使用熟悉的 Trainer API 不僅微調 Transformers 模型,還可以微調**任何 timm 模型**——將 timm 庫中強大的模型引入 Hugging Face 生態系統,只需進行最少的調整。這極大地拓寬了您可以使用相同可信賴的工具和工作流程進行微調的模型範圍。

模型示例
Food-101 上微調的 ViT:vit_base_patch16_224.augreg2_in21k_ft_in1k.ft_food101

LoRA 微調以實現高效訓練

LoRA (低秩自適應) 允許您透過僅訓練少量額外引數而不是完整的模型權重來高效地微調大型模型。這使得微調更快,並允許使用消費級硬體。您可以使用 PEFT 庫透過 LoRA 微調一個 timm 模型。

以下是您可以如何設定它

from peft import LoraConfig, get_peft_model

model = AutoModelForImageClassification.from_pretrained(checkpoint, num_labels=num_labels)
lora_config = LoraConfig(
    r=16,
    lora_alpha=16,
    target_modules=["qkv"],
    lora_dropout=0.1,
    bias="none",
    modules_to_save=["head"],
)

# Wrap the model with PEFT
lora_model = get_peft_model(model, lora_config)

lora_model.print_trainable_parameters()

使用 LoRA 的可訓練引數

trainable params: 667,493 || all params: 86,543,818 || trainable%: 0.77%

模型示例
Food-101 上進行 LoRA 微調的 ViT:vit_base_patch16_224.augreg2_in21k_ft_in1k.lora_ft_food101

LoRA 只是您可以應用於 timm 模型的眾多高效介面卡微調方法中的一個例子。timm 與 🤗 生態系統的整合為您開啟了各種**引數高效微調 (PEFT)** 技術的大門,讓您可以選擇最適合您應用場景的方法。

使用 LoRA 微調模型進行推理

一旦模型經過 LoRA 微調,我們僅將介面卡權重推送到 Hugging Face Hub。本節將幫助您下載介面卡權重,將介面卡權重與基礎模型合併,然後進行推理。

from peft import PeftModel, PeftConfig

repo_name = "ariG23498/vit_base_patch16_224.augreg2_in21k_ft_in1k.lora_ft_food101"
config = PeftConfig.from_pretrained(repo_name)

model = AutoModelForImageClassification.from_pretrained(
    config.base_model_name_or_path,
    label2id=label2id,
    num_labels=num_labels,
    id2label=id2label,
    ignore_mismatched_sizes=True,
)
inference_model = PeftModel.from_pretrained(model, repo_name)

# Make prediction with the model

image of sushi with prediction from a fine tuned model

雙向整合

Ross (timm的建立者) 最喜歡的一個功能是,這種整合保持了完整的“雙向”相容性。也就是說,使用包裝器,人們可以使用 transformerTrainer 在新資料集上微調 timm 模型,將結果模型釋出到 Hugging Face hub,然後再次使用 timm.create_model('hf-hub:my_org/my_fine_tuned_model', pretrained=True)timm 中載入微調後的模型。

讓我們看看如何用 timm 載入我們微調過的模型 ariG23498/vit_base_patch16_224.augreg2_in21k_ft_in1k.ft_food101

checkpoint = "ariG23498/vit_base_patch16_224.augreg2_in21k_ft_in1k.ft_food101"

config = AutoConfig.from_pretrained(checkpoint)

model = timm.create_model(f"hf_hub:{checkpoint}", pretrained=True) # Load the model with timm
model = model.eval()

image = load_image("https://cdn.britannica.com/52/128652-050-14AD19CA/Maki-zushi.jpg")

data_config = timm.data.resolve_model_data_config(model)
transforms = timm.data.create_transform(**data_config, is_training=False)

output = model(transforms(image).unsqueeze(0))

top5_probabilities, top5_class_indices = torch.topk(output.softmax(dim=1) * 100, k=5)

for prob, idx in zip(top5_probabilities[0], top5_class_indices[0]):
    print(f"Label: {config.id2label[idx.item()] :20} Score: {prob/100 :0.2f}%")

輸出

Label: sushi                Score: 0.98%
Label: spring_rolls         Score: 0.01%
Label: sashimi              Score: 0.00%
Label: club_sandwich        Score: 0.00%
Label: cannoli              Score: 0.00%

Torch Compile:即時加速

在 PyTorch 2.0 中使用 torch.compile,您只需一行程式碼即可編譯模型,從而實現更快的推理速度timm 整合完全相容 torch.compile。以下是一個快速基準測試,用於比較使用 TimmWrapper 時有無 torch.compile 的推理時間。

# Load the model and input
model = TimmWrapperForImageClassification.from_pretrained(checkpoint).to(device)
processed_input = image_processor(image, return_tensors="pt").to(device)

# Benchmark function
def run_benchmark(model, input_data, warmup_runs=5, benchmark_runs=300):
    # Warm-up phase
    model.eval()
    with torch.no_grad():
        for _ in range(warmup_runs):
            _ = model(**input_data)

    # Benchmark phase
    times = []
    with torch.no_grad():
        for _ in range(benchmark_runs):
            start_time = time.perf_counter()
            _ = model(**input_data)
            if device.type == "cuda":
                torch.cuda.synchronize(device=device)  # Ensure synchronization for CUDA
            times.append(time.perf_counter() - start_time)

    avg_time = sum(times) / benchmark_runs
    return avg_time

# Run benchmarks
time_no_compile = run_benchmark(model, processed_input)
compiled_model = torch.compile(model).to(device)
time_compile = run_benchmark(compiled_model, processed_input)

# Results
print(f"Without torch.compile: {time_no_compile:.4f} s")
print(f"With torch.compile: {time_compile:.4f} s")

compile timing

總結

timm 與 transformers 的整合為利用最先進的視覺模型開闢了新的大門,且只需最少的努力。無論您是想進行微調、量化,還是僅僅執行推理,這種整合都提供了一個統一的 API 來簡化您的工作流程。

立即開始探索,解鎖計算機視覺的新可能!

致謝

我們要向在 Transformers PR #34564 中促成此次整合的各位同仁表示衷心的感謝。排名不分先後,衷心感謝 Pavel Iakubovskii、Ross Wightman、Lysandre Debut、Pablo Montalvo、Arthur Zucker 和 Amy Roberts 所做的傑出工作。你們的共同努力使這個想法變成了現實,讓每個人都能享受到這個功能!

社群

文章作者
此評論已被隱藏

對此非常興奮,謝謝!我們正準備切換到 timm,這使得它變得更容易了!

也許一個簡單天真的問題,我正在嘗試編寫一個演示訓練指令碼,從以下位置載入基礎模型
TimmWrapperForImageClassification.from_pretrained("timm/mobilenetv4_conv_medium.e500_r256_in1k").to("cuda")

但隨後在 food101 資料集上進行訓練,只是為了說明在新自定義資料集上進行訓練。訓練正常,但推理返回的是動物名稱作為標籤。

當我載入微調模型時,是否應該設定 label2id, num_labels, id2label, 等引數?它似乎在訓練期間將資料儲存在某個地方,但 TrainingArguments 不允許我設定 TypeError: TrainingArguments.__init__() got an unexpected keyword argument 'label2id'

@davidrs 嗯,可能是標籤處理有問題,在釋出前對整合做了一個改動,以保持標籤與 timm 的使用相容(保留 label_names 欄位而不是 id2label/label2id),實際上這兩種情況混合在一起了,而且 @ariG23498 的許多微調都有 id2label,儘管在釋出時我被告知它應該生成 label_names...

你有一個我可以看的公開模型嗎?你是在 Transformers 中使用示例影像分類指令碼 (https://github.com/huggingface/transformers/tree/main/examples/pytorch/image-classification) 還是像上面的例子一樣在自定義指令碼/notebook 中直接使用 Trainer?

我的錯,我錯過了部落格頂部的連結,那裡有更完整的程式碼示例,也許這能幫到我 https://github.com/ariG23498/timm-wrapper-examples/blob/main/%2304_sft.ipynb

在自定義指令碼中直接使用 Trainer,我做了一個 Colab notebook 來說明我目前正在嘗試的東西,在最後的預測測試中,標籤不是食物。
https://colab.research.google.com/drive/14jTpetYR61B6EVoJ6o8_B8gi6-SiizCA?usp=sharing

@davidrs 示例中用目標資料集重置分類器/標籤的這部分很重要。

model = AutoModelForImageClassification.from_pretrained(
    checkpoint,
    num_labels=num_labels,
    id2label=id2label,
    label2id=label2id,
    ignore_mismatched_sizes=True,
)

如果你不這樣做,我相信它會用原始標籤、一個 imagenet-1k (1000) 類分類器等被推送。不過,我猜新資料集中最低的 'n' 個類別會被目標微調(如果類別更少,如果新資料集有更多類別,它會崩潰)。

順便說一下,我剛在 pipeline 中發現一個 bug,由於一個 bug (https://github.com/huggingface/transformers/pull/35848),它預設應用了 sigmoid 而不是 softmax,所以如果你想要 softmax 機率,請新增 function_to_apply='softmax'……這不特定於 timm 整合,而且看起來已經存在一段時間了。我確認瞭如果你像上面那樣設定標籤,微調後的 timm 模型將用正確的標籤進行預測,並且也應該用這些標籤推送到 hub...

如何與 Optimum 整合並載入模型的 onnx 版本

註冊登入 以發表評論

© . This site is unofficial and not affiliated with Hugging Face, Inc.