🐯 Liger GRPO 與 TRL 的邂逅

釋出於 2025 年 5 月 25 日
在 GitHub 上更新

摘要:LigerTRL 的組相對策略最佳化 GRPO Trainer 注入了強大動力,它在不降低模型質量的前提下,將記憶體使用量減少了 40%。我們還增加了對 FSDPPEFT 的支援,使得在多個 GPU 上擴充套件 GRPO 變得前所未有地簡單。

動機

使用強化學習 (RL) 對語言模型進行微調是模型訓練生命週期中的關鍵一步,它能引導模型產生更復雜的、符合期望的行為,而這是傳統的監督式微調難以實現的。傳統上,RL 通常透過近端策略最佳化 (PPO) 演算法來最佳化大型語言模型 (LLM)。這種方法,通常與基於人類反饋的強化學習 (RLHF) 相關聯,利用一個單獨訓練的獎勵模型來指導主模型的微調。

然而,使用 PPO 的 RLHF 是一種非常消耗資源的方法——PPO 需要在記憶體中載入多個模型(策略模型、價值模型、獎勵模型和參考模型),並且還需要對獎勵模型和基礎模型進行多次迭代微調才能達到預期效果。RLHF 的成功還取決於獎勵模型有效區分模型期望行為和非期望行為的能力。

隨著 DeepSeek 的 R1 模型的推出,組相對策略最佳化 (GRPO) 近期受到了廣泛關注。GRPO 摒棄了 RLHF 中使用的預訓練獎勵模型和價值模型,轉而依賴於 *可驗證的獎勵函式*,這些函式能夠以封閉形式檢查模型輸出的正確性,而無需外部獎勵模型。這使得在那些易於驗證的領域(如教模型推理、在數學和編碼任務上表現良好)使用 GRPO 進行微調時,相較於 PPO 取得了巨大改進。

下圖展示了 GRPO 與 PPO 的訓練流程對比 (參考:DeepSeekMath: Pushing the Limits of Mathematical Reasoning in Open Language Models 論文圖 4)

PPO-vs-GRPO

話雖如此,RL 訓練仍然會佔用大量 GPU 記憶體,因此這裡仍有很大的最佳化空間。在這篇博文中,我們將討論我們最近新增到 TRL 的一項最佳化,該最佳化在 GRPO 訓練期間將峰值記憶體使用量減少了 40%,並且我們還將深入探討如何在不損失效能或正確性的情況下將 GRPO 擴充套件到多個 GPU 和節點。

Liger Kernel 如何為 GRPO 大幅削減記憶體

我們將 Liger 的分塊損失 (Chunked Loss) 方法擴充套件到了 GRPO 損失計算中,這讓我們在每個訓練步驟中都無需將完整的 logits 儲存在記憶體裡。logits 的計算涉及模型的輸出頭 (output head),是峰值記憶體使用的主要來源,尤其是在處理大詞彙表、長序列或大批次資料時。我們透過將輸入到 lm_head 的資料按批次 (batch) 分塊,並逐塊執行前向傳播來解決這個問題。

但如果你只是直接實現它,實際上並不能減少峰值記憶體,因為你仍然需要為反向傳播將所有 logits 保留在 GPU 記憶體中。為了解決這個問題,我們在前向傳播過程中計算每個損失塊(相對於 input 塊和 lm_head 權重)的梯度,然後在處理每個塊時累積這些梯度。

以下是該最佳化的視覺化圖示 (來源: Byron Hsu)

liger-chunked-loss

與 TRL 的即插即用式整合

我們最近在 PR #3184 中將 Liger GRPO 與 TRL 進行了整合,所以現在你只需在你的 GRPOConfig 中將 use_liger_loss 設定為 True,就可以使用 Liger GRPO 損失並享受記憶體節省帶來的好處!

請注意:這些功能尚未包含在最新的 TRL 版本中,因此你目前需要從原始碼安裝 TRL

pip install "trl[liger] @ git+https://github.com/huggingface/trl.git"

然後你就可以這樣使用它

from trl import GRPOConfig, GRPOTrainer
from datasets import load_dataset


train_dataset = load_dataset("trl-lib/tldr", split="train")
training_args = GRPOConfig(output_dir="Qwen3-0.6B-GRPO", use_liger_loss=True)

def reward_len(completions, **kwargs):
    return [-abs(20 - len(completion)) for completion in completions]

trainer = GRPOTrainer(
    model="Qwen/Qwen3-0.6B-Instruct",
    reward_funcs=reward_len,
    args=training_args,
    train_dataset=train_dataset,
)
trainer.train()

基準測試

我們進行了一系列使用和不使用 Liger GRPO 損失的 GRPO 實驗,以比較兩者之間的差異。對於策略模型,我們使用了 Qwen3-0.6B,並嘗試了不同的批次大小。所有實驗都在 gsm8k 資料集上進行,並使用其獎勵函式。

這是在 FP32 和 BF16 訓練中,峰值記憶體使用量與批次大小關係的圖表。正如預期的那樣,隨著批次大小的增加,記憶體節省效果會更好,因為我們是沿著批次維度進行分塊的。所以當批次大小增加時,Liger 分塊損失最終使用的記憶體比常規(非 Liger)版本少得多,最多可節省 40%。

簡要說明:目前,我們只支援 FP32,但我們正在努力將 Liger GRPO 的 BF16 支援開源到 TRL 中。此處顯示的 BF16 結果來自我們一直在測試的內部補丁。

Mem-vs-batch-size-fp32

Mem-vs-batch-size-bf16

我們還證明了 Liger 損失在效果上是精確的。如圖所示,訓練過程中的獎勵變化與使用標準 TRL 實現所觀察到的結果基本保持一致。

reward-vs-step

透過 FSDP 和 PEFT 進一步擴充套件

我們還在 PR #3260 和 PR #3355 中分別為 Liger GRPO 損失添加了 FSDP 和 PEFT 支援,讓使用者可以輕鬆地將實驗擴充套件到多個 GPU 或節點。像 LoRA 和 QLoRA 這樣的 PEFT 技術透過只調整原始模型之上的較小介面卡權重來減少可訓練引數的數量,從而顯著降低了記憶體壓力,因為不需要在記憶體中保留整個模型的梯度、啟用和最佳化器狀態。此外,在 GRPO 中使用 PEFT 可以在訓練期間省去載入單獨的參考模型,因為我們只需停用 LoRA 介面卡就可以在訓練中獲得原始的、未經修改的模型。

這裡,我們展示了一個使用 FSDP 和 PEFT 的多 GPU GRPO 訓練圖,其中我們比較了在不同 Qwen3 模型尺寸下,使用和不使用 Liger 損失時可能的最大訓練批次大小。我們發現,使用 Liger,我們能夠將批次大小提高約 1.5 到 1.8 倍

peft-batch-size-vs-model-size

透過 vLLM 實現更大規模的擴充套件

為了加速訓練過程中的文字生成,Liger 損失可以與 TRL 整合的 vLLM 伺服器有效結合。這能以最小的開銷顯著加快 rollout 資料的收集,並提供無縫的整合體驗。

以下是如何進行設定

  1. 啟動 vLLM 伺服器: 首先,啟動 vLLM 伺服器。此伺服器將處理來自您訓練指令碼的生成請求。開啟一個終端並執行

    CUDA_VISIBLE_DEVICES=1 trl vllm-serve --model "Qwen/Qwen3-0.6B"
    

    注意:我們分配 CUDA_VISIBLE_DEVICES=1 以在特定 GPU(本例中為 GPU 1)上執行 vLLM 伺服器,從而讓其他 GPU 可用於訓練。

  2. 配置並執行您的訓練指令碼: 接下來,修改您的訓練指令碼以使用 vLLM 伺服器。關鍵的改動是在您的 GRPOConfig 中設定 use_vllm=True

    from trl import GRPOConfig, GRPOTrainer
    from datasets import load_dataset
    
    
    def reward_len(completions, **kwargs):
        return [-abs(20 - len(completion)) for completion in completions]
    
    dataset = load_dataset("trl-lib/tldr", split="train[:1%]")
    training_args = GRPOConfig(
        output_dir="Qwen3-0.6B-GRPO", 
        use_liger_loss=True, 
        use_vllm=True, # Enable vLLM integration
        logging_steps=10
    )
    trainer = GRPOTrainer(
        model="Qwen/Qwen3-0.6B", # Ensure this matches the model served by vLLM
        reward_funcs=reward_len,
        args=training_args,
        train_dataset=dataset,
    )
    trainer.train()
    
  3. 啟動訓練: 最後,使用 accelerate launch(或者如果不使用 Accelerate 進行多 GPU/分散式訓練,則使用 python)執行您的訓練指令碼。如果您的 vLLM 伺服器正在佔用一個 GPU,請確保將訓練目標設定為另一個不同的 GPU。

    CUDA_VISIBLE_DEVICES=0 accelerate launch train.py 
    

    (假設您的指令碼名為 train.py 並且您想在 GPU 0 上執行訓練).

透過遵循這些步驟,您可以在使用 Liger 損失進行 GRPO 訓練時,利用 vLLM 實現更快的生成周轉。

結論

隨著 Liger-GRPO 整合到 TRL 中,並支援 FSDP 和 PEFT,使用 GRPO 微調語言模型現在比以往任何時候都更加節省記憶體且可擴充套件。我們鼓勵社群嘗試這些新功能並分享他們的反饋,以幫助我們進一步改進 LLM 的 RL 訓練。

社群

Liger Kernel 會影響訓練速度嗎?與常規 GRPO 相比,是更快、更慢還是沒有區別?

·

這通常取決於設定,可能會有加速效果,也可能沒有!

文章作者

在我們的實驗中,我們觀察到使用和不使用 Liger 的訓練速度沒有顯著差異

·

聽起來太完美了!

感謝你們出色的工作。

順便問一下,我用 Qwen/Qwen2.5-0.5B-Instructbf16 模式下,結合 deepspeed zero3 測試了 liger loss。
我遇到了如下所述的形狀不匹配問題


[rank0]: Traceback (most recent call last):
[rank0]:   File "/workspace/temp.py", line 22, in <module>
[rank0]:     trainer.train()
[rank0]:   File "/usr/local/lib/python3.11/dist-packages/transformers/trainer.py", line 2238, in train
[rank0]:     return inner_training_loop(
[rank0]:            ^^^^^^^^^^^^^^^^^^^^
[rank0]:   File "/usr/local/lib/python3.11/dist-packages/transformers/trainer.py", line 2553, in _inner_training_loop
[rank0]:     tr_loss_step = self.training_step(model, inputs, num_items_in_batch)
[rank0]:                    ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
[rank0]:   File "/usr/local/lib/python3.11/dist-packages/transformers/trainer.py", line 3730, in training_step
[rank0]:     loss = self.compute_loss(model, inputs, num_items_in_batch=num_items_in_batch)
[rank0]:            ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
[rank0]:   File "/usr/local/lib/python3.11/dist-packages/trl/extras/profiling.py", line 87, in wrapper
[rank0]:     return func(self, *args, **kwargs)
[rank0]:            ^^^^^^^^^^^^^^^^^^^^^^^^^^^
[rank0]:   File "/usr/local/lib/python3.11/dist-packages/trl/trainer/grpo_trainer.py", line 1187, in compute_loss
[rank0]:     return self.compute_liger_loss(model, inputs)
[rank0]:            ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
[rank0]:   File "/usr/local/lib/python3.11/dist-packages/trl/trainer/grpo_trainer.py", line 1160, in compute_liger_loss
[rank0]:     loss, metrics = self.liger_grpo_loss(
[rank0]:                     ^^^^^^^^^^^^^^^^^^^^^
[rank0]:   File "/usr/local/lib/python3.11/dist-packages/torch/nn/modules/module.py", line 1739, in _wrapped_call_impl
[rank0]:     return self._call_impl(*args, **kwargs)
[rank0]:            ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
[rank0]:   File "/usr/local/lib/python3.11/dist-packages/torch/nn/modules/module.py", line 1750, in _call_impl
[rank0]:     return forward_call(*args, **kwargs)
[rank0]:            ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
[rank0]:   File "/usr/local/lib/python3.11/dist-packages/liger_kernel/chunked_loss/grpo_loss.py", line 249, in forward
[rank0]:     return LigerFusedLinearGRPOFunction.apply(
[rank0]:            ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
[rank0]:   File "/usr/local/lib/python3.11/dist-packages/torch/autograd/function.py", line 575, in apply
[rank0]:     return super().apply(*args, **kwargs)  # type: ignore[misc]
[rank0]:            ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
[rank0]:   File "/usr/local/lib/python3.11/dist-packages/liger_kernel/chunked_loss/grpo_loss.py", line 142, in forward
[rank0]:     return super().forward(
[rank0]:            ^^^^^^^^^^^^^^^^
[rank0]:   File "/usr/local/lib/python3.11/dist-packages/liger_kernel/chunked_loss/fused_linear_ppo.py", line 219, in forward
[rank0]:     accumulate_chunk(
[rank0]:   File "/usr/local/lib/python3.11/dist-packages/liger_kernel/chunked_loss/fused_linear_ppo.py", line 132, in accumulate_chunk
[rank0]:     (chunk_grad_input, chunk_grad_weight, *chunk_grad_bias), (chunk_loss, chunk_metrics) = fused_fwd_bwd(
[rank0]:                                                                                            ^^^^^^^^^^^^^^
[rank0]:   File "/usr/local/lib/python3.11/dist-packages/torch/_dynamo/eval_frame.py", line 574, in _fn
[rank0]:     return fn(*args, **kwargs)
[rank0]:            ^^^^^^^^^^^^^^^^^^^
[rank0]:   File "/usr/local/lib/python3.11/dist-packages/torch/_dynamo/convert_frame.py", line 1380, in __call__
[rank0]:     return self._torchdynamo_orig_callable(
[rank0]:            ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
[rank0]:   File "/usr/local/lib/python3.11/dist-packages/torch/_dynamo/convert_frame.py", line 1164, in __call__
[rank0]:     result = self._inner_convert(
[rank0]:              ^^^^^^^^^^^^^^^^^^^^
[rank0]:   File "/usr/local/lib/python3.11/dist-packages/torch/_dynamo/convert_frame.py", line 547, in __call__
[rank0]:     return _compile(
[rank0]:            ^^^^^^^^^
[rank0]:   File "/usr/local/lib/python3.11/dist-packages/torch/_dynamo/convert_frame.py", line 986, in _compile
[rank0]:     guarded_code = compile_inner(code, one_graph, hooks, transform)
[rank0]:                    ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
[rank0]:   File "/usr/local/lib/python3.11/dist-packages/torch/_dynamo/convert_frame.py", line 715, in compile_inner
[rank0]:     return _compile_inner(code, one_graph, hooks, transform)
[rank0]:            ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
[rank0]:   File "/usr/local/lib/python3.11/dist-packages/torch/_utils_internal.py", line 95, in wrapper_function
[rank0]:     return function(*args, **kwargs)
[rank0]:            ^^^^^^^^^^^^^^^^^^^^^^^^^
[rank0]:   File "/usr/local/lib/python3.11/dist-packages/torch/_dynamo/convert_frame.py", line 750, in _compile_inner
[rank0]:     out_code = transform_code_object(code, transform)
[rank0]:                ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
[rank0]:   File "/usr/local/lib/python3.11/dist-packages/torch/_dynamo/bytecode_transformation.py", line 1361, in transform_code_object
[rank0]:     transformations(instructions, code_options)
[rank0]:   File "/usr/local/lib/python3.11/dist-packages/torch/_dynamo/convert_frame.py", line 231, in _fn
[rank0]:     return fn(*args, **kwargs)
[rank0]:            ^^^^^^^^^^^^^^^^^^^
[rank0]:   File "/usr/local/lib/python3.11/dist-packages/torch/_dynamo/convert_frame.py", line 662, in transform
[rank0]:     tracer.run()
[rank0]:   File "/usr/local/lib/python3.11/dist-packages/torch/_dynamo/symbolic_convert.py", line 2868, in run
[rank0]:     super().run()
[rank0]:   File "/usr/local/lib/python3.11/dist-packages/torch/_dynamo/symbolic_convert.py", line 1052, in run
[rank0]:     while self.step():
[rank0]:           ^^^^^^^^^^^
[rank0]:   File "/usr/local/lib/python3.11/dist-packages/torch/_dynamo/symbolic_convert.py", line 962, in step
[rank0]:     self.dispatch_table[inst.opcode](self, inst)
[rank0]:   File "/usr/local/lib/python3.11/dist-packages/torch/_dynamo/symbolic_convert.py", line 659, in wrapper
[rank0]:     return inner_fn(self, inst)
[rank0]:            ^^^^^^^^^^^^^^^^^^^^
[rank0]:   File "/usr/local/lib/python3.11/dist-packages/torch/_dynamo/symbolic_convert.py", line 2341, in CALL
[rank0]:     self._call(inst)
[rank0]:   File "/usr/local/lib/python3.11/dist-packages/torch/_dynamo/symbolic_convert.py", line 2335, in _call
[rank0]:     self.call_function(fn, args, kwargs)
[rank0]:   File "/usr/local/lib/python3.11/dist-packages/torch/_dynamo/symbolic_convert.py", line 897, in call_function
[rank0]:     self.push(fn.call_function(self, args, kwargs))  # type: ignore[arg-type]
[rank0]:               ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
[rank0]:   File "/usr/local/lib/python3.11/dist-packages/torch/_dynamo/variables/functions.py", line 118, in call_function
[rank0]:     return tx.inline_user_function_return(self, [*self.self_args(), *args], kwargs)
[rank0]:            ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
[rank0]:   File "/usr/local/lib/python3.11/dist-packages/torch/_dynamo/symbolic_convert.py", line 903, in inline_user_function_return
[rank0]:     return InliningInstructionTranslator.inline_call(self, fn, args, kwargs)
[rank0]:            ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
[rank0]:   File "/usr/local/lib/python3.11/dist-packages/torch/_dynamo/symbolic_convert.py", line 3072, in inline_call
[rank0]:     return cls.inline_call_(parent, func, args, kwargs)
[rank0]:            ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
[rank0]:   File "/usr/local/lib/python3.11/dist-packages/torch/_dynamo/symbolic_convert.py", line 3198, in inline_call_
[rank0]:     tracer.run()
[rank0]:   File "/usr/local/lib/python3.11/dist-packages/torch/_dynamo/symbolic_convert.py", line 1052, in run
[rank0]:     while self.step():
[rank0]:           ^^^^^^^^^^^
[rank0]:   File "/usr/local/lib/python3.11/dist-packages/torch/_dynamo/symbolic_convert.py", line 962, in step
[rank0]:     self.dispatch_table[inst.opcode](self, inst)
[rank0]:   File "/usr/local/lib/python3.11/dist-packages/torch/_dynamo/symbolic_convert.py", line 659, in wrapper
[rank0]:     return inner_fn(self, inst)
[rank0]:            ^^^^^^^^^^^^^^^^^^^^
[rank0]:   File "/usr/local/lib/python3.11/dist-packages/torch/_dynamo/symbolic_convert.py", line 2341, in CALL
[rank0]:     self._call(inst)
[rank0]:   File "/usr/local/lib/python3.11/dist-packages/torch/_dynamo/symbolic_convert.py", line 2335, in _call
[rank0]:     self.call_function(fn, args, kwargs)
[rank0]:   File "/usr/local/lib/python3.11/dist-packages/torch/_dynamo/symbolic_convert.py", line 897, in call_function
[rank0]:     self.push(fn.call_function(self, args, kwargs))  # type: ignore[arg-type]
[rank0]:               ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
[rank0]:   File "/usr/local/lib/python3.11/dist-packages/torch/_dynamo/variables/functions.py", line 317, in call_function
[rank0]:     return super().call_function(tx, args, kwargs)
[rank0]:            ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
[rank0]:   File "/usr/local/lib/python3.11/dist-packages/torch/_dynamo/variables/functions.py", line 118, in call_function
[rank0]:     return tx.inline_user_function_return(self, [*self.self_args(), *args], kwargs)
[rank0]:            ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
[rank0]:   File "/usr/local/lib/python3.11/dist-packages/torch/_dynamo/symbolic_convert.py", line 903, in inline_user_function_return
[rank0]:     return InliningInstructionTranslator.inline_call(self, fn, args, kwargs)
[rank0]:            ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
[rank0]:   File "/usr/local/lib/python3.11/dist-packages/torch/_dynamo/symbolic_convert.py", line 3072, in inline_call
[rank0]:     return cls.inline_call_(parent, func, args, kwargs)
[rank0]:            ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
[rank0]:   File "/usr/local/lib/python3.11/dist-packages/torch/_dynamo/symbolic_convert.py", line 3198, in inline_call_
[rank0]:     tracer.run()
[rank0]:   File "/usr/local/lib/python3.11/dist-packages/torch/_dynamo/symbolic_convert.py", line 1052, in run
[rank0]:     while self.step():
[rank0]:           ^^^^^^^^^^^
[rank0]:   File "/usr/local/lib/python3.11/dist-packages/torch/_dynamo/symbolic_convert.py", line 962, in step
[rank0]:     self.dispatch_table[inst.opcode](self, inst)
[rank0]:   File "/usr/local/lib/python3.11/dist-packages/torch/_dynamo/symbolic_convert.py", line 659, in wrapper
[rank0]:     return inner_fn(self, inst)
[rank0]:            ^^^^^^^^^^^^^^^^^^^^
[rank0]:   File "/usr/local/lib/python3.11/dist-packages/torch/_dynamo/symbolic_convert.py", line 1736, in CALL_FUNCTION_EX
[rank0]:     self.call_function(fn, argsvars.items, kwargsvars)
[rank0]:   File "/usr/local/lib/python3.11/dist-packages/torch/_dynamo/symbolic_convert.py", line 897, in call_function
[rank0]:     self.push(fn.call_function(self, args, kwargs))  # type: ignore[arg-type]
[rank0]:               ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
[rank0]:   File "/usr/local/lib/python3.11/dist-packages/torch/_dynamo/variables/functions.py", line 317, in call_function
[rank0]:     return super().call_function(tx, args, kwargs)
[rank0]:            ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
[rank0]:   File "/usr/local/lib/python3.11/dist-packages/torch/_dynamo/variables/functions.py", line 118, in call_function
[rank0]:     return tx.inline_user_function_return(self, [*self.self_args(), *args], kwargs)
[rank0]:            ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
[rank0]:   File "/usr/local/lib/python3.11/dist-packages/torch/_dynamo/symbolic_convert.py", line 903, in inline_user_function_return
[rank0]:     return InliningInstructionTranslator.inline_call(self, fn, args, kwargs)
[rank0]:            ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
[rank0]:   File "/usr/local/lib/python3.11/dist-packages/torch/_dynamo/symbolic_convert.py", line 3072, in inline_call
[rank0]:     return cls.inline_call_(parent, func, args, kwargs)
[rank0]:            ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
[rank0]:   File "/usr/local/lib/python3.11/dist-packages/torch/_dynamo/symbolic_convert.py", line 3198, in inline_call_
[rank0]:     tracer.run()
[rank0]:   File "/usr/local/lib/python3.11/dist-packages/torch/_dynamo/symbolic_convert.py", line 1052, in run
[rank0]:     while self.step():
[rank0]:           ^^^^^^^^^^^
[rank0]:   File "/usr/local/lib/python3.11/dist-packages/torch/_dynamo/symbolic_convert.py", line 962, in step
[rank0]:     self.dispatch_table[inst.opcode](self, inst)
[rank0]:   File "/usr/local/lib/python3.11/dist-packages/torch/_dynamo/symbolic_convert.py", line 659, in wrapper
[rank0]:     return inner_fn(self, inst)
[rank0]:            ^^^^^^^^^^^^^^^^^^^^
[rank0]:   File "/usr/local/lib/python3.11/dist-packages/torch/_dynamo/symbolic_convert.py", line 1736, in CALL_FUNCTION_EX
[rank0]:     self.call_function(fn, argsvars.items, kwargsvars)
[rank0]:   File "/usr/local/lib/python3.11/dist-packages/torch/_dynamo/symbolic_convert.py", line 897, in call_function
[rank0]:     self.push(fn.call_function(self, args, kwargs))  # type: ignore[arg-type]
[rank0]:               ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
[rank0]:   File "/usr/local/lib/python3.11/dist-packages/torch/_dynamo/variables/functions.py", line 858, in call_function
[rank0]:     return self.func.call_function(tx, merged_args, merged_kwargs)
[rank0]:            ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
[rank0]:   File "/usr/local/lib/python3.11/dist-packages/torch/_dynamo/variables/functions.py", line 317, in call_function
[rank0]:     return super().call_function(tx, args, kwargs)
[rank0]:            ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
[rank0]:   File "/usr/local/lib/python3.11/dist-packages/torch/_dynamo/variables/functions.py", line 118, in call_function
[rank0]:     return tx.inline_user_function_return(self, [*self.self_args(), *args], kwargs)
[rank0]:            ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
[rank0]:   File "/usr/local/lib/python3.11/dist-packages/torch/_dynamo/symbolic_convert.py", line 903, in inline_user_function_return
[rank0]:     return InliningInstructionTranslator.inline_call(self, fn, args, kwargs)
[rank0]:            ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
[rank0]:   File "/usr/local/lib/python3.11/dist-packages/torch/_dynamo/symbolic_convert.py", line 3072, in inline_call
[rank0]:     return cls.inline_call_(parent, func, args, kwargs)
[rank0]:            ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
[rank0]:   File "/usr/local/lib/python3.11/dist-packages/torch/_dynamo/symbolic_convert.py", line 3198, in inline_call_
[rank0]:     tracer.run()
[rank0]:   File "/usr/local/lib/python3.11/dist-packages/torch/_dynamo/symbolic_convert.py", line 1052, in run
[rank0]:     while self.step():
[rank0]:           ^^^^^^^^^^^
[rank0]:   File "/usr/local/lib/python3.11/dist-packages/torch/_dynamo/symbolic_convert.py", line 962, in step
[rank0]:     self.dispatch_table[inst.opcode](self, inst)
[rank0]:   File "/usr/local/lib/python3.11/dist-packages/torch/_dynamo/symbolic_convert.py", line 659, in wrapper
[rank0]:     return inner_fn(self, inst)
[rank0]:            ^^^^^^^^^^^^^^^^^^^^
[rank0]:   File "/usr/local/lib/python3.11/dist-packages/torch/_dynamo/symbolic_convert.py", line 2341, in CALL
[rank0]:     self._call(inst)
[rank0]:   File "/usr/local/lib/python3.11/dist-packages/torch/_dynamo/symbolic_convert.py", line 2335, in _call
[rank0]:     self.call_function(fn, args, kwargs)
[rank0]:   File "/usr/local/lib/python3.11/dist-packages/torch/_dynamo/symbolic_convert.py", line 897, in call_function
[rank0]:     self.push(fn.call_function(self, args, kwargs))  # type: ignore[arg-type]
[rank0]:               ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
[rank0]:   File "/usr/local/lib/python3.11/dist-packages/torch/_dynamo/variables/misc.py", line 1022, in call_function
[rank0]:     return self.obj.call_method(tx, self.name, args, kwargs)
[rank0]:            ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
[rank0]:   File "/usr/local/lib/python3.11/dist-packages/torch/_dynamo/variables/misc.py", line 778, in call_method
[rank0]:     .call_function(tx, args, kwargs)
[rank0]:      ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
[rank0]:   File "/usr/local/lib/python3.11/dist-packages/torch/_dynamo/variables/functions.py", line 317, in call_function
[rank0]:     return super().call_function(tx, args, kwargs)
[rank0]:            ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
[rank0]:   File "/usr/local/lib/python3.11/dist-packages/torch/_dynamo/variables/functions.py", line 118, in call_function
[rank0]:     return tx.inline_user_function_return(self, [*self.self_args(), *args], kwargs)
[rank0]:            ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
[rank0]:   File "/usr/local/lib/python3.11/dist-packages/torch/_dynamo/symbolic_convert.py", line 903, in inline_user_function_return
[rank0]:     return InliningInstructionTranslator.inline_call(self, fn, args, kwargs)
[rank0]:            ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
[rank0]:   File "/usr/local/lib/python3.11/dist-packages/torch/_dynamo/symbolic_convert.py", line 3072, in inline_call
[rank0]:     return cls.inline_call_(parent, func, args, kwargs)
[rank0]:            ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
[rank0]:   File "/usr/local/lib/python3.11/dist-packages/torch/_dynamo/symbolic_convert.py", line 3198, in inline_call_
[rank0]:     tracer.run()
[rank0]:   File "/usr/local/lib/python3.11/dist-packages/torch/_dynamo/symbolic_convert.py", line 1052, in run
[rank0]:     while self.step():
[rank0]:           ^^^^^^^^^^^
[rank0]:   File "/usr/local/lib/python3.11/dist-packages/torch/_dynamo/symbolic_convert.py", line 962, in step
[rank0]:     self.dispatch_table[inst.opcode](self, inst)
[rank0]:   File "/usr/local/lib/python3.11/dist-packages/torch/_dynamo/symbolic_convert.py", line 659, in wrapper
[rank0]:     return inner_fn(self, inst)
[rank0]:            ^^^^^^^^^^^^^^^^^^^^
[rank0]:   File "/usr/local/lib/python3.11/dist-packages/torch/_dynamo/symbolic_convert.py", line 2341, in CALL
[rank0]:     self._call(inst)
[rank0]:   File "/usr/local/lib/python3.11/dist-packages/torch/_dynamo/symbolic_convert.py", line 2335, in _call
[rank0]:     self.call_function(fn, args, kwargs)
[rank0]:   File "/usr/local/lib/python3.11/dist-packages/torch/_dynamo/symbolic_convert.py", line 897, in call_function
[rank0]:     self.push(fn.call_function(self, args, kwargs))  # type: ignore[arg-type]
[rank0]:               ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
[rank0]:   File "/usr/local/lib/python3.11/dist-packages/torch/_dynamo/variables/torch.py", line 953, in call_function
[rank0]:     tensor_variable = wrap_fx_proxy(
[rank0]:                       ^^^^^^^^^^^^^^
[rank0]:   File "/usr/local/lib/python3.11/dist-packages/torch/_dynamo/variables/builder.py", line 2153, in wrap_fx_proxy
[rank0]:     return wrap_fx_proxy_cls(target_cls=TensorVariable, **kwargs)
[rank0]:            ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
[rank0]:   File "/usr/local/lib/python3.11/dist-packages/torch/_dynamo/variables/builder.py", line 2219, in wrap_fx_proxy_cls
[rank0]:     return _wrap_fx_proxy(
[rank0]:            ^^^^^^^^^^^^^^^
[rank0]:   File "/usr/local/lib/python3.11/dist-packages/torch/_dynamo/variables/builder.py", line 2315, in _wrap_fx_proxy
[rank0]:     example_value = get_fake_value(proxy.node, tx, allow_non_graph_fake=True)
[rank0]:                     ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
[rank0]:   File "/usr/local/lib/python3.11/dist-packages/torch/_dynamo/utils.py", line 2536, in get_fake_value
[rank0]:     raise TorchRuntimeError(str(e)).with_traceback(e.__traceback__) from None
[rank0]:   File "/usr/local/lib/python3.11/dist-packages/torch/_dynamo/utils.py", line 2471, in get_fake_value
[rank0]:     ret_val = wrap_fake_exception(
[rank0]:               ^^^^^^^^^^^^^^^^^^^^
[rank0]:   File "/usr/local/lib/python3.11/dist-packages/torch/_dynamo/utils.py", line 2017, in wrap_fake_exception
[rank0]:     return fn()
[rank0]:            ^^^^
[rank0]:   File "/usr/local/lib/python3.11/dist-packages/torch/_dynamo/utils.py", line 2472, in <lambda>
[rank0]:     lambda: run_node(tx.output, node, args, kwargs, nnmodule)
[rank0]:             ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
[rank0]:   File "/usr/local/lib/python3.11/dist-packages/torch/_dynamo/utils.py", line 2604, in run_node
[rank0]:     raise RuntimeError(make_error_message(e)).with_traceback(
[rank0]:   File "/usr/local/lib/python3.11/dist-packages/torch/_dynamo/utils.py", line 2586, in run_node
[rank0]:     return node.target(*args, **kwargs)
[rank0]:            ^^^^^^^^^^^^^^^^^^^^^^^^^^^^
[rank0]:   File "/usr/local/lib/python3.11/dist-packages/torch/_prims_common/wrappers.py", line 289, in _fn
[rank0]:     result = fn(*args, is_out=(out is not None), **kwargs)  # type: ignore[arg-type]
[rank0]:              ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
[rank0]:   File "/usr/local/lib/python3.11/dist-packages/torch/_decomp/decompositions.py", line 4444, in matmul
[rank0]:     return torch.ops.aten._unsafe_view(t1_folded.mv(t2), output_shape)
[rank0]:                                        ^^^^^^^^^^^^^^^^
[rank0]:   File "/usr/local/lib/python3.11/dist-packages/torch/utils/_stats.py", line 21, in wrapper
[rank0]:     return fn(*args, **kwargs)
[rank0]:            ^^^^^^^^^^^^^^^^^^^
[rank0]:   File "/usr/local/lib/python3.11/dist-packages/torch/_subclasses/fake_tensor.py", line 1276, in __torch_dispatch__
[rank0]:     return self.dispatch(func, types, args, kwargs)
[rank0]:            ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
[rank0]:   File "/usr/local/lib/python3.11/dist-packages/torch/_subclasses/fake_tensor.py", line 1816, in dispatch
[rank0]:     return self._cached_dispatch_impl(func, types, args, kwargs)
[rank0]:            ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
[rank0]:   File "/usr/local/lib/python3.11/dist-packages/torch/_subclasses/fake_tensor.py", line 1377, in _cached_dispatch_impl
[rank0]:     output = self._dispatch_impl(func, types, args, kwargs)
[rank0]:              ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
[rank0]:   File "/usr/local/lib/python3.11/dist-packages/torch/_subclasses/fake_tensor.py", line 2290, in _dispatch_impl
[rank0]:     decomposition_table[func](*args, **kwargs)
[rank0]:   File "/usr/local/lib/python3.11/dist-packages/torch/_prims_common/wrappers.py", line 291, in _fn
[rank0]:     result = fn(*args, **kwargs)
[rank0]:              ^^^^^^^^^^^^^^^^^^^
[rank0]:   File "/usr/local/lib/python3.11/dist-packages/torch/_decomp/decompositions.py", line 83, in inner
[rank0]:     r = f(*tree_map(increase_prec, args), **tree_map(increase_prec, kwargs))
[rank0]:         ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
[rank0]:   File "/usr/local/lib/python3.11/dist-packages/torch/_decomp/decompositions.py", line 4336, in mv
[rank0]:     torch._check(
[rank0]:   File "/usr/local/lib/python3.11/dist-packages/torch/__init__.py", line 1656, in _check
[rank0]:     _check_with(RuntimeError, cond, message)
[rank0]:   File "/usr/local/lib/python3.11/dist-packages/torch/__init__.py", line 1638, in _check_with
[rank0]:     raise error_type(message_evaluated)
[rank0]: torch._dynamo.exc.TorchRuntimeError: Failed running call_function <built-in method matmul of type object at 0x7f2e2a41ff00>(*(GradTrackingTensor(lvl=1, value=
[rank0]:     FakeTensor(..., device='cuda:0', size=(1, s0, 896), dtype=torch.bfloat16,
[rank0]:                requires_grad=True)
[rank0]: ), GradTrackingTensor(lvl=1, value=
[rank0]:     FakeTensor(..., device='cuda:0', size=(0,), dtype=torch.bfloat16,
[rank0]:                requires_grad=True)
[rank0]: )), **{}):
[rank0]: size mismatch, got input (s0x896), vec (0)

Liger GRPO 是否支援使用 deepspeed zero3 進行多 GPU 訓練?

註冊登入 以發表評論

© . This site is unofficial and not affiliated with Hugging Face, Inc.