TRL 中的視覺語言模型對齊 ⚡️

釋出於 2025 年 8 月 7 日
在 GitHub 上更新

引言

視覺語言模型 (VLM) 越來越強大,但將其與人類偏好進行*對齊*仍然很重要。在 TRL 中,我們已經展示瞭如何透過有監督微調 (SFT)直接偏好最佳化 (DPO) 對 VLM 進行後期訓練。這一次,我們更進一步。

tl;dr 我們在 TRL 中增加了兩種新的多模態對齊方法:組相對策略最佳化 (GRPO)、其變體組序列策略最佳化 (GSPO)混合偏好最佳化 (MPO)。所有這些方法都允許您超越成對 DPO,從偏好資料中提取更多訊號,並更好地與現代 VLM 配合使用。我們釋出了訓練指令碼和演示筆記本,以便輕鬆上手!

目錄

視覺語言模型對齊

傳統上,你會使用一個基礎模型,應用 SFT 來遵循指令,然後應用 DPO 將其與偏好資料對齊。此前,我們已將此方法應用於視覺語言模型 (VLM) 並在 IDEFICS2 上進行了驗證,顯示模型響應有所改進。

DPO 透過使用對比損失最佳化模型響應對之間的偏好來工作:您有一個已選擇和已拒絕的答案,並根據您想要和不想要的內容最佳化您的偏好。

但在過去一年中,新的多模態對齊方法 GRPO 和 MPO 越來越受歡迎,它們可以進一步提升 VLM 效能。在部落格文章末尾,您可以找到一個表格,展示模型響應之間的差異。

混合偏好最佳化 (MPO)

使用 SFT 對齊多模態模型以執行推理任務會因分佈偏移而不足。同時,使用 DPO 對齊的模型無法生成連貫的推理,並且可能會生成重複的響應。為了解決這個問題,有一種專門為多模態模型設計的名為混合偏好最佳化 (MPO) 的新技術。該方法本質上是 DPO 的擴充套件,具有多個損失:來自 DPO 的偏好損失(Sigmoid)、來自二分類器最佳化 (BCO) 的質量損失和來自 SFT 的生成損失。根據論文,僅僅切換到這種組合損失就能在 MathVista 中將效能提高 6.2 分!

MPO

由於這隻修改了損失,我們為 TRL 的 DPOTrainer 類添加了組合損失支援。要使用它,您可以按如下方式初始化 DPOConfig

mpo_config = DPOConfig(
    output_dir=tmp_dir,
    per_device_train_batch_size=2,
    learning_rate=9e-1,
    loss_type=["sigmoid", "bco_pair", "sft"], # Loss types to combine, as used in the MPO paper
    loss_weights=[0.8, 0.2, 1.0], # Corresponding weights, as used in the MPO paper
    report_to="none",
    bf16=False,
    fp16=False,
    use_cpu=True,
    max_steps=1,
)

然後初始化 DPOTrainer

mpo_trainer = DPOTrainer(
    model=model_id,
    args=mpo_config,
    processing_class=tokenizer,
    train_dataset=dataset,
)
mpo_trainer.train()

就是這樣!如果您想進一步探索,可以在此處找到一個完整的筆記本示例。

多模態組相對策略最佳化 (GRPO)

組相對策略最佳化 (GRPO) 是一種尖端對齊方法,最初在DeepSeek Math 論文中引入,後來整合到開創性的 LLM DeepSeek R1 中。它是 PPO 的一個補充,其中策略更新在組(表示對話如何展開的軌跡批次)上完成。此功能使其對獎勵噪聲更加魯棒,因為噪聲在組內平均。由於模型學習的是對良好響應的更廣泛理解,而不是單一的高獎勵樣本,因此該方法也使模型具有高效能。

GRPO

在 TRL 中,我們現在為視覺語言模型引入了 GRPO 支援。我們不會提供完整的訓練指令碼示例,因為您可以在筆記本中找到它。相反,我們將重點突出關鍵元件和概念。

為了使訓練指令碼有效工作,我們需要驗證答案的格式是否正確以及解決方案本身是否接近已完成的部分,因此我們編寫了兩個獎勵函式。為了真正看到後一個獎勵的改進,您需要一個相當最大化的設定,即您擁有相對較大的模型、大量的生成以及高質量、多樣化的資料集。

import re
from math_verify import LatexExtractionConfig, parse, verify

def format_reward(completions, **kwargs):
    """Reward function that checks if the completion has a specific format."""
    pattern = r"^<think>.*?</think>\s*<answer>.*?</answer>$"
    matches = [re.match(pattern, content) for content in completions]
    rewards_list = [1.0 if match else 0.0 for match in matches]
    rewards = [1.0 if match else 0.0 for match in matches]
    print(completions)
    print(rewards)
    return rewards

def accuracy_reward(completions, **kwargs):
    """Reward function that checks if the completion is the same as the ground truth."""
    solutions = kwargs['solution']
    completion_contents = [completion[0]["content"] for completion in completions]
    rewards = []
    for content, solution in zip(completion_contents, solutions):
        gold_parsed = parse(solution, extraction_mode="first_match", extraction_config=[LatexExtractionConfig()])
        answer_parsed = parse(content, extraction_mode="first_match", extraction_config=[LatexExtractionConfig()])
        if len(gold_parsed) != 0:
            try:
                rewards.append(float(verify(answer_parsed, gold_parsed)))
            except Exception:
                rewards.append(0.0)
        else:
            rewards.append(1.0)
    return rewards

然後,您可以初始化 GRPOConfig 和 GRPOTrainer,傳入我們上面定義的獎勵函式,並呼叫 train() 開始訓練。

from trl import GRPOConfig

training_args = GRPOConfig(
    learning_rate=1e-5,
    remove_unused_columns=False,
    max_prompt_length=None,
    .. # setup other params of choice here
)
trainer = GRPOTrainer(
    model=model,
    reward_funcs=[format_reward, accuracy_reward],
    args=training_args,
    train_dataset=train_dataset,
    processing_class=processor
)
trainer.train()

此處探索完整的筆記本示例。

組序列策略最佳化 (GSPO)

組序列策略最佳化 (GSPO) 是 Qwen 最近釋出的一種 RL 對齊演算法,它克服了 GRPO 的一些侷限性。它透過在序列級別而不是每標記計算重要性取樣權重來實現更穩定的訓練。它的優點在 MoE 風格模型中更具相關性

最新版本的 TRL 也引入了對 GSPO 的支援,由於它是 GRPO 損失的一種變體,因此它支援多模態。建立訓練器的方法與 GRPO 相同,但需要新增以下額外引數(值從論文中提取)。

from trl import GRPOConfig

training_args = GRPOConfig(
        ...
    importance_sampling_level="sequence",
    epsilon=3e-4,
    epsilon_high=4e-4,
    beta=0.0,
    loss_type="grpo",
    steps_per_generation=1,
    steps_per_generation=4
)

比較

下表總結了使用上述技術微調後的 Qwen2.5VL-3B 模型輸出。請注意,我們僅對資料集子集進行了少量執行,並且模型在不同資料集上進行了微調,因此此比較僅供參考。

在此處檢視答案

圖片: 圖片

問題: 在給定的圖中,如果角 APO 為 25.0 度,則角 AOB 的度數是多少?選項:A:140° B:130° C:120° D:110°

資料集中答案: 為了求角 AOB 的度數,我們可以利用四邊形內角和為 360 度的性質。應用此性質,我們可以計算角 AOB 為 360 - 角 OAP - 角 OBP - 角 P。由於角 OAP 和角 OBP 均為 90 度,且角 P 已知為 50 度,我們將這些值代入方程得到角 AOB = 360 - 90 - 90 - 50 = 130 度。因此,角 AOB 的度數為 130 度。因此,答案是選項 B。答案:B

基礎模型輸出: 為了找到角 AOB 的度數,我們需要使用圓中同一弧所對的角屬性。具體來說,弧在圓心所對的角(角 AOB)是同一弧在圓周上任意一點所對的角(角 APO)的兩倍。已知角 APO = 25°,我們可以如下計算角 AOB:由於角 AOB 是角 APO 的兩倍,我們有:[ 角 AOB = 2 * 角 APO ][ 角 AOB = 2 * 25° ][ 角 AOB = 50° ] 然而,根據提供的選項,此計算似乎不正確。讓我們使用正確的屬性重新評估。正確的屬性指出,弧在圓心所對的角是同一弧在圓周上任意一點所對的角的兩倍。因此,正確的計算應該是:角 AOB = 2 * 角 APO = 2 * 25° = 50° 但由於這與給定的選項不符,讓我們考慮另一種方法。如果角 APO 是 25°,並且考慮到圓內接四邊形的屬性或圓內接四邊形對角和為 180° 的事實,我們可能需要檢視弧和絃形成的補角。假設正確的解釋涉及找到給定角 APO 的補角,這將給出角 AOB 的度數。角 AOB = 180° - 角 APO = 180° - 25° = 155° 這仍然與選項不符。鑑於提供的選項,似乎可能存在誤解或需要不同的解釋。然而,根據標準幾何原理,正確答案應該是:角 AOB = 2 * 角 APO = 2 * 25° = 50° 如果情況並非如此,請提供額外背景資訊或澄清問題陳述。

MPO 輸出: 為了找到角 AOB 的度數,我們需要使用圓中同一弧所對的角屬性。具體來說,弧在圓心所對的角(角 AOB)是同一弧在圓周上任意一點所對的角(角 APO)的兩倍。已知角 APO = 25°,我們可以如下計算角 AOB:由於角 AOB 是角 APO 的兩倍,我們有:[ 角 AOB = 2 * 角 APO ][ 角 AOB = 2 * 25° ][ 角 AOB = 50° ] 然而,根據提供的選項,此計算似乎不正確。讓我們使用正確的屬性重新評估。正確的屬性指出,弧在圓心所對的角是同一弧在圓周上任意一點所對的角的兩倍。因此,正確的計算應該是:角 AOB = 2 * 角 APO = 2 * 25° = 50° 但由於這與給定的選項不符,讓我們考慮另一種方法。如果我們看三角形 OAP,角 AOP 是 90°,因為它是一個直角(因為它內接於半圓)。現在,角 APO 給定為 25°,所以角 AOA 將是 65°(因為它與角 APO 互補)。因此,角 AOB 將是其一半,即 32.5°。然而,這也不符合選項。讓我們考慮問題可能涉及不同的解釋或不同的幾何配置的可能性。鑑於這些選項,我們假設問題陳述或選項本身可能存在錯誤。如果我們要從給定選項中選擇,根據此類問題的標準幾何原理,似乎沒有一個選項是正確的。

GRPO 輸出:為了求角 AOB 的度數,我們需要理解角 AOB 是一個內接角,它擷取與圓心角 AOP 相同的弧。根據內接角定理,內接角的度數是擷取相同弧的圓心角度數的一半。已知:- 角 APO = 25.0 度 由於角 AOB 是一個內接角,它擷取與圓心角 AOP 相同的弧,我們可以使用內接角定理來求角 AOB 的度數:角 AOB 的度數是角 AOP 度數的一半。因此,角 AOB 的度數為 130 度。

GSPO 輸出: 為了求角 AOB 的度數,我們需要理解角 AOB 是一個內接角,它擷取與圓心角 AOP 相同的弧。根據內接角定理,內接角的度數是擷取相同弧的圓心角度數的一半。已知:- 角 APO = 25.0 度 由於角 AOB 是一個內接角,它擷取與圓心角 AOP 相同的弧,我們可以使用內接角定理來求角 AOB 的度數:角 AOB 的度數是角 AOP 度數的一半。因此,答案是 B:130°。

TRL 中的 vLLM 整合

vLLM 已整合到 TRL 中,以支援需要在訓練期間生成樣本的線上對齊方法。執行以下示例指令碼即可啟用 vLLM

CUDA_VISIBLE_DEVICES=1,2 python3 examples/scripts/grpo_vlm.py     --model_name_or_path   Qwen/Qwen2.5-VL-3B-Instruct    …   --log_completions --use_vllm --vllm_mode colocate

主要有兩種模式:colocateservercolocate 在與訓練迴圈相同的程序中執行 vLLM,在訓練和生成之間共享同一 GPU,在 GRPOTrainer 中建立一個 vLLM LLM 例項。而 server 則要求您在不同的程序中單獨執行 vLLM,您可以在其中訪問伺服器。您可以使用以下命令啟動此伺服器

trl vllm-serve --model Qwen/Qwen2.5-VL-3B-Instruct --tensor-parallel-size 1 

然後您可以按如下方式執行指令碼。

CUDA_VISIBLE_DEVICES=1,2 python3 examples/scripts/grpo_vlm.py     --model_name_or_path   Qwen/Qwen2.5-VL-3B-Instruct    …   --log_completions --use_vllm --vllm_mode server

另一個提示:我們已添加了在 TRL 中使用 transformers 後端與 vLLM 的支援。您可以在使用 colocate 執行指令碼或提供模型時透過傳遞 --vllm_model_impl transformers 標誌來啟用它。

您可以在此處閱讀有關 TRL 中 vLLM 整合的更多資訊。

有用資源

以下是探索 VLM 對齊的詳細資源彙編。祝您閱讀愉快!

社群

註冊登入發表評論

© . This site is unofficial and not affiliated with Hugging Face, Inc.