Timm ❤️ Transformers:在 Transformers 中使用任何 timm 模型
在友好的 🤗
transformers
生態系統中,為**任何** timm
模型實現閃電般的推理速度、快速量化、torch.compile
加速和輕鬆微調。
隆重介紹 TimmWrapper
——一個簡單而強大的工具,釋放了這一潛力。
在這篇文章中,我們將涵蓋
- timm 整合的工作原理及其為何能改變遊戲規則。
- 如何將
timm
模型與 🤗transformers
整合。 - 實踐示例:pipeline、量化、微調等。
要跟隨本部落格文章,請執行以下命令安裝最新版本的
transformers
和timm
:pip install -Uq transformers timm
檢視包含所有程式碼示例和 notebook 的完整程式碼庫:🔗 TimmWrapper 示例
什麼是 timm?
PyTorch Image Models (timm
) 庫提供了豐富的最先進的計算機視覺模型,以及有用的層、工具、最佳化器和資料增強。截至本文撰寫時,它在 GitHub 上擁有超過 32K 顆星,每日下載量超過 200K,是影像分類和特徵提取(用於目標檢測、分割、影像搜尋等下游任務)的首選資源。
timm
擁有涵蓋各種架構的預訓練模型,簡化了計算機視覺從業者的工作流程。
為何使用 timm 整合?
雖然 🤗 transformers
支援多種視覺模型,但 timm
提供了更廣泛的集合,包括許多在 transformers 中不可用的移動端友好和高效的模型。
timm
整合彌補了這一差距,帶來了兩全其美的優勢
- ✅ Pipeline API 支援:輕鬆將任何
timm
模型插入到高階transformers
pipeline 中,以實現流線型推理。 - 🧩 與 Auto 類相容:雖然
timm
模型本身與transformers
不相容,但此整合使其能夠與Auto
類 API 無縫協作。 - ⚡ 快速量化:只需約 5 行程式碼,您就可以量化**任何**
timm
模型以進行高效推理。 - 🎯 使用 Trainer API 進行微調:使用
Trainer
API 微調timm
模型,甚至可以與低秩自適應 (LoRA) 等介面卡整合。 - 🔁 返回 timm:在
timm
中再次使用微調後的模型。 - 🚀 Torch Compile 加速:利用
torch.compile
最佳化推理時間。
Pipeline API:使用 timm 模型進行影像分類
timm
整合的一個突出特點是它允許您利用 🤗 pipeline
API。pipeline
API 抽象了許多複雜性,使得載入預訓練模型、執行推理和檢視結果變得非常簡單,只需幾行程式碼即可完成。
讓我們看看如何將 transformers pipeline 與 MobileNetV4 一起使用。該架構沒有原生的 transformers
實現,但可以輕鬆地從 timm
中使用
from transformers import pipeline
import requests
# Load the image classification pipeline with a timm model
image_classifier = pipeline(model="timm/mobilenetv4_conv_medium.e500_r256_in1k")
# URL of the image to classify
url = "https://huggingface.co/datasets/huggingface/documentation-images/resolve/main/timm/cat.jpg"
# Perform inference
outputs = image_classifier(url)
# Print the results
for output in outputs:
print(f"Label: {output['label'] :20} Score: {output['score'] :0.2f}")
輸出:
Device set to use cpu
Label: tabby, tabby cat Score: 0.69
Label: tiger cat Score: 0.21
Label: Egyptian cat Score: 0.02
Label: bee Score: 0.00
Label: marmoset Score: 0.00
Gradio 整合:構建食物分類器演示 🍣
想要快速建立一個用於影像分類的互動式 Web 應用嗎?Gradio 使您能夠用最少的程式碼構建一個使用者友好的介面。讓我們將 Gradio 與 pipeline
API 結合起來,使用一個微調過的 timm ViT 模型來分類食物影像(我們將在後面的章節中介紹微調)。
以下是如何使用 timm
模型快速設定一個演示
import gradio as gr
from transformers import pipeline
# Load the image classification pipeline using a timm model
pipe = pipeline(
"image-classification",
model="ariG23498/vit_base_patch16_224.augreg2_in21k_ft_in1k.ft_food101"
)
def classify(image):
return pipe(image)[0]["label"]
demo = gr.Interface(
fn=classify,
inputs=gr.Image(type="pil"),
outputs="text",
examples=[["./sushi.png", "sushi"]]
)
demo.launch()
這是一個託管在 Hugging Face Spaces 上的即時示例。您可以直接在瀏覽器中測試!
Auto 類:簡化模型載入
🤗 transformers
庫提供了 Auto 類 來抽象化載入模型和處理器的複雜性。透過 TimmWrapper
,您可以使用 AutoModelForImageClassification
和 AutoImageProcessor
輕鬆載入任何 timm
模型。
這是一個快速示例
from transformers import (
AutoModelForImageClassification,
AutoImageProcessor,
)
from transformers.image_utils import load_image
image_url = "https://huggingface.co/datasets/huggingface/documentation-images/resolve/main/timm/cat.jpg"
image = load_image(image_url)
# Use Auto classes to load a timm model
checkpoint = "timm/mobilenetv4_conv_medium.e500_r256_in1k"
image_processor = AutoImageProcessor.from_pretrained(checkpoint)
model = AutoModelForImageClassification.from_pretrained(checkpoint).eval()
# Check the types
print(type(image_processor)) # TimmWrapperImageProcessor
print(type(model)) # TimmWrapperForImageClassification
執行量化的 timm 模型
量化是一種強大的技術,可以減小模型大小並加速推理,尤其適用於資源受限的裝置。透過 timm
整合,您可以使用 bitsandbytes
中的 BitsAndBytesConfig
,只需幾行程式碼即可即時量化任何 timm
模型。
以下是量化一個 timm
模型是多麼簡單
from transformers import TimmWrapperForImageClassification, BitsAndBytesConfig
quantization_config = BitsAndBytesConfig(load_in_8bit=True)
checkpoint = "timm/vit_base_patch16_224.augreg2_in21k_ft_in1k"
model = TimmWrapperForImageClassification.from_pretrained(checkpoint).to("cuda")
model_8bit = TimmWrapperForImageClassification.from_pretrained(
checkpoint,
quantization_config=quantization_config,
low_cpu_mem_usage=True,
)
original_footprint = model.get_memory_footprint()
quantized_footprint = model_8bit.get_memory_footprint()
print(f"Original model size: {original_footprint / 1e6:.2f} MB")
print(f"Quantized model size: {quantized_footprint / 1e6:.2f} MB")
print(f"Reduction: {(original_footprint - quantized_footprint) / original_footprint * 100:.2f}%")
輸出
Original model size: 346.27 MB
Quantized model size: 88.20 MB
Reduction: 74.53%
量化模型在推理時的效能與全精度模型幾乎完全相同
模型 | 標籤 | 準確率 |
---|---|---|
原始模型 | 遙控器,遙控 | 0.35% |
量化模型 | 遙控器,遙控 | 0.33% |
timm
模型的監督式微調
使用 🤗 transformers
的 Trainer
API 微調 timm
模型是直接且高度靈活的。您可以使用 Trainer
類在自定義資料集上微調您的模型,該類處理訓練迴圈、日誌記錄和評估。此外,您可以使用 LoRA (低秩自適應) 進行微調,以更少的引數高效地進行訓練。
本節對標準微調和 LoRA 微調進行了簡要概述,並提供了完整程式碼的連結。
使用 Trainer
API 進行標準微調
Trainer
API 使得用最少的程式碼設定訓練變得容易。以下是微調設定的概要
from transformers import TrainingArguments, Trainer
# Define training arguments
training_args = TrainingArguments(
output_dir="my_model_output",
evaluation_strategy="epoch",
save_strategy="epoch",
learning_rate=5e-5,
per_device_train_batch_size=16,
num_train_epochs=3,
load_best_model_at_end=True,
push_to_hub=True,
)
# Initialize the Trainer
trainer = Trainer(
model=model,
args=training_args,
train_dataset=train_ds,
eval_dataset=val_ds,
data_collator=data_collator,
compute_metrics=compute_metrics,
)
# Start training
trainer.train()
這種方法的顯著之處在於,它反映了用於原生 transformers
模型的完全相同的工作流程,從而在不同模型型別之間保持了一致性。
這意味著您可以使用熟悉的 Trainer
API 不僅微調 Transformers 模型,還可以微調**任何 timm
模型**——將 timm
庫中強大的模型引入 Hugging Face 生態系統,只需進行最少的調整。這極大地拓寬了您可以使用相同可信賴的工具和工作流程進行微調的模型範圍。
模型示例
在 Food-101 上微調的 ViT:vit_base_patch16_224.augreg2_in21k_ft_in1k.ft_food101
LoRA 微調以實現高效訓練
LoRA (低秩自適應) 允許您透過僅訓練少量額外引數而不是完整的模型權重來高效地微調大型模型。這使得微調更快,並允許使用消費級硬體。您可以使用 PEFT 庫透過 LoRA 微調一個 timm
模型。
以下是您可以如何設定它
from peft import LoraConfig, get_peft_model
model = AutoModelForImageClassification.from_pretrained(checkpoint, num_labels=num_labels)
lora_config = LoraConfig(
r=16,
lora_alpha=16,
target_modules=["qkv"],
lora_dropout=0.1,
bias="none",
modules_to_save=["head"],
)
# Wrap the model with PEFT
lora_model = get_peft_model(model, lora_config)
lora_model.print_trainable_parameters()
使用 LoRA 的可訓練引數
trainable params: 667,493 || all params: 86,543,818 || trainable%: 0.77%
模型示例
在 Food-101 上進行 LoRA 微調的 ViT:vit_base_patch16_224.augreg2_in21k_ft_in1k.lora_ft_food101
LoRA 只是您可以應用於 timm
模型的眾多高效介面卡微調方法中的一個例子。timm
與 🤗 生態系統的整合為您開啟了各種**引數高效微調 (PEFT)** 技術的大門,讓您可以選擇最適合您應用場景的方法。
使用 LoRA 微調模型進行推理
一旦模型經過 LoRA 微調,我們僅將介面卡權重推送到 Hugging Face Hub。本節將幫助您下載介面卡權重,將介面卡權重與基礎模型合併,然後進行推理。
from peft import PeftModel, PeftConfig
repo_name = "ariG23498/vit_base_patch16_224.augreg2_in21k_ft_in1k.lora_ft_food101"
config = PeftConfig.from_pretrained(repo_name)
model = AutoModelForImageClassification.from_pretrained(
config.base_model_name_or_path,
label2id=label2id,
num_labels=num_labels,
id2label=id2label,
ignore_mismatched_sizes=True,
)
inference_model = PeftModel.from_pretrained(model, repo_name)
# Make prediction with the model
雙向整合
Ross (timm
的建立者) 最喜歡的一個功能是,這種整合保持了完整的“雙向”相容性。也就是說,使用包裝器,人們可以使用 transformer
的 Trainer
在新資料集上微調 timm 模型,將結果模型釋出到 Hugging Face hub,然後再次使用 timm.create_model('hf-hub:my_org/my_fine_tuned_model', pretrained=True)
在 timm
中載入微調後的模型。
讓我們看看如何用 timm
載入我們微調過的模型 ariG23498/vit_base_patch16_224.augreg2_in21k_ft_in1k.ft_food101
checkpoint = "ariG23498/vit_base_patch16_224.augreg2_in21k_ft_in1k.ft_food101"
config = AutoConfig.from_pretrained(checkpoint)
model = timm.create_model(f"hf_hub:{checkpoint}", pretrained=True) # Load the model with timm
model = model.eval()
image = load_image("https://cdn.britannica.com/52/128652-050-14AD19CA/Maki-zushi.jpg")
data_config = timm.data.resolve_model_data_config(model)
transforms = timm.data.create_transform(**data_config, is_training=False)
output = model(transforms(image).unsqueeze(0))
top5_probabilities, top5_class_indices = torch.topk(output.softmax(dim=1) * 100, k=5)
for prob, idx in zip(top5_probabilities[0], top5_class_indices[0]):
print(f"Label: {config.id2label[idx.item()] :20} Score: {prob/100 :0.2f}%")
輸出
Label: sushi Score: 0.98%
Label: spring_rolls Score: 0.01%
Label: sashimi Score: 0.00%
Label: club_sandwich Score: 0.00%
Label: cannoli Score: 0.00%
Torch Compile:即時加速
在 PyTorch 2.0 中使用 torch.compile
,您只需一行程式碼即可編譯模型,從而實現更快的推理速度。timm
整合完全相容 torch.compile
。以下是一個快速基準測試,用於比較使用 TimmWrapper
時有無 torch.compile
的推理時間。
# Load the model and input
model = TimmWrapperForImageClassification.from_pretrained(checkpoint).to(device)
processed_input = image_processor(image, return_tensors="pt").to(device)
# Benchmark function
def run_benchmark(model, input_data, warmup_runs=5, benchmark_runs=300):
# Warm-up phase
model.eval()
with torch.no_grad():
for _ in range(warmup_runs):
_ = model(**input_data)
# Benchmark phase
times = []
with torch.no_grad():
for _ in range(benchmark_runs):
start_time = time.perf_counter()
_ = model(**input_data)
if device.type == "cuda":
torch.cuda.synchronize(device=device) # Ensure synchronization for CUDA
times.append(time.perf_counter() - start_time)
avg_time = sum(times) / benchmark_runs
return avg_time
# Run benchmarks
time_no_compile = run_benchmark(model, processed_input)
compiled_model = torch.compile(model).to(device)
time_compile = run_benchmark(compiled_model, processed_input)
# Results
print(f"Without torch.compile: {time_no_compile:.4f} s")
print(f"With torch.compile: {time_compile:.4f} s")
總結
timm
與 transformers 的整合為利用最先進的視覺模型開闢了新的大門,且只需最少的努力。無論您是想進行微調、量化,還是僅僅執行推理,這種整合都提供了一個統一的 API 來簡化您的工作流程。
立即開始探索,解鎖計算機視覺的新可能!
致謝
我們要向在 Transformers PR #34564 中促成此次整合的各位同仁表示衷心的感謝。排名不分先後,衷心感謝 Pavel Iakubovskii、Ross Wightman、Lysandre Debut、Pablo Montalvo、Arthur Zucker 和 Amy Roberts 所做的傑出工作。你們的共同努力使這個想法變成了現實,讓每個人都能享受到這個功能!