Accelerate ND-Parallel:高效多 GPU 訓練指南

釋出時間:2025 年 8 月 8 日
在 GitHub 上更新

由於不同並行策略的複雜性,在多個 GPU 上訓練大型模型可能充滿挑戰。在 Accelerate 中,我們與 Axolotl 合作,集成了一種快速簡便的方法,可以在您的訓練指令碼中使用任何並行策略組合!

以下是如何將其新增到您的訓練指令碼中

from transformers import AutoModelForCausalLM
from accelerate import Accelerator
from accelerate.parallelism_config import ParallelismConfig
from accelerate.utils import FullyShardedDataParallelPlugin

# configure your desired parallelisms here - this particular configuration requires at least 2 nodes with 8 GPUs each. 
# setting any parallelism degree to 1 disables it i.e. dp_replicate_size=1 disables DP.
pc = ParallelismConfig(
    dp_shard_size=2, # Fully Sharded Data Parallel degree
    dp_replicate_size=2, # Data Parallel degree
    cp_size=2, # Context Parallel degree
    tp_size=2, # Tensor Parallel degree
)

fsdp_plugin = FullyShardedDataParallelPlugin(
    fsdp_version=2,
    auto_wrap_policy="transformer_based_wrap",
    transformer_cls_names_to_wrap=["LlamaDecoderLayer"],
    state_dict_type="SHARDED_STATE_DICT",
)

accelerator = Accelerator(
    parallelism_config=pc,
    fsdp_plugin=fsdp_plugin
)

model = AutoModelForCausalLM.from_pretrained(
    "NousResearch/Hermes-3-Llama-3.1-8B", 
    device_mesh=accelerator.torch_device_mesh
)

model = accelerator.prepare(model)

我們還在 Accelerate 儲存庫中包含了一個更全面的端到端訓練指令碼,其中演示瞭如何設定資料載入器、最佳化器和訓練迴圈,以及如何在訓練後儲存模型。

為了進一步簡化大規模模型微調並結合並行策略與各種微調技術,我們還將此技術整合到 Axolotl 中。為了幫助您立即上手,我們測試了一些示例配置,您可以根據自己的需求進行修改 - 嘗試使用以下命令:

# note: this requires a minimum world size of 16 
axolotl train examples/distributed-parallel/llama-3_1-8b-hsdp-tp.yaml

您還可以檢視 Axolotl ND-Parallelism 文件以獲取更多詳細資訊——將 ND 並行技術新增到您現有配置中就像在您的 Axolotl 配置檔案中新增一個或多個以下欄位一樣簡單:

# Fully Sharded Data Parallel degree (note: also requires the fsdp_config field) 
# see https://docs.axolotl.ai/docs/multi-gpu.html#sec-fsdp for more details
dp_shard_size: 2
# Data Parallel degree
dp_replicate_size: 2
# Context Parallel Degree
context_parallel_size: 2
# Tensor Parallel Degree
tensor_parallel_size: 2

我們已經透過 Accelerate 中的 ParallelismConfig 類或 Axolotl 中的配置欄位,使得配置不同並行策略的程度以及它們如何組合變得容易。但是我們如何知道哪種配置最適合我們的用例呢?當我們擴充套件到訓練具有數百億甚至數千億引數的模型時,主要的挑戰來自於理解不同的並行策略以及它們如何相互作用以最小化裝置間的通訊開銷。在這篇文章中,我們將詳細介紹不同的並行策略如何工作,以及何時以及如何組合它們。

目錄

資料並行

Diagram for Data Parallel
分散式資料並行(Distributed Data Parallel)在每個裝置上覆制整個模型,並將資料均勻地分成子批次分配給每個裝置。(來源:Martynas Šubonis)。

資料並行 (DP) 是在多個 GPU 上訓練模型最常用的技術,它涉及在每個裝置上覆制模型、梯度和最佳化器狀態,同時在 GPU 之間均勻分配資料批次,並在更新引數之前同步裝置間的梯度。與單裝置訓練相比,這可以顯著提高吞吐量,但要求您的模型能夠適應單個裝置。

我們可以透過 Accelerate 的 `ParallelismConfig` 中的 `dp_replicate_size` 引數或 Axolotl 中的配置欄位來控制模型的副本數量。值得注意的是,DP 是一種*最頂層*的並行策略,這意味著如果我們使用 `dp_replicate_size=2` 並將其與其他並行策略組合,將會有 2 個模型副本,每個副本也會受到其他並行策略的影響。例如,如果我們將 `dp_replicate_size=2` 和 `tp_size=2` 結合使用,我們將擁有 2 個模型副本,每個副本都有 2 個張量並行分片。

我們使用術語 *shard* 來描述單個裝置上的資料,它是較大資料的分割槽。

完全分片資料並行

Diagram for Fully Sharded Data Parallel
完全分片資料並行(Fully Sharded Data Parallel)將模型的每個引數均勻地劃分到每個裝置上,並且像 DDP 一樣,將資料均勻地分成子批次分配給每個裝置。為了完成前向和反向傳播,FSDP 必須在每次前向/反向傳播之前*聚合*每個引數的權重,以便每個裝置獲得引數的完整副本。(來源:Martynas Šubonis)。

如果我們的模型太大,無法適應單個裝置,該怎麼辦?完全分片資料並行 (FSDP) 透過在 GPU 之間分片(均勻分佈)模型的權重、梯度和最佳化器狀態來解決此問題(這受 DeepSpeed 的 ZeRO-3 啟發),同時每個裝置仍接收其完整資料批次的一部分。如您從上圖中所注意到的,我們不是在每個裝置上都需要整個模型的完整副本,而是在前向傳播之前一次只收集一個層的權重,之後權重可以再次分片。

透過這種方式,我們以記憶體使用量換取了在每次前向和後向傳播之前收集分片引數以及進行 reduce-scatter 本地梯度的通訊開銷。我們可以透過調整引數收集的粒度來控制 FSDP 中的這種權衡。在一個極端情況下,我們可以對模型的每一層進行收集和重新分片,這將導致最低的峰值記憶體使用量,但會產生最高的通訊成本。在實踐中,一種常見的方法是一次性收集整個 Transformer 解碼器塊的權重。

雖然我們可以進一步進行記憶體-計算的權衡,並將模型引數和梯度解除安裝到 CPU 以訓練更大的模型,但這可能會非常慢。相反,讓我們考慮如何有效利用更多裝置來訓練更大的模型,同時保持高資料吞吐量。

我們使用術語 *節點* 指代託管多個 GPU(最多 8 個)的單臺機器,其中 GPU 之間使用 NVLink 等實現快速節點內通訊。在多節點訓練中,我們依靠 Infiniband 等相對較慢的節點間通訊通道。我們還將程序池中的裝置總數稱為世界大小——例如,一臺擁有 8 個 GPU 的單節點表示世界大小為 8,而 4 個節點則表示世界大小為 32。

當在多個節點上使用 FSDP 時,我們將跨節點的所有裝置視為在單個節點上進行訓練。例如,對於 4 個節點,每個節點包含 8 個 GPU,我們跨 32 個裝置執行分片,並使用節點內和節點間通訊後端執行集體 all-reduce 和 reduce-scatter 操作。透過這種方式,FSDP 單獨就可以擴充套件到大量 GPU,並具有較大的全域性批次大小以提高資料吞吐量。然而,在某些情況下會出現一些挑戰,可能需要將 FSDP 與其他並行技術結合使用。我們通常會盡量避免在超過一個完整節點的情況下使用 FSDP,因為通訊開銷可能會變得過高,我們將在混合分片資料並行部分討論如何解決這個問題。

您可以使用 Accelerate 的 `ParallelismConfig` 中的 `dp_shard_size` 引數,結合已準備好的 FullyShardedDataParallelPlugin,或者在 Axolotl 中設定 `dp_shard_size` 配置欄位來設定應用於模型的 FSDP 程度。

張量並行

Diagram for Tensor Parallel
張量並行將大型線性層拆分到不同裝置上,通常第一層採用列式分片,後續層採用行式分片。這種方法僅需要一次 AllReduce 通訊操作來組合分片輸出,從而在節點內將記憶體和計算分佈到不同裝置上,同時最大程度地減少通訊開銷。

張量並行 (TP) 是一種模型並行技術,其中模型的分片永久儲存在不同的裝置上,與資料並行技術相反,每個裝置接收相同批次的資料。TP 透過在裝置之間分配線性層的計算來工作,因此每個裝置只計算矩陣乘法的一部分。這種技術最適用於大型線性層,例如 transformer 模型中的前饋層,這些層可以跨裝置進行拆分。我們還可以在注意力層中的每個查詢、鍵、值和輸出投影上使用 TP,幾乎沒有額外的通訊成本。

為了達到最佳效能,連續層的引數可以以特定方式分佈,從而最大限度地減少所需的通訊。當處理成對的線性層時,我們可以對第一層進行列式拆分,對後續層進行行式拆分,從而只需一次 all-reduce 操作即可組合分片輸出。

與 FSDP 的動態分片行為不同,TP 建立靜態記憶體分割槽,從而導致記憶體使用量隨著 TP 組大小的增加而恆定減少。這對於大型模型至關重要,因為即使是單個解碼器層也太大,無法在 FSDP all-gather 期間放入記憶體(回想一下 FSDP 的常見做法是同時收集整個解碼器層的權重)。然而,與 FSDP 在節點間相對線性地擴充套件(在同構叢集上最多可達約 512 個 GPU,在低頻寬連線上則顯著減少)不同,TP 僅在一個節點範圍內有效。TP 在計算過程中需要裝置之間頻繁進行啟用同步,因為每個裝置只計算輸出的一部分,需要與其他裝置的輸出進行通訊才能繼續前向傳播。因此,如果要在多節點設定中使用 TP,我們必須考慮將 TP 與其他並行技術結合使用,同時將 TP 僅限於單個節點。由於其巨大的通訊開銷,不建議將 TP 用於 PCIe 連線的 GPU。

在 Accelerate 中,TP 大小透過 `ParallelismConfig` 中的 `tp_size` 進行配置,而在 Axolotl 中,您可以使用 `tensor_parallel_size` 配置欄位。

上下文並行

最近,大型語言模型 (LLM) 的推理能力導致序列長度急劇增加,因為模型使用越來越多的 token 來解決複雜任務。為了透過微調實現這種行為,我們需要一種方法來訓練模型處理非常長的序列長度——有時甚至可以達到一百萬個 token!

由於 transformer 中的注意力操作與上下文長度呈平方關係,這使得在單個 GPU 上進行操作變得不可能。例如,在微調相對較小的模型(如 Mistral-7B,使用 32 個注意力頭)時,如果序列長度為 128k,單個注意力矩陣將佔用 128k * 128k * 2 位元組 * `num_heads=32` = ~32GB * 32 = ~1TB 的啟用記憶體!儘管在使用 FlashAttention 等最佳化注意力實現時這個例子不現實,但它有助於說明上下文長度增加所導致的記憶體需求增長。

透過上下文並行 (CP),我們可以沿序列維度對輸入進行分片,從而使每個裝置只處理完整上下文的一部分,並計算完整且非常大的注意力矩陣的較小部分。為了瞭解其工作原理,請回憶注意力計算由以下方程描述: Attention(Q,K,V)=softmax(QKT)V \text{Attention}(Q, K, V) = \text{softmax}(QK^T)V

其中 Q Q K K V V 分別是查詢、鍵和值矩陣。Q Q 的每個查詢向量(行或輸入嵌入)必須計算與整個序列中 K K 的*每個*鍵向量的注意力得分,以正確應用 softmax 歸一化。然後,這些注意力得分將與 V V 中的*所有*值向量進行加權。

這裡最關鍵的細節在於,Q Q 中的每一行都可以獨立計算其注意力分數,但每個查詢向量仍然需要完整的 K K V V 矩陣。換句話說,給定一個序列長度為 $n$ 的輸入,我們可以將上述注意力方程擴充套件為

Attention(Q,K,V)1=softmax(Q1KT)VAttention(Q,K,V)2=softmax(Q2KT)VAttention(Q,K,V)n=softmax(QnKT)V \begin{align} \text{Attention}(Q, K, V)_1 &= \text{softmax}(Q_1 K^T) V \\ \text{Attention}(Q, K, V)_2 &= \text{softmax}(Q_2 K^T) V \\ &\vdots \\ \text{Attention}(Q, K, V)_n &= \text{softmax}(Q_n K^T) V \end{align}

其中我們把查詢矩陣的每一行表示為 Q1,Q2,...,Qn Q_1, Q_2, ..., Q_n 。這可以推廣為:Attention(Q,K,V)i=softmax(QiKT)Vi{1,2,...,n} \text{Attention}(Q, K, V)_i = \text{softmax}(Q_i K^T) V \quad \forall i \in \{1, 2, ..., n\}

當我們跨裝置對輸入進行分片時,由此產生的 Q Q K K V V 矩陣(由這些輸入分片計算得出)也會沿序列維度自動分片——每個 GPU 僅為其序列部分計算查詢、鍵和值。例如,如果世界大小為 W W 個 GPU,序列長度為 n n

  • GPU 0 計算 Q1:n/W Q_{1:n/W} K1:n/W K_{1:n/W} V1:n/W V_{1:n/W}
  • GPU 1 計算 Qn/W+1:2n/W Q_{n/W+1:2n/W} Kn/W+1:2n/W K_{n/W+1:2n/W} Vn/W+1:2n/W V_{n/W+1:2n/W}
  • ...
  • GPU (W1) (W-1) 計算 Q(W1)n/W+1:n Q_{(W-1)n/W+1:n} K(W1)n/W+1:n K_{(W-1)n/W+1:n} V(W1)n/W+1:n V_{(W-1)n/W+1:n}

我們如何確保注意力計算正確?如上所述,每個裝置只需要自己的 Q Q 分片,但需要完整的 K K V V 矩陣才能正確計算注意力。我們可以透過使用一種稱為 環注意力(RingAttention)的技術來實現這一點,其工作原理如下:

  1. 最初,每個 GPU 都擁有其分片的 Q Q K K V V (例如,GPU 0 擁有 Q1:n/W Q_{1:n/W} K1:n/W K_{1:n/W} V1:n/W V_{1:n/W} )。
  2. 每個 GPU 然後為其 Qi Q_i 分片及其本地 Kj K_j Vj V_j 分片計算一個部分注意力矩陣 Ai,j A_{i,j}
  3. 每個 GPU 將其 K K V V 分片傳送到環中的下一個 GPU。
  4. 每個 GPU 都會從環中的上一個 GPU 接收到不同的 K 和 V 分片。
  5. 每個 GPU 使用接收到的 K K V V 分片計算額外的部分注意力矩陣 Ai,j+1 A_{i,j+1} Ai,j+2 A_{i,j+2} 等。
  6. 每個 GPU 重複此過程,直到所有 K K V V 分片都已接收,並且所有部分注意力矩陣 Ai, A_{i,*} 都已計算。
Diagram for Context Parallel
上下文並行將輸入序列分片到各個 GPU 上,每個裝置儲存其分配段的查詢和鍵值對。環注意力在 GPU 之間迴圈 K、V 分片(由箭頭表示),允許每個查詢計算與整個序列中的鍵和值相關的注意力分數。最終的注意力輸出結合了所有序列位置的資訊,同時在裝置之間分配了記憶體和計算。

Accelerate 透過 accelerator.maybe_context_parallel 裝飾器實現此功能,該裝飾器也在 Accelerate 示例指令碼中展示。您還可以在我們的 CP 概念指南中瞭解其工作原理和限制。

與 TP 類似,在 Accelerate 中,CP 大小透過 `ParallelismConfig` 中的 `cp_size` 配置,而在 Axolotl 中,您可以使用 `context_parallel_size` 配置欄位。

ND 並行

在多節點設定中,FSDP 等資料並行技術將整個網路拓撲視為沿單個維度存在。您可能會發現這種方法在多種原因下受到限制:

  • 當擴充套件到更多節點時,FSDP 的集體操作會受到節點間延遲的瓶頸,導致訓練速度過慢。
  • 如前所述,大型模型的解碼器層可能無法適應 GPU 記憶體,或者即使處於分片狀態,也可能太大而無法執行前向傳播。
  • 可能無法達到理想的批處理大小——批處理可能太大,純資料並行無法有效處理,或者由於模型大小的記憶體限制而太小。

為了解決其中一些問題,我們可以將多節點叢集視為具有二維拓撲:裝置之間沿一個軸進行快速節點內通訊,而沿另一個軸進行相對較慢的節點間通訊。讓我們考慮如何組合我們迄今為止介紹的並行技術來利用這一點。

混合分片資料並行

Diagram for Hybrid Sharded Data Parallel
混合分片資料並行在每個副本組內執行 FSDP,並透過 AllReduce 同步副本組之間的梯度,從而將 FSDP 的記憶體效率與跨節點的 DP 通訊效率相結合。

混合分片資料並行(HSDP)是一種二維並行,它在節點內執行 FSDP,並在節點間執行 DP——也就是說,模型在每個節點之間複製,並在每個節點內使用 FSDP 進行分片。這使得 FSDP 較高的通訊開銷可以利用更快的節點內鏈路,而 DP 將較慢的節點間通訊開銷最小化到單個梯度同步步驟。如果您遇到問題 1,並希望以增加記憶體使用為代價來加速訓練,您可能會考慮這種方法。

重要的是要注意,我們可以自由配置我們的 2D 網路拓撲的形狀,因為我們不受限於維度與物理節點邊界對齊——您可能會在 2 個節點之間應用 FSDP,同時在 2 個節點的組之間複製,這將導致較低的記憶體使用但較慢的吞吐量,但仍將節點內 FSDP 通訊開銷減少一半。這是一個我們鼓勵您根據您的特定硬體設定和微調需求進行調整的引數。

您可以透過在 Accelerate 的 `ParallelismConfig` 或 Axolotl 的配置欄位中同時定義 `dp_shard_size` 和 `dp_replicate_size` 來啟用 HSDP。

完全分片資料並行 + 張量並行

正如我們之前提到的,TP 應該在節點內部應用以利用高頻寬的節點內通訊。因此,將 TP 和 FSDP 結合起來涉及使用 FSDP 在節點間對模型進行分片,並在節點內部使用 TP。在一定程度上,這可能為上述所有三個問題提供一個簡潔的解決方案:FSDP 的延遲成本可以減少 8 倍,太大無法在單個裝置上容納的層現在均勻分佈在裝置上,並且由於每個 TP 組接收相同批次的資料,我們還可以將全域性批次大小減少 8 倍。然而,如果這仍然不足,我們將無法增加跨節點的 TP 大小,並且必須考慮替代方法。

在 Accelerate 中,您可以透過在 `ParallelismConfig` 中同時定義 `dp_shard_size` 和 `tp_size` 來結合 TP 和 FSDP,而在 Axolotl 中,您可以新增 `dp_shard_size` 和 `tensor_parallel_size` 這兩個配置欄位。

完全分片資料並行 + 上下文並行

這是一種結合 FSDP 和 CP 的二維並行策略,雖然它並不常用,因為 CP 已經與 FSDP 結合(關於原因請參見 accelerate 概念指南),但在某些情況下它可能很有用,例如需要大序列長度,因此需要大 `cp_size`。如果這仍然不符合您的記憶體預算,您可以在此之上應用 FSDP,進一步減少記憶體使用。

在 Accelerate 中,您可以透過在 `ParallelismConfig` 中同時定義 `dp_shard_size` 和 `cp_size` 來結合 CP 和 FSDP,而在 Axolotl 中,您可以新增 `dp_shard_size` 和 `context_parallel_size` 這兩個配置欄位。

混合分片資料並行 + 張量並行

在足夠大的世界大小下(注意:3D 並行的最小世界大小為 8,但在更大規模下最有效),我們可以考慮將 HSDP 與 TP 結合起來,這樣就建立了一個層級結構:DP 首先在節點組之間複製模型,然後 FSDP 在每個組內分片模型,最後 TP 在每個節點內拆分單個層。當您面臨上述所有擴充套件限制時,您可能會考慮這種方法,因為它透過在記憶體使用和吞吐量之間進行權衡,提供了最大的靈活性來適應您的特定訓練設定。

在 Accelerate 中,您可以透過在 `ParallelismConfig` 中同時定義 `dp_shard_size`、`dp_replicate_size` 和 `tp_size` 來結合 HSDP 和 TP。類似地,在 Axolotl 中,您可以新增 `dp_shard_size`、`dp_replicate_size` 和 `tensor_parallel_size` 這三個配置欄位。

使用注意事項

我們沒有涵蓋其他並行組合方式,例如使用 HSDP + TP + CP 的 4D 並行,但它們與我們已經涵蓋的技術操作方式非常相似。最重要的是,我們鼓勵您嘗試不同的技術和配置——這是您掌握不同記憶體/吞吐量權衡方式的最佳途徑。

以下是一些您在分散式設定中可能會覺得有用的額外提示:

  • 當使用 FSDP 並處理單個裝置無法容納的過大模型時,啟用 CPU RAM 高效載入和分片狀態字典檢查點技術至關重要。您可以透過 Accelerate 的 FullyShardedDataParallelPlugin 中的 cpu_ram_efficient_loadingstate_dict_type 引數來啟用此功能,

    fsdp2_plugin = FullyShardedDataParallelPlugin(
        fsdp_version=2,
        auto_wrap_policy="transformer_based_wrap",
        transformer_cls_names_to_wrap=["LlamaDecoderLayer"],
        state_dict_type="SHARDED_STATE_DICT", 
        cpu_ram_efficient_loading=True
    )
    

    或者透過 Axolotl 中 fsdp_config 裡的 cpu_ram_efficient_loadingstate_dict_type 配置欄位。

    fsdp_version: 2
    fsdp_config:
      auto_wrap_policy: TRANSFORMER_BASED_WRAP
      transformer_layer_cls_to_wrap: LlamaDecoderLayer
      state_dict_type: SHARDED_STATE_DICT
      cpu_ram_efficient_loading: True
    
  • 訓練期間使用的總批次大小對訓練穩定性、記憶體使用和資料吞吐量起著重要作用。當使用 DP 和/或 FSDP 時,有效批次大小計算如下:

    effective_batch_size = micro_batch_size * gradient_accumulation_steps * dp_world_size.

    其中 dp_world_size = (dp_shard_size * dp_replicate_size) / tp_size。您可以透過增加訓練迴圈中的總微批次大小或梯度累積步數,或在 Axolotl 中設定 micro_batch_sizegradient_accumulation_steps 配置欄位,或透過增加更多 GPU 來增加總 dp_world_size 來增大批次大小。如前所述,這將施加一個 dp_world_size 的*最小*總批次大小——當使用純 DP/FSDP 時,這將是您的總世界大小,如果這過高,減少總批次大小的唯一方法是引入張量並行。最後,在 GPU 數量固定且記憶體受限的情況下,我們建議增加 gradient_accumulation_steps 而不是 micro_batch_size 以實現更大的有效批次大小,反之亦然。

  • 相應地,當您的有效批次大小因引入資料並行而增加時,您應該縮放學習率以保持訓練穩定性。常見的方法包括線性縮放 scaled_lr = base_lr * (effective_batch_size / base_batch_size) 或平方根縮放 scaled_lr = base_lr * sqrt(effective_batch_size / base_batch_size)

  • 即使使用並行策略,如果記憶體限制仍然存在,梯度檢查點可以透過計算換記憶體的方式提供額外的記憶體節省。在前向傳播期間,只有一部分啟用被儲存在記憶體中(通常在 Transformer 塊邊界),並且中間啟用在反向傳播期間重新計算。此技術與上述所有並行策略無縫協作。在 Accelerate 中,您可以透過在 FullyShardedDataParallelPlugin 中設定 activation_checkpointing=true 來啟用它。

    fsdp2_plugin = FullyShardedDataParallelPlugin(
        fsdp_version=2,
        auto_wrap_policy="transformer_based_wrap",
        transformer_cls_names_to_wrap=["LlamaDecoderLayer"],
        state_dict_type="SHARDED_STATE_DICT", 
        cpu_ram_efficient_loading=True,
        activation_checkpointing=True
    )
    

    在 Axolotl 中也類似。

    fsdp_version: 2
    fsdp_config:
      auto_wrap_policy: TRANSFORMER_BASED_WRAP
      transformer_layer_cls_to_wrap: LlamaDecoderLayer
      state_dict_type: SHARDED_STATE_DICT
      cpu_ram_efficient_loading: True
      activation_checkpointing: True
    

    請注意,梯度檢查點通常會因啟用重新計算而使訓練時間增加約 20-30%,但可以將啟用記憶體減少 60-80%,這使得它在訓練非常大的模型或使用長序列長度時特別有價值。

社群

註冊登入評論

© . This site is unofficial and not affiliated with Hugging Face, Inc.