使用 🤗 Transformers 微調 ViT 進行影像分類

釋出於 2022 年 2 月 11 日
在 GitHub 上更新
Open In Colab

正如基於 transformer 的模型徹底改變了自然語言處理一樣,我們現在也看到大量論文將它們應用於各種其他領域。其中最具革命性之一的是 Vision Transformer (ViT),它於 2021 年 6 月由 Google Brain 的研究團隊推出。

這篇論文探討了如何像標記句子一樣標記影像,以便將它們傳遞給 transformer 模型進行訓練。這真的非常簡單……

  1. 將影像分割成子影像塊網格
  2. 使用線性投影嵌入每個影像塊
  3. 每個嵌入的影像塊都成為一個標記,由此產生的嵌入影像塊序列就是您傳遞給模型的序列。

事實證明,完成上述操作後,您可以像在自然語言處理任務中一樣預訓練和微調 transformer。非常棒 😎。


在這篇博文中,我們將逐步介紹如何利用 🤗 datasets 下載和處理影像分類資料集,然後使用它們透過 🤗 transformers 微調預訓練的 ViT。

首先,讓我們先安裝這兩個軟體包。

pip install datasets transformers

載入資料集

讓我們從載入一個小型影像分類資料集並檢視其結構開始。

我們將使用 beans 資料集,它是健康和不健康豆葉圖片的集合。🍃

from datasets import load_dataset

ds = load_dataset('beans')
ds

讓我們看看 beans 資料集 'train' 分割中的第 400 個示例。您會注意到資料集中的每個示例都有 3 個特徵

  1. image:PIL 影像
  2. image_file_path:作為 image 載入的影像檔案的 str 路徑
  3. labels:一個 datasets.ClassLabel 特徵,它是標籤的整數表示。(稍後您會看到如何獲取字串類名,別擔心!)
ex = ds['train'][400]
ex
{
  'image': <PIL.JpegImagePlugin ...>,
  'image_file_path': '/root/.cache/.../bean_rust_train.4.jpg',
  'labels': 1
}

讓我們看看影像 👀

image = ex['image']
image

這絕對是一片葉子!但是是什麼種類呢?😅

由於此資料集的 'labels' 特徵是 datasets.features.ClassLabel,我們可以使用它來查詢此示例標籤 ID 對應的名稱。

首先,讓我們訪問 'labels' 的特徵定義。

labels = ds['train'].features['labels']
labels
ClassLabel(num_classes=3, names=['angular_leaf_spot', 'bean_rust', 'healthy'], names_file=None, id=None)

現在,讓我們打印出我們示例的類標籤。您可以使用 ClassLabelint2str 函式來完成此操作,顧名思義,它允許傳遞類的整數表示來查詢字串標籤。

labels.int2str(ex['labels'])
'bean_rust'

原來上面顯示的葉子感染了豆鏽病,這是一種嚴重的豆類植物病害。😢

讓我們編寫一個函式,它將顯示每個類別的一些示例網格,以便更好地瞭解您正在處理的內容。

import random
from PIL import ImageDraw, ImageFont, Image

def show_examples(ds, seed: int = 1234, examples_per_class: int = 3, size=(350, 350)):

    w, h = size
    labels = ds['train'].features['labels'].names
    grid = Image.new('RGB', size=(examples_per_class * w, len(labels) * h))
    draw = ImageDraw.Draw(grid)
    font = ImageFont.truetype("/usr/share/fonts/truetype/liberation/LiberationMono-Bold.ttf", 24)

    for label_id, label in enumerate(labels):

        # Filter the dataset by a single label, shuffle it, and grab a few samples
        ds_slice = ds['train'].filter(lambda ex: ex['labels'] == label_id).shuffle(seed).select(range(examples_per_class))

        # Plot this label's examples along a row
        for i, example in enumerate(ds_slice):
            image = example['image']
            idx = examples_per_class * label_id + i
            box = (idx % examples_per_class * w, idx // examples_per_class * h)
            grid.paste(image.resize(size), box=box)
            draw.text(box, label, (255, 255, 255), font=font)

    return grid

show_examples(ds, seed=random.randint(0, 1337), examples_per_class=3)
資料集中每個類別的一些示例網格

據我所知,

  • 角斑病:有不規則的棕色斑塊
  • 豆鏽病:有圓形棕色斑點,周圍有白色-黃色環
  • 健康:……看起來很健康。🤷‍♂️

載入 ViT 影像處理器

現在我們知道影像是什麼樣子以及我們正在努力解決的問題。讓我們看看如何為模型準備這些影像!

當 ViT 模型進行訓練時,會對其輸入的影像應用特定的轉換。如果對影像使用了錯誤的轉換,模型將無法理解它所看到的內容!🖼 ➡️ 🔢

為了確保我們應用正確的轉換,我們將使用一個 ViTImageProcessor,它使用我們計劃使用的預訓練模型儲存的配置進行初始化。在本例中,我們將使用 google/vit-base-patch16-224-in21k 模型,因此讓我們從 Hugging Face Hub 載入其影像處理器。

from transformers import ViTImageProcessor

model_name_or_path = 'google/vit-base-patch16-224-in21k'
processor = ViTImageProcessor.from_pretrained(model_name_or_path)

您可以透過列印影像處理器配置來檢視它。

ViTImageProcessor {
  "do_normalize": true,
  "do_resize": true,
  "image_mean": [
    0.5,
    0.5,
    0.5
  ],
  "image_std": [
    0.5,
    0.5,
    0.5
  ],
  "resample": 2,
  "size": 224
}

要處理影像,只需將其傳遞給影像處理器的呼叫函式。這將返回一個包含 pixel values 的字典,這是要傳遞給模型的數字表示。

預設情況下您會得到一個 NumPy 陣列,但如果您新增 return_tensors='pt' 引數,您將得到 torch 張量。

processor(image, return_tensors='pt')

應該會得到類似以下內容...

{
  'pixel_values': tensor([[[[ 0.2706,  0.3255,  0.3804,  ...]]]])
}

...其中張量的形狀為 (1, 3, 224, 224)

處理資料集

現在您已經知道如何讀取影像並將其轉換為輸入,讓我們編寫一個函式,將這兩者結合起來,以處理資料集中的單個示例。

def process_example(example):
    inputs = processor(example['image'], return_tensors='pt')
    inputs['labels'] = example['labels']
    return inputs
process_example(ds['train'][0])
{
  'pixel_values': tensor([[[[-0.6157, -0.6000, -0.6078,  ..., ]]]]),
  'labels': 0
}

雖然您可以呼叫 ds.map 並將其一次性應用於每個示例,但這可能會非常慢,特別是如果您使用更大的資料集。相反,您可以對資料集應用一個 *轉換*。轉換僅在您索引示例時應用。

不過,首先,您需要更新最後一個函式以接受一批資料,因為這就是 ds.with_transform 所期望的。

ds = load_dataset('beans')

def transform(example_batch):
    # Take a list of PIL images and turn them to pixel values
    inputs = processor([x for x in example_batch['image']], return_tensors='pt')

    # Don't forget to include the labels!
    inputs['labels'] = example_batch['labels']
    return inputs

您可以使用 ds.with_transform(transform) 直接將其應用於資料集。

prepared_ds = ds.with_transform(transform)

現在,每當您從資料集中獲取一個示例時,轉換將即時應用(在樣本和切片上,如下所示)

prepared_ds['train'][0:2]

這次,得到的 pixel_values 張量形狀將是 (2, 3, 224, 224)

{
  'pixel_values': tensor([[[[-0.6157, -0.6000, -0.6078,  ..., ]]]]),
  'labels': [0, 0]
}

訓練和評估

資料已處理完畢,您已準備好開始設定訓練管道。這篇博文使用 🤗 的 Trainer,但這需要我們先做幾件事

  • 定義一個 collate 函式。

  • 定義一個評估指標。在訓練期間,模型應根據其預測準確性進行評估。您應相應地定義一個 compute_metrics 函式。

  • 載入預訓練的檢查點。您需要載入預訓練的檢查點並正確配置它以進行訓練。

  • 定義訓練配置。

微調模型後,您將正確評估評估資料上的模型,並驗證它確實學會了正確分類影像。

定義我們的資料整理器

批次以字典列表的形式傳入,因此您只需將它們解包並堆疊成批次張量即可。

由於 collate_fn 將返回一個批次字典,因此您可以稍後將輸入 **解包 到模型中。✨

import torch

def collate_fn(batch):
    return {
        'pixel_values': torch.stack([x['pixel_values'] for x in batch]),
        'labels': torch.tensor([x['labels'] for x in batch])
    }

定義評估指標

evaluate準確度 指標可以輕鬆用於比較預測與標籤。下面,您可以看到如何在 compute_metrics 函式中使用它,該函式將由 Trainer 使用。

import numpy as np
from evaluate import load

metric = load("accuracy")
def compute_metrics(p):
    return metric.compute(predictions=np.argmax(p.predictions, axis=1), references=p.label_ids)

讓我們載入預訓練模型。我們將在初始化時新增 num_labels,以便模型建立一個具有正確單元數量的分類頭。我們還將包含 id2labellabel2id 對映,以便在 Hub 小部件中擁有人類可讀的標籤(如果您選擇 push_to_hub)。

from transformers import ViTForImageClassification

labels = ds['train'].features['labels'].names

model = ViTForImageClassification.from_pretrained(
    model_name_or_path,
    num_labels=len(labels),
    id2label={str(i): c for i, c in enumerate(labels)},
    label2id={c: str(i) for i, c in enumerate(labels)}
)

差不多準備好訓練了!在此之前,最後需要做的是透過定義 TrainingArguments 來設定訓練配置。

其中大多數都非常直觀,但其中一個非常重要的是 remove_unused_columns=False。這個引數將刪除模型呼叫函式未使用的任何特徵。預設情況下它為 True,因為通常最好刪除未使用的特徵列,這樣可以更容易地將輸入解包到模型的呼叫函式中。但是,在我們的例子中,我們需要未使用的特徵(尤其是“image”)才能建立“pixel_values”。

我想說的是,如果你忘記設定 remove_unused_columns=False,你將會很糟糕。

from transformers import TrainingArguments

training_args = TrainingArguments(
  output_dir="./vit-base-beans",
  per_device_train_batch_size=16,
  evaluation_strategy="steps",
  num_train_epochs=4,
  fp16=True,
  save_steps=100,
  eval_steps=100,
  logging_steps=10,
  learning_rate=2e-4,
  save_total_limit=2,
  remove_unused_columns=False,
  push_to_hub=False,
  report_to='tensorboard',
  load_best_model_at_end=True,
)

現在,所有例項都可以傳遞給 Trainer,我們準備開始訓練了!

from transformers import Trainer

trainer = Trainer(
    model=model,
    args=training_args,
    data_collator=collate_fn,
    compute_metrics=compute_metrics,
    train_dataset=prepared_ds["train"],
    eval_dataset=prepared_ds["validation"],
    tokenizer=processor,
)

訓練 🚀

train_results = trainer.train()
trainer.save_model()
trainer.log_metrics("train", train_results.metrics)
trainer.save_metrics("train", train_results.metrics)
trainer.save_state()

評估 📊

metrics = trainer.evaluate(prepared_ds['validation'])
trainer.log_metrics("eval", metrics)
trainer.save_metrics("eval", metrics)

這是我的評估結果——好棒的豆子!抱歉,我必須說出來。

***** eval metrics *****
  epoch                   =        4.0
  eval_accuracy           =      0.985
  eval_loss               =     0.0637
  eval_runtime            = 0:00:02.13
  eval_samples_per_second =     62.356
  eval_steps_per_second   =       7.97

最後,如果您願意,可以將模型推送到 Hub。在這裡,如果您在訓練配置中指定了 push_to_hub=True,我們將將其推送到 Hub。請注意,要推送到 Hub,您必須安裝 git-lfs 並登入您的 Hugging Face 帳戶(可以透過 huggingface-cli login 完成)。

kwargs = {
    "finetuned_from": model.config._name_or_path,
    "tasks": "image-classification",
    "dataset": 'beans',
    "tags": ['image-classification'],
}

if training_args.push_to_hub:
    trainer.push_to_hub('🍻 cheers', **kwargs)
else:
    trainer.create_model_card(**kwargs)

生成的模型已共享至 nateraw/vit-base-beans。我假設您手頭沒有豆葉圖片,所以我添加了一些示例供您嘗試!🚀

社群

註冊登入 發表評論

© . This site is unofficial and not affiliated with Hugging Face, Inc.