使用 PyTorch FSDP 微調 Llama 2 70B

釋出於 2023 年 9 月 13 日
在 GitHub 上更新

引言

在這篇博文中,我們將探討如何使用 PyTorch FSDP 微調 Llama 2 70B 以及相關的最佳實踐。我們將利用 Hugging Face Transformers、Accelerate 和 TRL。我們還將學習如何將 Accelerate 與 SLURM 結合使用。

全分片資料並行 (FSDP) 是一種範例,其中最佳化器狀態、梯度和引數在裝置之間分片。在前向傳播過程中,每個 FSDP 單元執行一次 *all-gather 操作*以獲取完整的權重,然後執行計算,然後丟棄來自其他裝置的分片。在前向傳播之後,計算損失,然後進行反向傳播。在反向傳播中,每個 FSDP 單元執行一次 all-gather 操作以獲取完整的權重,並執行計算以獲取區域性梯度。這些區域性梯度透過 *reduce-scatter 操作*在裝置之間平均並分片,以便每個裝置可以更新其分片的引數。有關 PyTorch FSDP 的更多資訊,請參閱這篇博文:使用 PyTorch 全分片資料並行加速大型模型訓練

FSDP Workflow

(來源:連結

使用的硬體

節點數量:2。最低要求為 1。
每個節點的 GPU 數量:8
GPU 型別:A100
GPU 視訊記憶體:80GB
節點內連線:NVLink
每個節點的記憶體:1TB
每個節點的 CPU 核心數:96
節點間連線:彈性結構介面卡

微調 LLaMa 70B 的挑戰

在嘗試使用 FSDP 微調 LLaMa 70B 時,我們遇到了三個主要挑戰

  1. FSDP 在載入預訓練模型後封裝模型。如果節點內的每個程序/rank 都載入 Llama-70B 模型,它將需要 70 * 4 * 8 GB ~ 2TB 的 CPU RAM,其中 4 是每個引數的位元組數,8 是每個節點上的 GPU 數量。這將導致 CPU RAM 記憶體不足,從而導致程序終止。

  2. 使用 FULL_STATE_DICT 並將 CPU 解除安裝到 rank 0 上儲存整個中間檢查點需要大量時間,並且由於廣播期間無限期掛起,通常會導致 NCCL 超時錯誤。然而,在訓練結束時,我們希望獲得整個模型狀態字典,而不是僅與 FSDP 相容的分片狀態字典。

  3. 我們需要提高速度並減少 VRAM 使用量,以加快訓練速度並節省計算成本。

讓我們看看如何解決上述挑戰並微調 70B 模型!

在我們開始之前,以下是重現我們結果所需的所有資源

  1. 程式碼庫:https://github.com/pacman100/DHS-LLM-Workshop/tree/main/chat_assistant/sft/training 帶 flash-attn V2

  2. FSDP 配置:https://github.com/pacman100/DHS-LLM-Workshop/blob/main/chat_assistant/training/configs/fsdp_config.yaml

  3. SLURM 指令碼 launch.slurmhttps://gist.github.com/pacman100/1cb1f17b2f1b3139a63b764263e70b25

  4. 模型:meta-llama/Llama-2-70b-chat-hf

  5. 資料集:smangrul/code-chat-assistant-v1(LIMA+GUANACO 的混合,格式正確,可直接用於訓練)

先決條件

首先,按照這些步驟安裝 Flash Attention V2:Dao-AILab/flash-attention: 快速且記憶體高效的精確注意力 (github.com)。安裝 PyTorch 的最新 nightly 版本,CUDA ≥11.8。按照 DHS-LLM-Workshop/code_assistant/training/requirements.txt 安裝其餘要求。在這裡,我們將從主分支安裝 🤗 Accelerate 和 🤗 Transformers。

微調

解決挑戰 1

PR huggingface/transformers#25107huggingface/accelerate#1777 解決了第一個挑戰,並且不需要使用者進行程式碼更改。它執行以下操作:

  1. 在所有 rank 上建立無權重的模型(使用 meta 裝置)。
  2. 僅在 rank==0 上載入狀態字典,並使用該狀態字典在 rank 0 上設定模型權重。
  3. 對於所有其他 rank,對 meta 裝置上的每個引數執行 torch.empty(*param.size(), dtype=dtype)
  4. 因此,rank==0 將載入具有正確狀態字典的模型,而所有其他 rank 將具有隨機權重。
  5. 設定 sync_module_states=True,以便 FSDP 物件在訓練開始前負責將其廣播到所有 rank。

以下是 7B 模型在 2 個 GPU 上測量的記憶體消耗和模型引數在不同階段的輸出片段。我們可以觀察到,在載入預訓練模型期間,rank 0 和 rank 1 的 CPU 總峰值記憶體分別為 32744 MB1506 MB。因此,只有 rank 0 正在載入預訓練模型,從而實現了 CPU RAM 的高效利用。完整的日誌可以在這裡找到。

accelerator.process_index=0 GPU Memory before entering the loading : 0
accelerator.process_index=0 GPU Memory consumed at the end of the loading (end-begin): 0
accelerator.process_index=0 GPU Peak Memory consumed during the loading (max-begin): 0
accelerator.process_index=0 GPU Total Peak Memory consumed during the loading (max): 0
accelerator.process_index=0 CPU Memory before entering the loading : 926
accelerator.process_index=0 CPU Memory consumed at the end of the loading (end-begin): 26415
accelerator.process_index=0 CPU Peak Memory consumed during the loading (max-begin): 31818
accelerator.process_index=0 CPU Total Peak Memory consumed during the loading (max): 32744

accelerator.process_index=1 GPU Memory before entering the loading : 0
accelerator.process_index=1 GPU Memory consumed at the end of the loading (end-begin): 0
accelerator.process_index=1 GPU Peak Memory consumed during the loading (max-begin): 0
accelerator.process_index=1 GPU Total Peak Memory consumed during the loading (max): 0
accelerator.process_index=1 CPU Memory before entering the loading : 933
accelerator.process_index=1 CPU Memory consumed at the end of the loading (end-begin): 10
accelerator.process_index=1 CPU Peak Memory consumed during the loading (max-begin): 573
accelerator.process_index=1 CPU Total Peak Memory consumed during the loading (max): 1506

解決挑戰 2

透過在建立 FSDP 配置時選擇 SHARDED_STATE_DICT 狀態字典型別來解決。SHARDED_STATE_DICT 會單獨儲存每個 GPU 的分片,從而可以快速儲存或從中間檢查點恢復訓練。當使用 FULL_STATE_DICT 時,第一個程序(rank 0)會在 CPU 上收集整個模型,然後以標準格式儲存。

讓我們使用以下命令建立 Accelerate 配置:

accelerate config --config_file "fsdp_config.yaml"

fsdp_config

生成的配置在這裡:fsdp_config.yaml。在這裡,分片策略是 FULL_SHARD。我們使用 TRANSFORMER_BASED_WRAP 進行自動包裝策略,它使用 _no_split_module 來查詢巢狀 FSDP 自動包裝的 Transformer 塊名稱。我們使用 SHARDED_STATE_DICT 來儲存中間檢查點和最佳化器狀態,這是 PyTorch 團隊推薦的格式。請確保在訓練開始時啟用從 rank 0 廣播模組引數,如上一段中解決挑戰 1 所述。我們正在啟用 bf16 混合精度訓練。

對於最終檢查點,整個模型狀態字典,使用以下程式碼片段。

if trainer.is_fsdp_enabled:
    trainer.accelerator.state.fsdp_plugin.set_state_dict_type("FULL_STATE_DICT")

trainer.save_model(script_args.output_dir) # alternatively, trainer.push_to_hub() if the whole ckpt is below 50GB as the LFS limit per file is 50GB 

解決挑戰 3

Flash Attention 和啟用梯度檢查點對於加快訓練速度和減少 VRAM 使用至關重要,從而實現微調並節省計算成本。程式碼庫目前使用猴子補丁,實現位於 chat_assistant/training/llama_flash_attn_monkey_patch.py

FlashAttention: Fast and Memory-Efficient Exact Attention with IO-Awareness 引入了一種計算精確注意力的方法,該方法透過利用底層硬體/GPU 的記憶體層次結構知識,實現更快和更記憶體高效的計算——記憶體的頻寬/速度越高,其容量就越小,因為它變得更昂貴。

如果我們遵循部落格 讓深度學習 Go Brrrr 的第一性原理,我們可以發現當前硬體上的 Attention 模組是 記憶體密集型/頻寬密集型。原因是 Attention 主要由逐元素操作組成,如下圖左側所示。我們可以觀察到,掩碼、softmax 和 dropout 操作佔據了大部分時間,而不是佔據大部分 FLOP 的矩陣乘法。

Attention Bottlenecks

(來源:連結

這正是 Flash Attention 所解決的問題。其理念是**消除冗餘的 HBM 讀/寫。**它透過將所有內容保留在 SRAM 中,執行所有中間步驟,然後才將最終結果寫回 HBM 來實現這一點,這也稱為**核心融合**。下面是其如何克服記憶體瓶頸的圖示。kernel_fusion

(來源:連結

在正向和反向傳播過程中,使用**切片**將 NxN softmax/分數計算分塊,以克服 SRAM 記憶體大小的限制。為了啟用切片,使用了線上 softmax 演算法。在反向傳播過程中,使用**重新計算**,以避免在正向傳播過程中儲存整個 NxN softmax/分數矩陣。這大大減少了記憶體消耗。

要簡化和深入瞭解 Flash Attention,請參閱部落格文章 ELI5: FlashAttention讓深度學習 Go Brrrr 從第一性原理,以及原始論文 FlashAttention: 快速且記憶體高效的精確注意力,具有 IO 感知

整合所有功能

要使用 Accelerate 啟動器和 SLURM 執行訓練,請參閱此 gist launch.slurm。下面是演示如何使用 Accelerate 啟動器執行訓練的等效命令。請注意,我們正在覆蓋 fsdp_config.yamlmain_process_ipmain_process_portmachine_ranknum_processesnum_machines 值。這裡,另一個重要的一點是儲存在所有節點之間共享。

accelerate launch \
    --config_file configs/fsdp_config.yaml \
    --main_process_ip $MASTER_ADDR \
    --main_process_port $MASTER_PORT \
    --machine_rank \$MACHINE_RANK \
    --num_processes 16 \
    --num_machines 2 \
    train.py \
    --seed 100 \
    --model_name "meta-llama/Llama-2-70b-chat-hf" \
    --dataset_name "smangrul/code-chat-assistant-v1" \
    --chat_template_format "none" \
    --add_special_tokens False \
    --append_concat_token False \
    --splits "train,test" \
    --max_seq_len 2048 \
    --max_steps 500 \
    --logging_steps 25 \
    --log_level "info" \
    --eval_steps 100 \
    --save_steps 250 \
    --logging_strategy "steps" \
    --evaluation_strategy "steps" \
    --save_strategy "steps" \
    --push_to_hub \
    --hub_private_repo True \
    --hub_strategy "every_save" \
    --bf16 True \
    --packing True \
    --learning_rate 5e-5 \
    --lr_scheduler_type "cosine" \
    --weight_decay 0.01 \
    --warmup_ratio 0.03 \
    --max_grad_norm 1.0 \
    --output_dir "/shared_storage/sourab/experiments/full-finetune-llama-chat-asst" \
    --per_device_train_batch_size 1 \
    --per_device_eval_batch_size 1 \
    --gradient_accumulation_steps 1 \
    --gradient_checkpointing True \
    --use_reentrant False \
    --dataset_text_field "content" \
    --use_flash_attn True \
    --ddp_timeout 5400 \
    --optim paged_adamw_32bit 

微調在約 13.5 小時內完成,以下是訓練損失圖。

train_loss

重要提示:在這裡,如果您只使用一臺帶有 8 個 A100 80GB GPU 的節點,則需要 bitsandbytes 中的 paged_adamw_32bit 最佳化器。有關分頁最佳化器的更多資訊,請參閱 https://huggingface.co/docs/bitsandbytes/main/en/optimizers#paged-optimizers

下面是使用上述模型的一個對話示例

System Prompt: You are a helpful, respectful and honest assistant. Always answer as helpfully \
as possible, while being safe. Your answers should not include any harmful, \
unethical, racist, sexist, toxic, dangerous, or illegal content. Please ensure that \
your responses are socially unbiased and positive in nature.

If a question does not make any sense, or is not factually coherent, explain why \
instead of answering something not correct. If you don’t know the answer to a \
question, please don’t share false information.

- Human: What is Deep Learning? Explain like a Pirate.

+ Assistant: Arrr Grumete! Are ye lookin' fer a proper explanation o' Deep Learnin'? 
+ Alright then, listen up me hearty! Deep Learnin' be a type o' machine learnin' that uses artificial neural networks to analyze data. 
+ Aye, ye heard that right, mateys! Artificial neural networks be models o' the human brain, but instead o' using neurons, they use interconnected nodes or "artificial neurons" to process information. 
+ These nodes work together to form layers o' processing units that can learn from large amounts o' data. 
+ The more data these systems receive, the better they get at recognizing patterns and makin' decisions on their own. 
+ So hoist the sails and set course for adventure with deep learnin', me hearties! It's a treasure trove o' knowledge just waitin' to be discovered. 
+ Fair winds and following seas!

- Human: Now explain it like a chef.

+ Assistant: Certainly! Here's an explanation of deep learning from a culinary perspective:

+ Think of a recipe as a sequence of steps used to transform raw ingredients into a delicious dish. 
+ Similarly, in deep learning, there are multiple layers of "ingredients" (or features) that are combined and transformed through various operations to produce a final output or prediction. 
+ Each layer builds upon the previous one, refining and enhancing the representation of the input data until it reaches the desired outcome. 
+ Just as different spices enhance the flavors of food, each layer adds depth and complexity to the model's understanding of the input, allowing it to make more accurate predictions. 
+ And much like how chefs must carefully balance the amount of salt, sugar, and other seasonings in a recipe, the training process in deep learning involves adjusting the weights and biases of the network to achieve the perfect blend of accuracy and efficiency. 
+ With enough practice and experimentation, anyone can become a master chef... I mean, a master of deep learning! Bon appétit!

整個對話格式如下:

<|system|> system message <|endoftext|> <|prompter|> Q1 <|endoftext|> <|assistant|> A1 <|endoftext|> ...

結論

我們成功地使用 PyTorch FSDP 在多節點多 GPU 環境中微調了 70B Llama 模型,同時解決了各種挑戰。我們瞭解了 🤗 Transformers 和 🤗 Accelerates 如何支援高效初始化大型模型,以克服 CPU RAM 記憶體不足的問題。隨後,我們介紹了儲存/載入中間檢查點以及如何以可直接使用的方式儲存最終模型的推薦實踐。為了加快訓練速度並減少 GPU 記憶體使用,我們強調了 Flash Attention 和梯度檢查點的重要性。總的來說,我們可以看到使用 🤗 Accelerate 的簡單配置如何實現在多節點多 GPU 環境中微調如此龐大的模型。

社群

註冊登入 以發表評論

© . This site is unofficial and not affiliated with Hugging Face, Inc.