不讓任何 GPU 掉隊:TRL 中 vLLM 協同部署以釋放效率

釋出日期:2025 年 6 月 3 日
在 GitHub 上更新

🚀 引言

TRL 支援使用 GRPO 訓練大型語言模型(LLM),GRPO 是一種最近在《DeepSeekMath 論文》中引入的線上學習演算法。在 GRPO 中,模型從其自身輸出中學習:它在訓練期間生成響應,接收反饋,並利用該反饋隨時間推移改進自身。

這使得生成成為訓練迴圈中的關鍵步驟,也是一個主要的瓶頸。為了加快生成速度,TRL 與 vLLM 整合。這種組合允許您在 GRPO 設定中更高效地訓練強大的模型。然而,這裡有一個問題。

🧨 問題所在

在 TRL v0.18.0 之前,vLLM 僅支援**伺服器模式**,作為獨立程序在與訓練作業不同的 GPU 上執行。它透過 HTTP 與訓練指令碼通訊,這使得設定模組化且易於使用——但也引入了 GPU 效率低下問題。

這是發生的情況:

  • 在訓練期間,模型需要頻繁生成補全。
  • 訓練器向 vLLM 伺服器傳送請求,該伺服器執行在自己的 GPU 上。
  • 當 vLLM 生成時,**訓練 GPU 處於空閒狀態**並等待。
  • 一旦生成完成,**vLLM GPU 變為空閒狀態**,訓練恢復。

訓練和生成之間的這種“乒乓效應”導致:

  • 雙方 GPU 時間浪費
  • **額外 GPU** 需求增加,僅用於執行推理
  • 整體**吞吐量降低,成本更高**

在像 GRPO 這樣的線上學習方法中——生成持續發生——這種低效率變得更加令人痛苦。您在硬體上花費更多,但卻無法獲得預期的效能。

**因此,關鍵問題是:**_我們能否將訓練和生成共享同一批 GPU,而不是將它們分開?_

💡 機遇

主要問題在於訓練和推理在不同的 GPU 上執行,導致空閒時間和資源利用不足。自然而然的解決方案是:兩者都在同一批 GPU 上執行。如果 vLLM 可以與訓練程式碼一起執行,在同一個分散式程序組中,而不是作為獨立伺服器在自己的程序和裝置中執行呢?這將允許我們啟動一個單一的分散式作業,其中訓練和推理共享相同的裝置,在任務之間高效切換而不會浪費資源。

這種方法我們稱之為**協同部署**。訓練和推理協同部署在同一批 GPU 上,並透過相同的程序組進行協調,允許它們平穩地輪流執行——無需額外的硬體。

以前,這在 TRL 中是不可能的,它依賴於 vLLM 作為外部 HTTP 伺服器。透過我們的 PR #3394,這種情況發生了改變,該 PR 添加了對 vLLM 外部啟動器和與訓練過程的真正整合的支援。

它能實現什麼

  • 統一執行:透過將 vLLM 嵌入到同一個程序組中,訓練和推理任務可以共享相同的 GPU,輪流執行而不是相互等待。這減少了空閒時間並提高了整體效率。

  • 跳過 HTTP 通訊:無需 REST API 呼叫或網路通訊——vLLM 與訓練迴圈內聯執行,避免了開銷和延遲。

  • Torchrun 相容性:與 torchrun 無縫協作,因此易於透過最少的配置更改進行跨節點擴充套件。

  • TP 和 DP 支援:與張量並行 (Tensor Parallelism) 和資料並行 (Data Parallelism) 相容,使其適用於大規模訓練執行。

  • SPMD 執行模式:使用單程式多資料 (SPMD) 模型,其中每個 GPU 同步執行其自身的引擎例項。適用於分散式多 GPU、多節點設定。

  • 簡化部署:您不再需要維護單獨的伺服器指令碼——vLLM 直接在您的訓練作業中啟動和控制。

  • 提高吞吐量:透過避免 GPU 空閒和消除程序間通訊,系統提供更快的訓練和生成速度,這在 GRPO 等線上學習設定中尤為重要。

  • 健壯的程序間通訊:這更健壯,因為它避免了像伺服器模式中那樣在獨立程序之間設定分散式程序組的複雜性。

得益於此功能,協同訓練和推理不再是權宜之計——它現在是**一流的、可擴充套件的、生產就緒的**。

🧩 設計:從獨立伺服器到共享 GPU

從伺服器 TRL 到協同部署 TRL 的轉變完全是為了更智慧地利用 GPU。下圖顯示了差異:

gpus-design

伺服器 TRL 設定(上排)

在伺服器 TRL 設定中,訓練和推理在不同的 GPU 上執行。例如:

  • GPU 0 到 2 用於訓練。
  • GPU 3 完全用於執行 vLLM 作為獨立伺服器。

在訓練步驟中,**GPU 3 處於空閒狀態**。在生成步驟(推理)中,當 GPU 3 生成輸出時,**GPU 0-2 處於空閒狀態**。

這導致:

  • GPU 使用效率低下,裝置經常相互等待
  • 額外配置 GPU 僅用於推理
  • 增加成本和複雜性

協同部署 TRL 設定(下排)

相反,協同部署 TRL 設定在**相同的 GPU** 上執行訓練和 vLLM。每個 GPU:

  • 執行訓練迴圈
  • 在**同一個程序**中啟動 vLLM 引擎

訓練和推理**輪流**使用 GPU 的資源——無需專用裝置或獨立程序。

此設計:

  • 減少空閒時間
  • 最小化程序間和 HTTP 通訊
  • 充分利用可用的 GPU 記憶體和計算資源
  • 在不增加硬體需求的情況下提供**更快的吞吐量**

🛠️ 實施說明

現在,訓練器啟動 vLLM **程序內**,使用外部啟動器,而不是將 vLLM 作為伺服器啟動,如下圖所示:

self.llm = LLM(
    model=model.name_or_path,
    tensor_parallel_size=args.vllm_tensor_parallel_size,
    gpu_memory_utilization=self.vllm_gpu_memory_utilization,
    max_num_seqs=self.args.per_device_train_batch_size
        * self.vllm_tensor_parallel_size
        * self.args.gradient_accumulation_steps,
    max_model_len=self.max_prompt_length + self.max_completion_length,
    distributed_executor_backend="external_launcher",
    # Feed identical seed for tp groups to ensure sampling results are the same across workers
    seed=self.accelerator.process_index // self.vllm_tensor_parallel_size,
)

協同部署的 vLLM 遵循 `torch.distributed` 程序組和秩結構。這使得 vLLM 可以在訓練的同時初始化而不會發生衝突,並使 TP/DP 設定無縫執行。

if self.vllm_tensor_parallel_size > 1:
    # Create subgroups of ranks for TP, each group with `vllm_tensor_parallel_size` ranks.
    self.tp_group, _ = torch.distributed.new_subgroups_by_enumeration(
        [
            list(range(i * self.vllm_tensor_parallel_size, (i + 1) * self.vllm_tensor_parallel_size))
            for i in range(self.accelerator.num_processes // self.vllm_tensor_parallel_size)
        ]
    )

協同部署的 vLLM 不再依賴 REST API——它直接在記憶體中執行並透過原生 Python 呼叫進行通訊。

if self.vllm_tensor_parallel_size > 1:
    orig_size = len(prompts_text)
    gathered_prompts = [None for _ in range(self.vllm_tensor_parallel_size)]
    torch.distributed.all_gather_object(gathered_prompts, prompts_text, group=self.tp_group)
    all_prompts_text = [p for sublist in gathered_prompts for p in sublist]
else:
    all_prompts_text = prompts_text

with profiling_context(self, "vLLM.generate"):
    all_outputs = self.llm.generate(all_prompts_text, sampling_params=sampling_params, use_tqdm=False)

completion_ids = [output.token_ids for outputs in all_outputs for output in outputs.outputs]

if self.vllm_tensor_parallel_size > 1:
    local_rank_in_group = torch.distributed.get_rank(group=self.tp_group)
    tp_slice = slice(local_rank_in_group * orig_size, (local_rank_in_group + 1) * orig_size)
    completion_ids = completion_ids[tp_slice]

要使用此設定,只需在 GRPO 配置中將 `vllm_mode="colocate"`。

training_args = GRPOConfig(
    ...,
    use_vllm=True,
    vllm_mode="colocate",
)

注意:根據模型大小和訓練所需的總 GPU 記憶體,您可能需要調整 `GRPOConfig` 中的 `vllm_gpu_memory_utilization` 引數,以避免資源利用不足或記憶體不足錯誤。

📊 展示:協同部署與普通 TRL 效能對比

為了衡量協同部署的影響,我們進行了一系列實驗,比較了傳統的**伺服器模式**(vLLM 作為獨立伺服器在單獨的 GPU 上執行)與新的**協同部署模式**(訓練和推理共享相同的 GPU)。

在**伺服器模式**下,僅使用 7 個 GPU 進行訓練,因為 1 個 GPU 完全專用於 vLLM 推理伺服器。

在**協同部署模式**下,所有 8 個 GPU 都用於訓練——預設情況下增加了有效批次大小。

為了確保公平比較,我們**將伺服器模式下的吞吐量標準化為 8/7**。此調整考慮了協同部署模式下更大的訓練容量,並允許我們在相同的訓練條件下比較兩種設定。

實驗 1:1.5B 模型 — 不同批次大小

  • 隨著批次大小的增加,兩種設定的吞吐量都有所改善。
  • **協同部署設定在最大批次大小下達到 1.43 倍加速。**
  • 更大的批次可以更好地利用協同部署模式下共享的 GPU 記憶體。small-b

實驗 2:1.5B 模型 — 不同張量並行度 (TP)

  • 在協同部署設定中,增加 TP 會**降低效能**。
  • 更多的分片會引入更多的通訊開銷——這**不適合小型模型**。
  • **啟示**:對於小型模型,在協同部署模式下避免過度分片。small-tp

實驗 3:7B 模型 — 不同批次大小

  • 同樣,協同部署模式**隨著批次大小的增加而擴充套件性更好**。
  • 在測試的最大批次下,增益達到**1.35 倍加速**。med-b

實驗 4:7B 模型 — 不同張量並行度 (TP)

  • 與 1.5B 模型相反的趨勢。
  • 對於 7B 模型,**增加 TP 可以提高吞吐量**,最高可達**1.73 倍加速**。
  • **協同部署設定中,大型模型從分片中受益。** med-tp

📊 擴充套件到 72B 模型

在訓練像 **Qwen2.5-Math-72B** 這樣的大型模型時,採用正確的策略以確保在多 GPU 和多節點上實現高效、可擴充套件和穩定的訓練至關重要。在我們的設定中,我們將**協同部署的 vLLM** 與多個關鍵最佳化相結合,以實現高效執行。

vLLM 中的休眠模式

在使用協同訓練時,管理 GPU 記憶體至關重要,以便訓練和推理都能在同一裝置上平穩執行。為支援此功能,我們已將 vLLM 的 `sleep()` API 新增到 GRPO 訓練迴圈中。

`sleep()` 函式暫時暫停 vLLM 引擎並釋放 GPU 記憶體。它支援兩個級別:

  • **級別 1**:從 GPU 解除安裝模型權重(保留在 CPU 記憶體中)並清除 KV 快取。當同一模型即將被重複使用時很有用。

  • **級別 2**:完全解除安裝模型權重和 KV 快取。最適合模型將更改或不會立即重複使用的情況。

在 GRPO 中,模型在每一步之後都會更新——因此我們使用**級別 2 休眠**。

級別 2 休眠的優勢:

  • **最大化訓練的空閒 GPU 記憶體**
  • **避免訓練和生成之間的記憶體爭用**
  • 即使對於像 Qwen2.5-72B 這樣的大型模型,也能保持協同部署的效率

這個小小的改動在實現平穩、可擴充套件的協同訓練方面發揮了**巨大作用**。

DeepSpeed 最佳化

為了訓練像 Qwen2.5-72B 這樣的大型模型,我們依賴於 **DeepSpeed ZeRO Stage 3**,這與普通 TRL 中使用的設定相同。

ZeRO 透過在 GPU 之間分配記憶體來幫助擴充套件大型模型。Stage 3 更進一步,透過分割槽:

  • 模型權重
  • 梯度
  • 最佳化器狀態

這對於無法放入單個 GPU 的模型至關重要。使用 ZeRO Stage 3,每個 GPU 只處理模型的一部分。

我們啟用的其他選項:

  • "offload_optimizer": {"device": "cpu"} 將最佳化器狀態移動到 CPU 以釋放 GPU 記憶體——這在協同部署設定中至關重要。

  • "overlap_comm": true 啟用通訊與計算重疊,加速訓練。

  • "contiguous_gradients": true 在單個記憶體塊中分配梯度,改善記憶體訪問並減少碎片化。

這些最佳化有助於**高效訓練 72B 模型**,並確保在嚴格的記憶體限制下協同部署保持穩定。

Accelerate 整合

正如 TRL 中推薦的那樣,我們使用 **Accelerate**,一個輕量級庫,它簡化了分散式訓練。它處理:

  • 多 GPU 和多節點作業啟動
  • 資料並行
  • 梯度累積
  • 分散式資料載入

這使得設定簡潔、可擴充套件且易於維護。

實驗 5:Qwen2.5-Math-72B — 吞吐量、準確性和基準測試結果

吞吐量

即使**減少 4 個 GPU**,**協同部署設定仍比普通 TRL 快約 1.26 倍**。這突出了更智慧的 GPU 共享和使用 `sleep()` 進行記憶體清理的有效性。72b-tput

獎勵曲線

協同部署和普通設定的訓練獎勵圖**幾乎相同**,這表明:

  • 協同部署訓練保持了準確性
  • **模型學習效能沒有退步** blogpost_72b_rewards

Math500 基準測試

我們評估了三個模型:**基礎模型**、**協同訓練模型**和**普通訓練模型**在 Math500 基準測試中的表現。兩個訓練模型都**優於基礎模型**,並且**協同部署模型與普通訓練模型表現相當**——證實了協同部署不會影響下游效能。blogpost_72b_math500

🎓 挑戰、經驗教訓和後續步驟

透過我們利用協同部署 vLLM 擴充套件 GRPO 訓練的工作,我們面臨了幾個關鍵挑戰,並就大型模型訓練的效率、靈活性和系統設計汲取了重要的經驗教訓。

挑戰

  • vLLM ≥ 0.8.0 中的張量並行度 Bug。vLLM 0.8.0 及更高版本中的張量並行度 (TP) 與 external_launcher 停止工作。這在問題 #15895 中進行了跟蹤。為了確定破壞點,我們遵循了這篇 vLLM 開發者部落格文章中描述的方法,該文章提供了每個提交的輪子。經過二分法查詢,我們確定破壞性提交為 cc10281。根本原因是確定性——新版本需要明確設定隨機種子。一旦設定了種子,問題就消失了。

  • **二級休眠緩衝區 Bug。**最初,當我們嘗試使用 `load_weights` 重新載入權重時,二級休眠無法正常工作。這個問題在 Issue #16564 中進行了跟蹤。問題是模型緩衝區(例如 BatchNorm 中的執行均值/方差)在從休眠中喚醒後沒有恢復。修復方法是 PR #16889,它添加了在從二級休眠喚醒時明確恢復緩衝區的邏輯。我們現在保留原始緩衝區的副本,並在載入新權重後手動重新應用它們。

  • **退出時發生段錯誤。**vLLM 休眠在訓練結束時關閉程序時仍存在一個未解決的問題,會導致段錯誤。這在問題 #16993 中報告。此崩潰發生在關機期間,但不會中斷訓練本身,因此我們能夠完成本部落格中分享的所有演示和實驗。但是,我們正在等待官方修復,然後才能將 sleep() 完全整合到 TRL 上游。

這些挑戰並非阻礙,但它們需要仔細的除錯、版本控制,以及對 vLLM 如何管理記憶體和並行性的更深入理解。

經驗教訓

  • 協同部署推理顯著提高了 GPU 利用率。透過允許訓練和生成共享相同的 GPU,我們消除了空閒時間並降低了硬體需求——即使使用更少的 GPU 也能實現更高的吞吐量。

  • vLLM 的 `sleep()` 功能對於大規模協同部署至關重要。它實現了對記憶體使用的細粒度控制,允許訓練在生成步驟之間完全回收 GPU 記憶體——這是像 Qwen2.5-72B 這樣的模型實現的關鍵。

  • DeepSpeed ZeRO Stage 3 對於訓練大型模型至關重要。它透過在多個 GPU 上分配模型權重、梯度和最佳化器狀態,使超大型網路能夠適應記憶體。根據我們的經驗,啟用 `contiguous_gradients` 有助於減少記憶體碎片,而將最佳化器解除安裝到 CPU 則釋放了關鍵的 GPU 記憶體——這兩者在協同部署設定中都特別有用。

  • 協同部署功能強大,但也伴隨著權衡。它在仔細管理 GPU 記憶體時效果最佳,通常需要手動調整記憶體使用引數,例如 `vllm_gpu_memory_utilization`。雖然它提供了明顯的吞吐量優勢並減少了 GPU 空閒時間,但協同部署可能不適合記憶體預算緊張或記憶體碎片控制不佳的模型。但是,如果做得好,它會帶來顯著的效率提升。

  • TP/DP 相容性、Accelerate 和 torchrun 支援使部署無縫。儘管底層架構複雜,但整個系統可以使用標準分散式工具啟動和擴充套件。

  • 協同訓練保持模型質量。在多個基準測試(Math500、AIME24)中,協同部署和普通設定產生了可比較的結果,驗證了效能不會因效率而犧牲。

✅ 結論

這篇部落格文章探討了將 vLLM 與 GRPO 訓練協同部署如何在大語言模型訓練(包括 Qwen2.5-72B 等大型模型)中實現顯著的效率提升。

傳統上,TRL 僅支援伺服器模式下的 vLLM,這需要獨立的推理程序和 GPU,導致計算浪費和空閒時間。隨著 vLLM 外部啟動器和 TRL 中協同部署 PR PR #3394 的引入,我們現在可以在同一分散式程序組、同一 GPU 上執行訓練和推理,並完全支援 TP、DP 和 Accelerate。

儘管仍存在挑戰——例如特定版本 vLLM bug 和 `sleep()` 等邊緣情況——但總體結果表明,協同部署 GRPO 是高效訓練大型模型的一種實用、可擴充套件的解決方案。我們很高興繼續完善此設定,整合 FSDP 等功能,並突破大型模型訓練的極限——使其更快、更便宜、更易於所有人構建下一代 LLM。

✅ 試一試!

下面是一個嘗試使用協同部署 vLLM 進行 GRPO 訓練的示例。

📄 train_grpo_colocate.py

from datasets import load_dataset
from trl import GRPOConfig, GRPOTrainer

# Load dataset
dataset = load_dataset("trl-lib/tldr", split="train")

# Define the reward function
def reward_len(completions, **kwargs):
    return [-abs(20 - len(completion)) for completion in completions]

# Define training arguments
training_args = GRPOConfig(
    output_dir="Qwen2-0.5B-GRPO",
    logging_steps=1,
    use_vllm=True,
    vllm_mode="colocate",
    vllm_tensor_parallel_size=1,
    vllm_gpu_memory_utilization=0.3,
    max_prompt_length=512,
    max_completion_length=1024,
    max_steps=2,
    num_generations=4,
    num_train_epochs=1,
    per_device_train_batch_size=4,
    push_to_hub=False,
    report_to=None
)

# Create and run the trainer
trainer = GRPOTrainer(
    model="Qwen/Qwen2-0.5B-Instruct",
    reward_funcs=reward_len,
    args=training_args,
    train_dataset=dataset,
)

trainer.train()

社群

這項工作很棒,感謝詳細的撰寫。根據我們的經驗,這種方法對於大規模多節點訓練非常有效。我們已經看到訓練 32B 模型時,訓練速度提高了 3 倍。

·
文章作者

太棒了!感謝分享!

示例程式碼 `train_grpo_colocate.py` 需要使用 accelerate 啟動嗎?僅僅使用 `python3 train_grpo_colocate.py` 執行會丟擲關於缺少環境變數("RANK", "LOCAL_RANK"...)的異常。

·
文章作者

是的!

`vllm_mode="colocate"` 能與 PEFT 配合使用嗎?

·
文章作者

@lhkhiem28 實際上我們沒有嘗試過這個,但是它沒有理由不工作,因為 LoRA 與模型訓練相關,而我們的更改與生成相關。但是,似乎下面的 @ajinkya-tejankar 已經嘗試過並且看起來是可行的。

很棒的文章!協同部署模式是否計劃支援資料並行?

·
文章作者

DP是支援的。
例如,如果 GPU 數量 = 8 且 vllm_tensor_parallel_size = 2 → 組:[0,1], [2,3], [4,5], [6,7] -> 使 DP=4

DeepSpeed 是否計劃成為未來支援 TRL 多 GPU 和多節點設定的主要引擎?我嘗試了 FSDP,但它與許多 DeepSpeed 可用的配置不相容。例如,我無法讓 GRPO + FSDP + LoRA + VLLM colocate 協同工作,但將 FSDP 替換為 DeepSpeed 就可以。DeepSpeed 比 PyTorch 的普通 FSDP 更可靠嗎?

附言:很棒的部落格!非常感謝您的努力 :)

·
文章作者

@ajinkya-tejankar 在我們的內部實驗中,我們嘗試將 FSDP2 整合到 accelerate 中,並用 colocate 進行了測試。我認為仍然存在一些問題。1. TRL 的權重載入程式碼我認為只適用於 FSDP1。2. FSDP1 存在 NaN 問題,我之前提交了一個 bug 報告 https://github.com/vllm-project/vllm/issues/14443

請參閱之前的討論
https://github.com/huggingface/trl/pull/3317#issuecomment-2842576427

非常感謝這篇精彩的文章。
您的文章對於在協同部署模式下訓練 GRPO 幫助巨大。

順便問一下,您是否曾使用 LoRA 訓練過模型?
您提到訓練了一個 72B 模型,但我無法訪問 32 個 GPU,因此無法進行完全微調。

當使用 `DeepSpeed ZeRO-3` + `vLLM colocate` + `LoRA` + `GRPO` 的組合訓練模型,並在 LoRA 配置中配置 `modules_to_save=["embed_tokens", "lm_head"]`(如下所示)時,我遇到了底部的錯誤。
如果您有任何用於訓練 72B 模型的解決方案或技巧,我將不勝感激。

我使用的庫版本是:

trl==0.18.2
peft==0.15.2
transformers==4.52.4
deepspeed==0.17.1 

LoRA 配置

lora_config = LoraConfig(
    r=training_config["rank"],
    lora_alpha=training_config["alpha"],
    target_modules=[
        "q_proj",
        "k_proj",
        "v_proj",
        "o_proj",
        "gate_proj", 
        "up_proj",
        "down_proj",
    ],
    lora_dropout=training_config["dropout"],
    bias="none",
    task_type="CAUSAL_LM",
    modules_to_save=["embed_tokens", "lm_head"],
)

錯誤

AttributeError: 'Linear' object has no attribute 'ds_grads_remaining'

完整的錯誤日誌如下:

[rank0]: Traceback (most recent call last):
[rank0]:   File "/workspace/LLMTrainFlow/./src/train/rl_gemma3.py", line 180, in <module>
[rank0]:     trainer.train()
[rank0]:   File "/usr/local/lib/python3.11/dist-packages/transformers/trainer.py", line 2240, in train
[rank0]:     return inner_training_loop(
[rank0]:            ^^^^^^^^^^^^^^^^^^^^
[rank0]:   File "/usr/local/lib/python3.11/dist-packages/transformers/trainer.py", line 2555, in _inner_training_loop
[rank0]:     tr_loss_step = self.training_step(model, inputs, num_items_in_batch)
[rank0]:                    ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
[rank0]:   File "/usr/local/lib/python3.11/dist-packages/transformers/trainer.py", line 3745, in training_step
[rank0]:     loss = self.compute_loss(model, inputs, num_items_in_batch=num_items_in_batch)
[rank0]:            ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
[rank0]:   File "/usr/local/lib/python3.11/dist-packages/trl/extras/profiling.py", line 96, in wrapper
[rank0]:     return func(self, *args, **kwargs)
[rank0]:            ^^^^^^^^^^^^^^^^^^^^^^^^^^^
[rank0]:   File "/usr/local/lib/python3.11/dist-packages/trl/trainer/grpo_trainer.py", line 1330, in compute_loss
[rank0]:     return self._compute_loss(model, inputs)
[rank0]:            ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
[rank0]:   File "/usr/local/lib/python3.11/dist-packages/trl/trainer/grpo_trainer.py", line 1340, in _compute_loss
[rank0]:     per_token_logps = self._get_per_token_logps(model, input_ids, attention_mask, logits_to_keep)
[rank0]:                       ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
[rank0]:   File "/usr/local/lib/python3.11/dist-packages/trl/extras/profiling.py", line 96, in wrapper
[rank0]:     return func(self, *args, **kwargs)
[rank0]:            ^^^^^^^^^^^^^^^^^^^^^^^^^^^
[rank0]:   File "/usr/local/lib/python3.11/dist-packages/trl/trainer/grpo_trainer.py", line 852, in _get_per_token_logps
[rank0]:     logits = model(
[rank0]:              ^^^^^^
[rank0]:   File "/usr/local/lib/python3.11/dist-packages/torch/nn/modules/module.py", line 1739, in _wrapped_call_impl
[rank0]:     return self._call_impl(*args, **kwargs)
[rank0]:            ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
[rank0]:   File "/usr/local/lib/python3.11/dist-packages/torch/nn/modules/module.py", line 1750, in _call_impl
[rank0]:     return forward_call(*args, **kwargs)
[rank0]:            ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
[rank0]:   File "/usr/local/lib/python3.11/dist-packages/deepspeed/utils/nvtx.py", line 20, in wrapped_fn
[rank0]:     ret_val = func(*args, **kwargs)
[rank0]:               ^^^^^^^^^^^^^^^^^^^^^
[rank0]:   File "/usr/local/lib/python3.11/dist-packages/deepspeed/runtime/engine.py", line 2087, in forward
[rank0]:     loss = self.module(*inputs, **kwargs)
[rank0]:            ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
[rank0]:   File "/usr/local/lib/python3.11/dist-packages/torch/nn/modules/module.py", line 1739, in _wrapped_call_impl
[rank0]:     return self._call_impl(*args, **kwargs)
[rank0]:            ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
[rank0]:   File "/usr/local/lib/python3.11/dist-packages/torch/nn/modules/module.py", line 1845, in _call_impl
[rank0]:     return inner()
[rank0]:            ^^^^^^^
[rank0]:   File "/usr/local/lib/python3.11/dist-packages/torch/nn/modules/module.py", line 1793, in inner
[rank0]:     result = forward_call(*args, **kwargs)
[rank0]:              ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
[rank0]:   File "/usr/local/lib/python3.11/dist-packages/peft/peft_model.py", line 1757, in forward
[rank0]:     return self.base_model(
[rank0]:            ^^^^^^^^^^^^^^^^
[rank0]:   File "/usr/local/lib/python3.11/dist-packages/torch/nn/modules/module.py", line 1739, in _wrapped_call_impl
[rank0]:     return self._call_impl(*args, **kwargs)
[rank0]:            ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
[rank0]:   File "/usr/local/lib/python3.11/dist-packages/torch/nn/modules/module.py", line 1845, in _call_impl
[rank0]:     return inner()
[rank0]:            ^^^^^^^
[rank0]:   File "/usr/local/lib/python3.11/dist-packages/torch/nn/modules/module.py", line 1793, in inner
[rank0]:     result = forward_call(*args, **kwargs)
[rank0]:              ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
[rank0]:   File "/usr/local/lib/python3.11/dist-packages/peft/tuners/tuners_utils.py", line 193, in forward
[rank0]:     return self.model.forward(*args, **kwargs)
[rank0]:            ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
[rank0]:   File "/usr/local/lib/python3.11/dist-packages/transformers/utils/generic.py", line 969, in wrapper
[rank0]:     output = func(self, *args, **kwargs)
[rank0]:              ^^^^^^^^^^^^^^^^^^^^^^^^^^^
[rank0]:   File "/usr/local/lib/python3.11/dist-packages/transformers/models/gemma3/modeling_gemma3.py", line 880, in forward
[rank0]:     logits = self.lm_head(hidden_states[:, slice_indices, :])
[rank0]:              ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
[rank0]:   File "/usr/local/lib/python3.11/dist-packages/torch/nn/modules/module.py", line 1739, in _wrapped_call_impl
[rank0]:     return self._call_impl(*args, **kwargs)
[rank0]:            ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
[rank0]:   File "/usr/local/lib/python3.11/dist-packages/torch/nn/modules/module.py", line 1845, in _call_impl
[rank0]:     return inner()
[rank0]:            ^^^^^^^
[rank0]:   File "/usr/local/lib/python3.11/dist-packages/torch/nn/modules/module.py", line 1782, in inner
[rank0]:     args_result = hook(self, args)
[rank0]:                   ^^^^^^^^^^^^^^^^
[rank0]:   File "/usr/local/lib/python3.11/dist-packages/torch/_dynamo/eval_frame.py", line 745, in _fn
[rank0]:     return fn(*args, **kwargs)
[rank0]:            ^^^^^^^^^^^^^^^^^^^
[rank0]:   File "/usr/local/lib/python3.11/dist-packages/deepspeed/runtime/zero/parameter_offload.py", line 378, in _post_backward_module_hook
[rank0]:     return apply_to_tensors_only(module.post_bwd_fn.apply,
[rank0]:            ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
[rank0]:   File "/usr/local/lib/python3.11/dist-packages/deepspeed/runtime/zero/utils.py", line 133, in apply_to_tensors_only
[rank0]:     touched_output = apply_to_tensors_only(function, elem)
[rank0]:                      ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
[rank0]:   File "/usr/local/lib/python3.11/dist-packages/deepspeed/runtime/zero/utils.py", line 149, in apply_to_tensors_only
[rank0]:     touched_output = function(value)
[rank0]:                      ^^^^^^^^^^^^^^^
[rank0]:   File "/usr/local/lib/python3.11/dist-packages/torch/autograd/function.py", line 575, in apply
[rank0]:     return super().apply(*args, **kwargs)  # type: ignore[misc]
[rank0]:            ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
[rank0]:   File "/usr/local/lib/python3.11/dist-packages/deepspeed/runtime/zero/parameter_offload.py", line 446, in forward
[rank0]:     module.ds_grads_remaining += 1
[rank0]:     ^^^^^^^^^^^^^^^^^^^^^^^^^
[rank0]:   File "/usr/local/lib/python3.11/dist-packages/torch/nn/modules/module.py", line 1928, in __getattr__
[rank0]:     raise AttributeError(
[rank0]: AttributeError: 'Linear' object has no attribute 'ds_grads_remaining'
[rank0]: Traceback (most recent call last):
[rank0]:   File "/workspace/LLMTrainFlow/./src/train/rl_gemma3.py", line 180, in <module>
[rank0]:     trainer.train()
[rank0]:   File "/usr/local/lib/python3.11/dist-packages/transformers/trainer.py", line 2240, in train
[rank0]:     return inner_training_loop(
[rank0]:            ^^^^^^^^^^^^^^^^^^^^
[rank0]:   File "/usr/local/lib/python3.11/dist-packages/transformers/trainer.py", line 2555, in _inner_training_loop
[rank0]:     tr_loss_step = self.training_step(model, inputs, num_items_in_batch)
[rank0]:                    ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
[rank0]:   File "/usr/local/lib/python3.11/dist-packages/transformers/trainer.py", line 3745, in training_step
[rank0]:     loss = self.compute_loss(model, inputs, num_items_in_batch=num_items_in_batch)
[rank0]:            ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
[rank0]:   File "/usr/local/lib/python3.11/dist-packages/trl/extras/profiling.py", line 96, in wrapper
[rank0]:     return func(self, *args, **kwargs)
[rank0]:            ^^^^^^^^^^^^^^^^^^^^^^^^^^^
[rank0]:   File "/usr/local/lib/python3.11/dist-packages/trl/trainer/grpo_trainer.py", line 1330, in compute_loss
[rank0]:     return self._compute_loss(model, inputs)
[rank0]:            ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
[rank0]:   File "/usr/local/lib/python3.11/dist-packages/trl/trainer/grpo_trainer.py", line 1340, in _compute_loss
[rank0]:     per_token_logps = self._get_per_token_logps(model, input_ids, attention_mask, logits_to_keep)
[rank0]:                       ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
[rank0]:   File "/usr/local/lib/python3.11/dist-packages/trl/extras/profiling.py", line 96, in wrapper
[rank0]:     return func(self, *args, **kwargs)
[rank0]:            ^^^^^^^^^^^^^^^^^^^^^^^^^^^
[rank0]:   File "/usr/local/lib/python3.11/dist-packages/trl/trainer/grpo_trainer.py", line 852, in _get_per_token_logps
[rank0]:     logits = model(
[rank0]:              ^^^^^^
[rank0]:   File "/usr/local/lib/python3.11/dist-packages/torch/nn/modules/module.py", line 1739, in _wrapped_call_impl
[rank0]:     return self._call_impl(*args, **kwargs)
[rank0]:            ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
[rank0]:   File "/usr/local/lib/python3.11/dist-packages/torch/nn/modules/module.py", line 1750, in _call_impl
[rank0]:     return forward_call(*args, **kwargs)
[rank0]:            ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
[rank0]:   File "/usr/local/lib/python3.11/dist-packages/deepspeed/utils/nvtx.py", line 20, in wrapped_fn
[rank0]:     ret_val = func(*args, **kwargs)
[rank0]:               ^^^^^^^^^^^^^^^^^^^^^
[rank0]:   File "/usr/local/lib/python3.11/dist-packages/deepspeed/runtime/engine.py", line 2087, in forward
[rank0]:     loss = self.module(*inputs, **kwargs)
[rank0]:            ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
[rank0]:   File "/usr/local/lib/python3.11/dist-packages/torch/nn/modules/module.py", line 1739, in _wrapped_call_impl
[rank0]:     return self._call_impl(*args, **kwargs)
[rank0]:            ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
[rank0]:   File "/usr/local/lib/python3.11/dist-packages/torch/nn/modules/module.py", line 1845, in _call_impl
[rank0]:     return inner()
[rank0]:            ^^^^^^^
[rank0]:   File "/usr/local/lib/python3.11/dist-packages/torch/nn/modules/module.py", line 1793, in inner
[rank0]:     result = forward_call(*args, **kwargs)
[rank0]:              ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
[rank0]:   File "/usr/local/lib/python3.11/dist-packages/peft/peft_model.py", line 1757, in forward
[rank0]:     return self.base_model(
[rank0]:            ^^^^^^^^^^^^^^^^
[rank0]:   File "/usr/local/lib/python3.11/dist-packages/torch/nn/modules/module.py", line 1739, in _wrapped_call_impl
[rank0]:     return self._call_impl(*args, **kwargs)
[rank0]:            ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
[rank0]:   File "/usr/local/lib/python3.11/dist-packages/torch/nn/modules/module.py", line 1845, in _call_impl
[rank0]:     return inner()
[rank0]:            ^^^^^^^
[rank0]:   File "/usr/local/lib/python3.11/dist-packages/torch/nn/modules/module.py", line 1793, in inner
[rank0]:     result = forward_call(*args, **kwargs)
[rank0]:              ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
[rank0]:   File "/usr/local/lib/python3.11/dist-packages/peft/tuners/tuners_utils.py", line 193, in forward
[rank0]:     return self.model.forward(*args, **kwargs)
[rank0]:            ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
[rank0]:   File "/usr/local/lib/python3.11/dist-packages/transformers/utils/generic.py", line 969, in wrapper
[rank0]:     output = func(self, *args, **kwargs)
[rank0]:              ^^^^^^^^^^^^^^^^^^^^^^^^^^^
[rank0]:   File "/usr/local/lib/python3.11/dist-packages/transformers/models/gemma3/modeling_gemma3.py", line 880, in forward
[rank0]:     logits = self.lm_head(hidden_states[:, slice_indices, :])
[rank0]:              ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
[rank0]:   File "/usr/local/lib/python3.11/dist-packages/torch/nn/modules/module.py", line 1739, in _wrapped_call_impl
[rank0]:     return self._call_impl(*args, **kwargs)
[rank0]:            ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
[rank0]:   File "/usr/local/lib/python3.11/dist-packages/torch/nn/modules/module.py", line 1845, in _call_impl
[rank0]:     return inner()
[rank0]:            ^^^^^^^
[rank0]:   File "/usr/local/lib/python3.11/dist-packages/torch/nn/modules/module.py", line 1782, in inner
[rank0]:     args_result = hook(self, args)
[rank0]:                   ^^^^^^^^^^^^^^^^
[rank0]:   File "/usr/local/lib/python3.11/dist-packages/torch/_dynamo/eval_frame.py", line 745, in _fn
[rank0]:     return fn(*args, **kwargs)
[rank0]:            ^^^^^^^^^^^^^^^^^^^
[rank0]:   File "/usr/local/lib/python3.11/dist-packages/deepspeed/runtime/zero/parameter_offload.py", line 378, in _post_backward_module_hook
[rank0]:     return apply_to_tensors_only(module.post_bwd_fn.apply,
[rank0]:            ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
[rank0]:   File "/usr/local/lib/python3.11/dist-packages/deepspeed/runtime/zero/utils.py", line 133, in apply_to_tensors_only
[rank0]:     touched_output = apply_to_tensors_only(function, elem)
[rank0]:                      ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
[rank0]:   File "/usr/local/lib/python3.11/dist-packages/deepspeed/runtime/zero/utils.py", line 149, in apply_to_tensors_only
[rank0]:     touched_output = function(value)
[rank0]:                      ^^^^^^^^^^^^^^^
[rank0]:   File "/usr/local/lib/python3.11/dist-packages/torch/autograd/function.py", line 575, in apply
[rank0]:     return super().apply(*args, **kwargs)  # type: ignore[misc]
[rank0]:            ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
[rank0]:   File "/usr/local/lib/python3.11/dist-packages/deepspeed/runtime/zero/parameter_offload.py", line 446, in forward
[rank0]:     module.ds_grads_remaining += 1
[rank0]:     ^^^^^^^^^^^^^^^^^^^^^^^^^
[rank0]:   File "/usr/local/lib/python3.11/dist-packages/torch/nn/modules/module.py", line 1928, in __getattr__
[rank0]:     raise AttributeError(
[rank0]: AttributeError: 'Linear' object has no attribute 'ds_grads_remaining'

我注意到 vLLM 的休眠功能並未整合到 TRL 中,這是為什麼?

·
文章作者

原因在 https://huggingface.co/blog/vllm-colocate#challenges 中的“段錯誤”討論中有所說明。基本上,我們正在等待 bug (https://github.com/vllm-project/vllm/issues/16993) 的修復,然後才能將 sleep() 完全整合到 TRL 上游。

你們在 Qwen 72B 實驗中是如何分配權重的?是僅在一個節點上以 TP=8 執行,還是每個節點都有自己的 Qwen 72B 副本?

·
文章作者

是的,我們設定了 TP=8,這意味著每個節點都擁有 72B 模型分片的副本。

很棒的文章!
我正在 Slurm 叢集中使用 VLLM 協同部署進行 GRPO,我收到一個 TCP 異常:
TCP 客戶端連線/驗證主機 10.0.1.163:35345 失敗
雖然我以為它與訓練迴圈是內聯執行的。這正常嗎?:D

·

這不正常。請確保您設定了 `vllm_mode="colocate"`。

這很棒;文件應該更新嗎?

註冊登入以評論

© . This site is unofficial and not affiliated with Hugging Face, Inc.