Transformers 文件

計算機視覺知識蒸餾

Hugging Face's logo
加入 Hugging Face 社群

並獲得增強的文件體驗

開始使用

計算機視覺知識蒸餾

知識蒸餾是一種將知識從大型複雜模型(教師模型)轉移到小型簡單模型(學生模型)的技術。為了將知識從一個模型蒸餾到另一個模型,我們首先使用一個預訓練的教師模型,該模型已針對特定任務(本例中為影像分類)進行了訓練,然後隨機初始化一個學生模型,用於影像分類訓練。接下來,我們訓練學生模型,使其輸出與教師模型輸出之間的差異最小化,從而使其模仿教師模型的行為。它最初由 Hinton 等人在 Distilling the Knowledge in a Neural Network 中提出。在本指南中,我們將進行特定任務的知識蒸餾。我們將使用 beans 資料集 進行此操作。

本指南演示瞭如何使用 🤗 Transformers 的 Trainer API微調的 ViT 模型(教師模型)蒸餾到 MobileNet(學生模型)。

讓我們安裝蒸餾和評估過程所需的庫。

pip install transformers datasets accelerate tensorboard evaluate --upgrade

在本示例中,我們使用 `merve/beans-vit-224` 模型作為教師模型。它是一個影像分類模型,基於 `google/vit-base-patch16-224-in21k` 在 beans 資料集上進行了微調。我們將把這個模型蒸餾到一個隨機初始化的 MobileNetV2。

現在我們來載入資料集。

from datasets import load_dataset

dataset = load_dataset("beans")

我們可以使用其中一個模型的影像處理器,因為在這種情況下它們返回相同解析度的相同輸出。我們將使用 `dataset` 的 `map()` 方法將預處理應用於資料集的每個拆分。

from transformers import AutoImageProcessor
teacher_processor = AutoImageProcessor.from_pretrained("merve/beans-vit-224")

def process(examples):
    processed_inputs = teacher_processor(examples["image"])
    return processed_inputs

processed_datasets = dataset.map(process, batched=True)

本質上,我們希望學生模型(一個隨機初始化的 MobileNet)模仿教師模型(微調後的視覺 Transformer)。為了實現這一點,我們首先從教師模型和學生模型中獲取 logits 輸出。然後,我們將它們各自除以引數 `temperature`,該引數控制每個軟目標的重要性。一個名為 `lambda` 的引數衡量蒸餾損失的重要性。在本示例中,我們將使用 `temperature=5` 和 `lambda=0.5`。我們將使用 Kullback-Leibler 散度損失來計算學生模型和教師模型之間的散度。給定兩個資料 P 和 Q,KL 散度解釋了我們需要多少額外資訊來使用 Q 表示 P。如果兩者相同,它們的 KL 散度為零,因為不需要其他資訊來從 Q 解釋 P。因此,在知識蒸餾的背景下,KL 散度非常有用。

from transformers import TrainingArguments, Trainer
import torch
import torch.nn as nn
import torch.nn.functional as F
from accelerate.test_utils.testing import get_backend

class ImageDistilTrainer(Trainer):
    def __init__(self, teacher_model=None, student_model=None, temperature=None, lambda_param=None,  *args, **kwargs):
        super().__init__(model=student_model, *args, **kwargs)
        self.teacher = teacher_model
        self.student = student_model
        self.loss_function = nn.KLDivLoss(reduction="batchmean")
        device, _, _ = get_backend() # automatically detects the underlying device type (CUDA, CPU, XPU, MPS, etc.)
        self.teacher.to(device)
        self.teacher.eval()
        self.temperature = temperature
        self.lambda_param = lambda_param

    def compute_loss(self, student, inputs, return_outputs=False):
        student_output = self.student(**inputs)

        with torch.no_grad():
          teacher_output = self.teacher(**inputs)

        # Compute soft targets for teacher and student
        soft_teacher = F.softmax(teacher_output.logits / self.temperature, dim=-1)
        soft_student = F.log_softmax(student_output.logits / self.temperature, dim=-1)

        # Compute the loss
        distillation_loss = self.loss_function(soft_student, soft_teacher) * (self.temperature ** 2)

        # Compute the true label loss
        student_target_loss = student_output.loss

        # Calculate final loss
        loss = (1. - self.lambda_param) * student_target_loss + self.lambda_param * distillation_loss
        return (loss, student_output) if return_outputs else loss

現在我們將登入到 Hugging Face Hub,這樣我們就可以透過 `Trainer` 將模型推送到 Hugging Face Hub。

from huggingface_hub import notebook_login

notebook_login()

我們來設定 `TrainingArguments`、教師模型和學生模型。

from transformers import AutoModelForImageClassification, MobileNetV2Config, MobileNetV2ForImageClassification

training_args = TrainingArguments(
    output_dir="my-awesome-model",
    num_train_epochs=30,
    fp16=True,
    logging_dir=f"{repo_name}/logs",
    logging_strategy="epoch",
    eval_strategy="epoch",
    save_strategy="epoch",
    load_best_model_at_end=True,
    metric_for_best_model="accuracy",
    report_to="tensorboard",
    push_to_hub=True,
    hub_strategy="every_save",
    hub_model_id=repo_name,
    )

num_labels = len(processed_datasets["train"].features["labels"].names)

# initialize models
teacher_model = AutoModelForImageClassification.from_pretrained(
    "merve/beans-vit-224",
    num_labels=num_labels,
    ignore_mismatched_sizes=True
)

# training MobileNetV2 from scratch
student_config = MobileNetV2Config()
student_config.num_labels = num_labels
student_model = MobileNetV2ForImageClassification(student_config)

我們可以使用 `compute_metrics` 函式在測試集上評估我們的模型。此函式將在訓練過程中用於計算模型的 `accuracy` 和 `f1`。

import evaluate
import numpy as np

accuracy = evaluate.load("accuracy")

def compute_metrics(eval_pred):
    predictions, labels = eval_pred
    acc = accuracy.compute(references=labels, predictions=np.argmax(predictions, axis=1))
    return {"accuracy": acc["accuracy"]}

讓我們使用我們定義的訓練引數初始化 `Trainer`。我們還將初始化資料整理器。

from transformers import DefaultDataCollator

data_collator = DefaultDataCollator()
trainer = ImageDistilTrainer(
    student_model=student_model,
    teacher_model=teacher_model,
    training_args=training_args,
    train_dataset=processed_datasets["train"],
    eval_dataset=processed_datasets["validation"],
    data_collator=data_collator,
    processing_class=teacher_processor,
    compute_metrics=compute_metrics,
    temperature=5,
    lambda_param=0.5
)

我們現在可以訓練模型了。

trainer.train()

我們可以在測試集上評估模型。

trainer.evaluate(processed_datasets["test"])

在測試集上,我們的模型達到了 72% 的準確率。為了驗證蒸餾的效率,我們還使用相同的超引數從頭開始在 beans 資料集上訓練了 MobileNet,並在測試集上觀察到 63% 的準確率。我們邀請讀者嘗試不同的預訓練教師模型、學生架構、蒸餾引數並報告他們的發現。蒸餾模型的訓練日誌和檢查點可以在 此儲存庫 中找到,從頭開始訓練的 MobileNetV2 可以在 此儲存庫 中找到。

< > 在 GitHub 上更新

© . This site is unofficial and not affiliated with Hugging Face, Inc.