TRL 文件

使用 SFT 微調多模態模型(單圖或多圖資料集)

Hugging Face's logo
加入 Hugging Face 社群

並獲得增強的文件體驗

開始使用

使用 SFT 微調多模態模型(單圖或多圖資料集)

VLM SFT training procedure

概覽

本指南將引導你完成使用監督式微調 (Supervised Fine-Tuning, SFT) 來微調多模態語言模型(例如 Gemma 3)的過程。我們涵蓋兩種情況:

  • 單張圖片 + 文字
  • 多張圖片 + 文字

本指南是對現有 VLM SFT 指令碼詳細解讀和補充。如果你已熟悉這些概念,可以直接使用該指令碼。

我們使用兩個資料集來演示微調過程,但這些原則同樣適用於其他視覺語言模型 (Vision-Language Models, VLMs) 和資料集。

理解資料集

為了同時處理單張圖片 + 文字多張圖片 + 文字這兩種場景,我們使用了兩個非常適合此任務的資料集。

HuggingFaceH4/llava-instruct-mix-vsft 資料集(圖片 + 文字)

此資料集是 LLaVA Instruct Mix 的重新格式化版本。它由對話組成,其中使用者同時提供文字單張圖片作為輸入。

模型(被稱為“助手”)會根據使用者分享的視覺和文字資訊進行回應。該資料集對於訓練多模態模型以理解並生成基於圖片和文字的響應特別有用。

FanqingM/MMIU-Benchmark 資料集(多張圖片 + 文字)

FanqingM/MMIU-Benchmark 資料集包含:

  • 上下文:包含在系統提示中。
  • 問題:作為使用者輸入的一部分提供。
  • 一系列圖片:與問題相關的多張圖片。
  • 答案:模型的預期響應。

此資料集專為需要模型對多張圖片進行推理,並根據視覺和文字輸入生成明智回應的任務而設計。

為多模態 SFT 開發微調指令碼

在本節中,我們將構建一個用於微調多模態模型的指令碼,該指令碼適用於單張圖片 + 文字多張圖片 + 文字兩種用例。

設定環境

在微調之前,我們需要安裝所需的依賴項。讓我們從設定環境開始:

# Install the required libraries. Further details: https://huggingface.co/docs/trl/installation 
pip install -U -q trl bitsandbytes peft hf_xet tensorboard

所有依賴項安裝完畢後,我們需要登入到 Hugging Face Hub。由於 Gemma 3 是一個受限模型,因此需要訪問許可權。

如果你尚未申請訪問許可權,請訪問模型卡片並提交申請。

要登入,你需要從你的 Hugging Face 賬戶生成一個訪問令牌

huggingface-cli login

載入資料

如前所述,我們將涵蓋兩種可能的用例。雖然具體流程可能因資料集而異,但核心原則保持一致。

本指南支援兩種用例,請根據你的具體場景參考單張圖片 + 文字多張圖片 + 文字部分。

單張圖片 + 文字

Single Image + Text

在這種情況下,批次中的每個樣本都包含一張圖片與文字配對。由於資料集已經格式化為監督式微調 (SFT) 格式,我們可以直接使用 load_dataset 載入它。

from datasets import load_dataset

dataset_name = "HuggingFaceH4/llava-instruct-mix-vsft"

# Load Dataset
dataset = load_dataset(dataset_name)

多張圖片 + 文字(或交錯)

Multi-Image + Text

Gemma 3 也支援多張圖片 + 文字的場景,其中:

  • 模型接收一個圖片列表以及一條使用者訊息。
  • 模型處理對話中交錯的圖片和文字

對於這個資料集,在訓練前需要進行一些預處理。

from datasets import load_dataset

dataset_name = "FanqingM/MMIU-Benchmark"

# Load Dataset
dataset = load_dataset(dataset_name)

載入資料集後,我們需要將其預處理並格式化為對話結構。以下是資料可能的樣子示例:

{"role": "system", "content": [{"type": "text", "text": "You are a judge in a photography competition, and now you are given the four images. Please examine the details and tell which one of them is most likely to be a real photograph.\nSelect from the following choices.\nA: the first image\nB: the second image\nC: the third image\nD: the fourth image"}]},
{"role": "user", "content": images_list + [{"type": "text", "text": "Which image is most likely to be a real photograph?"}]},
{"role": "assistant", "content": [{"type": "text", "text": "A: the first image\nB: the second image\nC: the third image\nD: the fourth image"}]},

這裡,images_list 是一個圖片列表。

images_list = [
  {"type": "image", "image": <class 'PIL.Image.Image'>},
  {"type": "image", "image": <class 'PIL.Image.Image'>},
  {"type": "image", "image": <class 'PIL.Image.Image'>},
  {"type": "image", "image": <class 'PIL.Image.Image'>},
  {"type": "image", "image": <class 'PIL.Image.Image'>},
]

這種結構可以像這樣轉換成程式碼:

import os
import zipfile
import io
from datasets import DatasetDict
from huggingface_hub import hf_hub_download, list_repo_files
from PIL import Image

dataset_train_split = "test"

def format_data(samples: dict[str, any]) -> dict[str, list]:
    formatted_samples = {"messages": []}
    for cont in range(len(samples["question"])):
        images = []
        for img_path in samples["input_image_path"][cont]:
            try:
                with open(img_path, "rb") as f:
                    img_bytes = f.read()
                image = Image.open(io.BytesIO(img_bytes)).convert("RGB")
                images.append({"type": "image", "image": image})
            except Exception as e:
                print(f"Error processing image {img_path}: {e}")
                continue

        formatted_samples["messages"].append(
            [
                {"role": "system", "content": [{"type": "text", "text": samples["context"][cont]}]},
                {"role": "user", "content": images + [{"type": "text", "text": samples["question"][cont]}]},
                {"role": "assistant", "content": [{"type": "text", "text": samples["output"][cont]}]},
            ]
        )
    return formatted_samples

# For multi-image example
def prepare_dataset(dataset: DatasetDict, dataset_name: str, dataset_train_split: str) -> DatasetDict:
    all_files = list_repo_files(dataset_name, repo_type="dataset")
    zip_files = [f for f in all_files if f.endswith(".zip")]

    for zip_filename in zip_files:
        zip_path = hf_hub_download(repo_id=dataset_name, filename=zip_filename, repo_type="dataset")
        extract_folder = zip_filename.replace(".zip", "")
        os.makedirs(extract_folder, exist_ok=True)

        with zipfile.ZipFile(zip_path, "r") as zip_ref:
            zip_ref.extractall(extract_folder)

    dataset = dataset.map(format_data, batched=True, batch_size=4, num_proc=16)
    return dataset

dataset = prepare_dataset(dataset, dataset_name, dataset_train_split)

至此,你的多張圖片 + 文字資料集已經準備好用於訓練了。

準備訓練

我們首先載入模型和處理器。在本例中,我們使用 google/gemma-3-4b-it,但同樣的過程也適用於其其他變體和類似模型。

為了最佳化記憶體使用,我們配置 BitsAndBytes 來載入模型的量化版本。

import torch
from transformers import AutoModelForImageTextToText, AutoProcessor, BitsAndBytesConfig

model_id = "google/gemma-3-4b-it"

# BitsAndBytesConfig int-4 config
bnb_config = BitsAndBytesConfig(
    load_in_4bit=True,
    bnb_4bit_use_double_quant=True,
    bnb_4bit_quant_type="nf4",
    bnb_4bit_compute_dtype=torch.bfloat16,
    bnb_4bit_quant_storage=torch.bfloat16,
)

# Load model and tokenizer
model = AutoModelForImageTextToText.from_pretrained(
    model_id, 
    device_map="auto", 
    torch_dtype=torch.bfloat16,
    attn_implementation="eager", # Important (Ref: https://github.com/huggingface/transformers/blob/c15a7adb283fa984a40558c7fe7bed30ae975cdd/src/transformers/models/gemma3/modeling_gemma3.py#L934)
    quantization_config=bnb_config
)
processor = AutoProcessor.from_pretrained(model_id)
processor.tokenizer.padding_side = "right"

接下來,我們設定量化低秩適配 (Quantized Low-Rank Adaptation, QLoRA),這是一種針對大型語言模型 (LLMs) 和視覺語言模型 (VLMs) 的高效微調技術。

from peft import LoraConfig, get_peft_model

# Configure QLoRA
peft_config = LoraConfig(
    lora_alpha=16,
    lora_dropout=0.05,
    r=16,
    bias="none",
    target_modules="all-linear",
    task_type="CAUSAL_LM",
    modules_to_save=[
        "lm_head",
        "embed_tokens",
    ],
)

QLoRA 設定完成後,我們需要為 SFT 定義訓練引數。SFTConfig 類簡化了這一過程,提供了一種根據我們的具體需求輕鬆調整引數的方法。

from trl import SFTConfig

training_args = SFTConfig(
    output_dir="gemma-3-4b-it-trl-sft-llava-instruct-mix-vsft",     # Directory to save the model and push to the Hub. Use a specific repository id (e.g., gemma-3-4b-it-trl-sft-MMIU-Benchmark for multi-image datasets).
    num_train_epochs=1,                                             # Set the number of epochs to train the model.
    per_device_train_batch_size=8,                                  # Batch size for each device (e.g., GPU) during training. multi-image -> per_device_train_batch_size=1
    gradient_accumulation_steps=4,                                  # Number of steps before performing a backward/update pass to accumulate gradients. multi-image -> gradient_accumulation_steps=1
    gradient_checkpointing=True,                                    # Enable gradient checkpointing to reduce memory usage during training.
    optim="adamw_torch_fused",                                      # Use the fused AdamW optimizer for better performance.
    save_strategy="epoch",                                          # Save checkpoints at the end of each epoch.
    learning_rate=2e-05,                                            # Learning rate for training.
    bf16=True,                                                      # Enable bfloat16 precision for training to save memory and speed up computations.
    push_to_hub=True,                                               # Automatically push the fine-tuned model to Hugging Face Hub after training.
    report_to="tensorboard",                                        # Automatically report metrics to tensorboard.
    gradient_checkpointing_kwargs={"use_reentrant": False},         # Set gradient checkpointing to non-reentrant to avoid issues.
    dataset_kwargs={"skip_prepare_dataset": True},                  # Skip dataset preparation to handle preprocessing manually.
    remove_unused_columns=False,                                    # Ensure unused columns are not removed in the collator (important for batch processing).
)

collate_fn 負責處理和準備單個樣本以形成一個批次。

批次中的每個樣本都會經歷以下步驟:

  1. 聊天模板應用於文字。
  2. 處理器textsimages 進行分詞,將它們編碼成張量。
  3. 用於訓練的標籤被設定為樣本的 input_ids
  4. 在損失計算過程中,某些特殊標記掩碼(忽略)
    • pad_token_id
    • <image_token_id>
    • <image_soft_token>(對應 ID 262144

這個過程在不同型別的資料集中是相似的,只是在處理圖片的方式上略有不同:

  • 單張圖片 + 文字 → 一個圖片列表被直接處理。
  • 多張圖片 + 文字 → 使用一個由圖片列表組成的列表,其中每個批次元素包含多張圖片。
from PIL import Image

# For multi-image cases
def process_vision_info(messages: list[dict]) -> list[Image.Image]:
    image_inputs = []
    for msg in messages:
        content = msg.get("content", [])
        if not isinstance(content, list):
            content = [content]

        for element in content:
            if isinstance(element, dict) and ("image" in element or element.get("type") == "image"):
                if "image" in element:
                    image = element["image"]
                else:
                    image = element
                if image is not None:
                    image = Image.open(io.BytesIO(image["bytes"]))
                    image_inputs.append(image.convert("RGB"))
    return image_inputs

def collate_fn(examples):
    texts = [processor.apply_chat_template(example["messages"], tokenize=False, add_generation_prompt=False).strip() for example in examples]
    if "images" in examples[0]:  # single-image
        images = [
            [img.convert("RGB") for img in example["images"]]
            for example in examples
        ]
    else:  # multi-image
        images = [process_vision_info(example["messages"]) for example in examples]

    # Tokenize the texts and process the images
    batch = processor(
        images=images, text=texts, return_tensors="pt", padding=True
    )  # Encode texts and images into tensors

    # The labels are the input_ids, and we mask the padding tokens in the loss computation
    labels = batch["input_ids"].clone()  # Clone input IDs for labels
    # Mask image tokens
    image_token_id = [
        processor.tokenizer.convert_tokens_to_ids(processor.tokenizer.special_tokens_map["boi_token"])
    ]
    # Mask tokens for not being used in the loss computation
    labels[labels == processor.tokenizer.pad_token_id] = -100
    labels[labels == image_token_id] = -100
    labels[labels == 262144] = -100

    batch["labels"] = labels
    return batch  # Return the prepared batch

訓練模型

所有元件都設定好後,我們現在使用先前定義的設定來配置 SFTTrainer,並開始訓練過程。

# Training
from trl import SFTTrainer

trainer = SFTTrainer(
    model=model,
    args=training_args,
    data_collator=collate_fn,
    train_dataset=dataset["train"], # multi-image -> train_dataset=dataset["test"],
    processing_class=processor,
    peft_config=peft_config,
)

trainer.train()

# Save the final model
trainer.save_model()

我們將微調後的模型儲存到 Hub,使其易於將來使用。此外,TRL 會根據所選配置,自動將訓練結果記錄到 Weights & Biases (Wandb)TensorBoard

結果

在訓練期間和之後,我們可以使用 Weights & Biases (Wandb)TensorBoard 來檢查結果。例如:

侷限性

目前,微調 Gemma 存在一些已知的侷限性。我們建議遵循本指南中概述的步驟以確保最佳結果。

參考文獻

如需進一步閱讀和補充資源,請檢視以下內容:

< > 在 GitHub 上更新

© . This site is unofficial and not affiliated with Hugging Face, Inc.