Google TPU 文件
完全分片資料平行處理 (FSDP) v2
加入 Hugging Face 社群
並獲得增強的文件體驗
開始使用
完全分片資料平行處理 (FSDP) v2
概覽
在 TPU 上微調大型語言模型 (LLM) 時,跨裝置的模型分片對於記憶體效率與提升訓練效能至關重要。optimum.tpu.fsdp_v2 模組提供了實作「完全分片資料平行處理」(Fully Sharded Data Parallel) 訓練的工具,並使用專為 TPU 裝置最佳化的 SPMD (單一程式多重資料,Single Program Multiple Data) 架構。
FSDP_v2 特色
- 跨 TPU 裝置的模型權重分片
- 支援梯度檢查點 (Gradient checkpointing)
- 針對常見模型架構的自動設定
- 與 PyTorch/XLA 的 SPMD 實作整合
基本使用
以下說明如何啟用並設定 FSDP_v2 以進行訓練:
from optimum.tpu import fsdp_v2
from transformers import AutoModelForCausalLM, AutoTokenizer
import torch
# Enable FSDP_v2
fsdp_v2.use_fsdp_v2()
# Load model and tokenizer
model_id = "meta-llama/Llama-2-7b"
tokenizer = AutoTokenizer.from_pretrained(model_id)
model = AutoModelForCausalLM.from_pretrained(
model_id,
torch_dtype=torch.bfloat16
)
# Get FSDP training configuration
fsdp_args = fsdp_v2.get_fsdp_training_args(model)設定選項
get_fsdp_training_args() 函式會回傳一個包含模型特定設定的字典,例如:
{
'fsdp': 'full_shard',
'fsdp_config': {
'transformer_layer_cls_to_wrap': ['LlamaDecoderLayer'], # Model-specific
'xla': True,
'xla_fsdp_v2': True,
'xla_fsdp_grad_ckpt': True
}
}關鍵參數
transformer_layer_cls_to_wrap:指定哪些模型層需要使用 FSDP 進行封裝xla:啟用 XLA 最佳化xla_fsdp_v2:啟用 FSDP_v2 實作xla_fsdp_grad_ckpt:啟用梯度檢查點以提升記憶體效率
進階用法
自訂層封裝
您可以自訂哪些層需要使用 FSDP 進行封裝
custom_fsdp_args = fsdp_v2.get_fsdp_training_args(
model,
layer_cls_to_wrap=['CustomTransformerLayer']
)與 Transformers Trainer 整合
FSDP_v2 設定可直接用於 Transformers 的 Trainer
from transformers import Trainer, TrainingArguments
# Or for instruction fine-tuning:
# from trl import SFTTrainer
trainer = Trainer( # or SFTTrainer
model=model,
args=TrainingArguments(**fsdp_args), # Unpack FSDP configuration
train_dataset=dataset,
...
)