Accelerate 文件

🤗 accelerate 中的上下文並行

Hugging Face's logo
加入 Hugging Face 社群

並獲得增強的文件體驗

開始使用

🤗 accelerate 中的上下文並行

本指南將介紹在 🤗accelerate 中使用上下文並行(context parallelism)的基礎知識。對於更好奇的讀者,我們將在後面的章節中介紹一些技術細節。

為何需要上下文並行?

隨著大型語言模型和最近推理模型的出現,序列長度迅速增長。這與注意力機制的二次方記憶體複雜度相結合,導致需要更有效的方法來訓練具有長序列的模型。對於 128k 的序列長度,使用 `bf16` 精度和 vanilla attention 實現,注意力矩陣的記憶體需求為 `128k * 128k * 2 位元組 * num_heads = ~32 GB * num_heads`。當然,使用不例項化這些注意力權重的 `flash attention` 或 `SDPA`,這個數值會大幅下降,但記憶體需求的增長仍然相當可觀。

上下文並行允許我們沿序列維度對注意力計算的輸入進行分片,並在多個 GPU 上平行計算注意力。這樣,我們就可以訓練具有長序列的模型,並有可能擴充套件到 1M+ 序列長度。

如何使用上下文並行?

from accelerate.utils import ParallelismConfig, TorchContextParallelConfig

+ cp_config = TorchContextParallelConfig(
+       cp_comm_strategy="alltoall", # no need to use cp_config at all, if you want to use the default "allgather"
+ )

+ parallelism_config = ParallelismConfig(
+     cp_size=8,
+     cp_handler=cp_config,  # or just cp_size=8, if you want to use the default "allgather"
+ )

accelerator = Accelerator(
    ...,
    parallelism_config=parallelism_config,
)

與 🤗accelerate 中的任何其他功能一樣,您也可以透過向 `accelerate launch` 傳遞相應的標誌來啟用上下文並行。在這種情況下,沒有區別

accelerate launch --parallelism-config-cp-size 8 --parallelism-config-cp-comm-strategy [allgather|alltoall] ...

你也可以在 `accelerate config` 命令中設定 `cp_size` 和 `cp_comm_strategy`,這會把它們儲存在你的 `accelerate` 配置檔案中,這樣你就不用每次啟動指令碼時都傳遞它們了。

上下文並行與其他並行策略相容,例如資料並行、張量並行和 FSDP2。你可以簡單地透過將並行大小設定為期望的值來組合它們,例如 `--parallelism-config-dp-size 8 --parallelism-config-tp-size 2 --parallelism-config-cp-size 8`。或者,你可以使用 `ParallelismConfig` 類以程式設計方式設定它們。

上下文並行與 `FSDP2` 緊密耦合,你可以在 FSDP2 簡介中瞭解更多。這意味著,上下文並行僅在您為程式使用 `FullyShardedDataParallelPlugin` 或將版本設定為 2 的 `--use-fsdp` 時才有效。如果不使用 `FSDP2`,將會引發錯誤。

上下文並行僅適用於SDPA,並且僅在沒有掩碼或使用因果掩碼的情況下工作。我們無法為您正確檢測這一點,因此您有責任確保您使用的是沒有掩碼或帶有因果掩碼的 `SDPA`。如果您使用任何其他注意力實現,它將引發錯誤。

透過上述方法啟用上下文並行後,你可以將其應用於你的訓練迴圈。我們提供了一個圍繞 `torch.distributed.tensor.experimental.context_parallel` 的薄包裝器,你可以在你的訓練迴圈中使用它,它抽象了一些使用它的複雜性(稍後會詳細介紹)。為了最小化你對訓練迴圈的修改,我們提供了一個上下文管理器,如果上下文並行未啟用,它就是一個 `noop`(空操作),如果啟用了,它就應用上下文並行。這樣,你可以在你的訓練迴圈中使用它,而無需根據你的並行配置更改任何程式碼。你可以如下使用它

for batch in dataloader:
    with accelerator.maybe_context_parallel(
        buffers=[batch["input_ids"], batch["attention_mask"]],
        buffer_seq_dims=[1, 1],
        no_restore_buffers={batch["input_ids"], batch["labels"]},
    ):
        outputs = model(**batch)
        ...

這個上下文管理器必須在每個訓練步驟中重新建立,如上例所示。這樣做至關重要。

這有可能將您的上下文大小擴充套件到 1M+ 的序列長度。下面,我們展示了上下文並行在高達 256k 上下文大小下的速度和記憶體使用情況。我們可以看到,當我們加倍上下文大小和 GPU 數量時,我們可以實現一致的記憶體使用,從而可能實現無限的上下文長度擴充套件。

context parallelism memory usage
圖 1:上下文並行在高達 256k 上下文大小下的記憶體使用和速度。

這些示例是使用您可以在示例資料夾中找到的指令碼建立的。要在 8 個 H100 GPU(128k 序列長度)上執行該示例,您可以使用以下命令

accelerate launch --use-fsdp --fsdp-activation-checkpointing=TRUE examples/fsdp2/nd_parallel.py --cp-size=8 --sequence-length=128000

Accelerate 的介面

上下文管理器接受幾個引數,用於配置上下文並行。

  • `buffers`:這是一個張量列表,它們將在序列維度上進行分片。這些張量通常是輸入 ID、標籤和注意力掩碼。
  • `buffer_seq_dims`:這是一個整數列表,按 `buffers` 列表的順序指定了緩衝區的序列維度。如果你傳遞 `buffers=[input_ids, shift_labels]`,兩者形狀都為 `[batch_size, sequence_length]`,那麼你應該傳遞 `buffer_seq_dims=[1, 1]`,因為序列維度是張量的第二個維度。這對於正確計算模型輸出是必需的。
  • `no_restore_buffers`:上下文並行的實現會原地修改緩衝區,將它們轉換為 `torch.distributed.tensor.Dtensor`。在上下文管理器退出後,需要啟動一個通訊核心來將緩衝區恢復到其原始狀態(通常是 all-gather)。這需要一些時間,所以建議傳遞與 `buffers` 引數中相同的張量,以避免不必要的通訊,除非你確定在上下文管理器退出後需要使用這些緩衝區。

上下文並行與 `labels` 是 `input_ids` 的副本不相容,因為 🤗 transformers 的模型可能會自行移動 `labels` 以啟用因果語言建模。想象這種情況:labels = [l1, l2, l3, l4, … li],如果我們應用上下文並行,每個 rank 會得到一部分 labels,例如:labels_rank_0 = [l1, l2], labels_rank_1 = [l3, l4], … 在 transformers 的建模程式碼移動 labels 後,會變成:labels_rank_0 = [l2, PAD], labels_rank_1 = [l3, PAD], … 其中 `PAD` 是一個填充標記。這會導致損失計算不正確,因為 labels 不再與輸入對齊。因此,你需要在將 labels 傳入模型之前手動移動它們。

可配置選項

Accelerate 僅提供一個選項來配置上下文並行(除了 `cp_size`)

  • `cp_comm_strategy`:用於分片輪換的方法。我們強烈建議將其保持為 `"allgather"`,因為它很可能在大多數情況下優於 `"alltoall"`。

上下文並行大小相當不言自明,它是輸入被分片的 rank 數量。上下文並行分片輪換定義了輸入分片如何在 rank 之間輪換。我們將在下一節更詳細地介紹這兩種選項。

您可以在 ND 並行示例檔案中看到一個端到端的示例,在那裡您可以在單個 8xH100 節點上訓練一個 8B 模型,上下文長度可達 128k。透過多節點訓練,您可以在多個 GPU 上將其擴充套件到 1M+ 的序列長度。您還可以無縫地將其與其他並行策略結合起來,以滿足您的需求。

技術細節

本節技術性較強,如果您不需要了解上下文並行的內部原理,可以跳過此節,直接開始構建 🚀

在接下來的章節中,我們將大量使用 `shard` (分片) 這個詞,所以我們先來定義它。如果我們將一個張量稱為在第 `D` 維上,跨 `N` 個 rank `sharded`(分片),我們的意思是這個張量被分成 `N` 部分,其中張量的每個部分的形狀為 `[..., D//N, ...]`。

那麼它是如何工作的呢?

上下文並行透過在序列維度上對 `Q、K 和 V` 矩陣進行分片來工作。每個 rank 都有其分配的 `Q` 分片,我們稱之為 `Q_i`。在整個計算過程中,這個矩陣只保留在該 rank 上。同樣,每個 rank 都有自己的 `K` 和 `V` 分片,我們稱之為 `K_i` 和 `V_i`。然後,每個 rank 用自己的 `Q_i`、`K_i` 和 `V_i` 計算注意力,我們稱之為 `attn_i`。在此計算過程中,會啟動一個通訊核心來從所有其他 rank 收集 `K` 和 `V`。使用哪種通訊原語取決於 `context_parallel_shard_rotation` 選項。這樣,每個 rank 首先用 `Q_i`、`K_i` 和 `V_i` 計算本地注意力,然後用所有其他 rank 的 `K_j` 和 `V_j` 計算。由於每個 rank 持有的 `Q、K 和 V` 矩陣都是在序列維度上分片的,因此結果矩陣更小,可以容納在單個 GPU 上。

我們可以用以下虛擬碼來形式化這個過程

comm_kernel = {"allgather": allgather, "alltoall": alltoall}[context_parallel_shard_rotation]
Qi, Ki, Vi = shard(Q, K, V, seq_dim)
attn[i] = attn(Qi, Ki, Vi)
for j in range(context_parallel_size):
    Kj, Vj = comm_kernel()
    attn[j] = attn(Qi, Kj, Vj) # [batch, num_heads, seq_len // context_parallel_size, head_dim]

final_attn = combine(attn)

all-to-all vs all-gather

all-gather

那麼 all-to-all 和 all-gather 有什麼區別呢?使用 all-gather,通訊非常簡單。在我們計算完本地注意力 `attn_i` 之後(或者更確切地說,之前,因為它通常耗時更長),我們會啟動一個 all-gather 來收集所有其他 rank 的 `K` 和 `V`。當這個通訊完成後,每個 rank 就擁有了所有其他 rank 的 `K` 和 `V`,並可以依次與它們計算注意力。在理想情況下,all-gather 的完成時間恰好與 `attn_i` 的計算完成時間一致。然而,在實踐中這從未發生,因此理想的實際重疊是在 `attn_i` 的全部計算與一部分通訊重疊時實現的,然後為了開始用 `K_j` 和 `V_j` 進行計算,我們等待 all-gather 完成。

all-to-all

All-to-all,有時也稱為 `ring-rotation`,利用了一種環狀的通訊模式。在完成 `attn_i` 計算後,會啟動一個 all-to-all 操作,將 `K_i` 和 `V_i` 傳送給相鄰的 rank。然後我們重複這個操作 `context_parallel_size-1` 次,這樣每個 rank 都能看到所有其他 rank 的 `K` 和 `V` 的分片一次。在理想情況下,我們預取相鄰 rank 的分片 `K_i+1` 和 `V_i+1`,並且這個通訊過程與我們當前 `attn_i` 的計算完全重疊。同樣,現實中這種完美的重疊從未發生。鑑於這種方法的性質,如果我們沒有實現完美的重疊,其代價要比使用 all-gather 大得多。

如何選擇正確的輪換方法?

理論上,all-to-all 應該是更好的選擇。但實際上,它很少如此。因此,我們預設使用 all-gather,因為它更有可能獲得更好的效能。`torchtitan` 團隊的廣泛基準測試也表明,all-to-all 很少優於 all-gather。儘管如此,我們仍然提供兩種選擇,因為您可能會發現其中一種更適合您的用例。

您可以直接在下圖的效能分析器輸出中看到這個問題

all-to-all profiler output
圖 1:紅色部分顯示了等待 all-to-all 核心完成時的空閒時間。在第一個藍色條中高亮顯示的部分,您可以看到它需要大約 250 微秒才能完成,這個過程在每次注意力呼叫中重複 N-1 次,其中 N 是上下文並行大小。

為何只支援 FSDP2?

我們只支援 `FSDP2` 的上下文並行,因為我們建立了一個 `context_parallel_size` 和 `dp_shard_size` 的聯合網格來充分利用其潛力。它的工作原理是:我們在大小為 `cp_size*dp_shard_size` 的聯合網格上對模型進行分片,這最大化了記憶體節省。這在某種程度上是“免費的午餐”,因為 `FSDP` 通訊與注意力的計算完全重疊,如下圖所示。

why FSDP2+CP
圖 2:在藍色矩形(Stream 23)中,您可以看到 `FSDP` 分片的預取與注意力的計算(Stream 7)完全重疊,而在紅色矩形(Stream 24)中,您可以看到 all-gather 核心導致了一個空閒時間的“氣泡”,在此期間我們的計算流(7)是空閒的。

在上圖中,您還可以注意到 all-to-all 和 all-gather 之間的區別。在 all-to-all(圖 1)中,我們每次注意力呼叫都會啟動 N-1 次通訊核心,而在 all-gather(圖 2)中,我們只啟動一次通訊核心。這導致了一個更大的“氣泡”,但每次注意力呼叫只發生一次,而在 all-to-all 中,它會發生 N-1 次。

聯合網格中的資料分發

我們確保將同一批資料分發到整個 `cp` 子組,以確保結果正確。(意味著 `cp` 子組中的每個 rank 都會收到同一批資料。)然而,我們也會將不同的批次分發到 `dp_shard` 組的每個 rank。可以這樣想象:

# 8 GPUS, --dp_shard_size 4, --cp_size 2
# mesh = [[0, 1], [2, 3], [4, 5], [6, 7]]
# model is sharded across the whole mesh (each GPU holds 1/8 of the model)
# GPUs 0,1 = batch 0
# GPUs 2,3 = batch 1
... and so on.
< > 在 GitHub 上更新

© . This site is unofficial and not affiliated with Hugging Face, Inc.