TRL 文件

減少記憶體使用

Hugging Face's logo
加入 Hugging Face 社群

並獲得增強的文件體驗

開始使用

減少記憶體使用

此部分正在建設中。歡迎貢獻!

截斷

資料集中的序列長度差異可能很大。當資料分批處理時,序列會被填充以匹配批次中最長的序列,這可能會導致高記憶體使用,即使大多數序列相對較短。

Truncation prompt-completion

為了減少記憶體使用,將序列截斷到合理的長度非常重要。雖然 TRL 訓練器預設會截斷序列,但您可能需要調整預設的截斷長度,以更好地適應您的特定用例。

DPO
SFT

DPO 截斷首先透過 max_prompt_lengthmax_completion_length 引數應用於提示和補全。然後使用 max_length 引數來截斷最終生成的序列。

Truncation prompt-completion

要設定截斷引數,請使用以下程式碼片段

from trl import DPOConfig

training_args = DPOConfig(..., max_prompt_length=..., max_length=...)

您也可以使用 max_completion_length 引數來截斷補全,但這不太常見,因為目標通常是儘可能保留補全的完整長度。

from trl import DPOConfig

training_args = DPOConfig(..., max_completion_length=...)

如何選擇 max_length 值?

如果 max_length 太小,大部分詞元將被丟棄,無法對訓練做出貢獻。如果太大,記憶體使用量可能會激增,可能導致 OOM(記憶體不足)錯誤。如果沒有打包或無填充,大的 max_length 也可能導致訓練效率低下,因為許多詞元將是填充詞元。

為了幫助您選擇一個合適的值,我們提供了一個工具來視覺化資料集中序列長度的分佈。

打包

此技術僅適用於 SFT。

截斷有幾個缺點

  1. 資訊丟失:序列末尾的關鍵資料可能會被丟棄。
  2. 選擇截斷長度:太短會丟失資料;太長會影響效率。

打包(Packing)由 Raffel 等人於 2020 年提出,它透過組合序列而不是截斷來解決這些問題。它將資料集序列連線並分割成所需的長度。

Packing

打包透過在可能的情況下將多個序列合併到一行中來減少填充。我們使用一種先進的方法,以近乎最優的方式打包資料集。要啟用打包,請在 SFTConfig 中設定 packing=True

在 TRL 0.18 及更早版本中,打包使用了一種更激進的方法,將填充減少到幾乎為零,但缺點是會破壞資料集中大部分序列的連續性。要恢復到此策略,請在 `SFTConfig` 中使用 `packing_strategy="wrapped"`。

from trl import SFTConfig

training_args = SFTConfig(..., packing=True, max_length=512)

打包可能會導致批次汙染,即相鄰序列相互影響。這對於某些應用可能是有問題的。更多詳情,請參閱 #1230

使用 Liger 減少峰值記憶體使用

Liger Kernel 是一個專門為 LLM 訓練設計的 Triton 核心集合。它可以有效地將多 GPU 訓練吞吐量提高 20%,並減少 60% 的記憶體使用。

更多資訊,請參閱 Liger Kernel 整合

DPO
GRPO
KTO

要使用 Liger 減少峰值記憶體使用,請使用以下程式碼片段

from trl import DPOConfig

training_args = DPOConfig(..., use_liger_loss=True)

無填充

無填充批處理是另一種減少記憶體使用的方法。在這種方法中,首先對一個批次進行取樣,然後將其展平為單個序列,從而避免了填充。與打包(packing)可能透過組合不同樣本的部分而導致序列不完整不同,無填充批處理確保所有序列保持完整。

Padding-free batching

強烈建議將無填充批處理與 FlashAttention 2FlashAttention 3 結合使用。否則,您可能會遇到批次汙染問題。

DPO
SFT
from trl import DPOConfig

training_args = DPOConfig(..., padding_free=True, model_init_kwargs={"attn_implementation": "flash_attention_2"})

啟用解除安裝

啟用解除安裝是一種記憶體效率技術,它透過在前向傳播期間將啟用張量臨時移動到 CPU RAM,並在反向傳播需要時才將其移回,從而減少 GPU VRAM 的使用。這以略微增加訓練時間為代價,顯著減少了峰值記憶體使用。

要在您的 SFT 訓練配置中啟用啟用解除安裝:

<hfoptions> <hfoption id="SFT">
from trl import SFTConfig

training_args = SFTConfig(..., activation_offloading=True)
</hfoption> </hfoptions>

當將啟用解除安裝與使用 Liger 核心的模型一起使用時,由於相容性問題,您必須停用 Liger 交叉熵。該問題特別發生在 use_liger_kernel=True 的情況下,因為 Liger 交叉熵執行原地操作,這與啟用解除安裝衝突。預設設定 (use_liger_kernel=False) 可以正常工作。

# When using activation offloading with a model that uses Liger kernels:
from trl import SFTConfig

training_args = SFTConfig(
    activation_offloading=True,
    use_liger_kernel=False,  # Disable Liger cross entropy
    # Other parameters...
)

在底層,啟用解除安裝實現了 PyTorch 的 saved_tensors_hooks,以在前向傳播期間攔截啟用。它根據大小和上下文智慧地管理要解除安裝的張量,避免解除安裝輸出張量,因為這樣做效率低下。為了最佳化效能,它可以選擇使用 CUDA 流來將計算與 CPU-GPU 傳輸重疊。

在線上方法中停用模型聚合以進行生成

當使用 DeepSpeed ZeRO-3 時,模型權重會分片到多個 GPU 上。線上方法涉及在訓練過程中從模型生成補全。在此步驟中,模型權重會臨時聚合到單個 GPU 上進行生成。對於非常大的模型,這種聚合可能會導致記憶體不足(OOM)錯誤,如這個問題所述:#2250

如果遇到此問題,您可以透過設定以下引數來停用模型權重的生成聚合:

GRPO
線上 DPO
PPO
RLOO
from trl import GRPOConfig

training_args = GRPOConfig(..., ds3_gather_for_generation=False)

此調整可防止模型權重被聚合,從而避免 OOM 錯誤,但可能會導致生成速度變慢。

< > 在 GitHub 上更新

© . This site is unofficial and not affiliated with Hugging Face, Inc.