Accelerate 文件
DDP 通訊鉤子
加入 Hugging Face 社群
並獲得增強的文件體驗
開始使用
DDP 通訊鉤子
分散式資料並行(Distributed Data Parallel, DDP)通訊鉤子提供了一個通用介面,透過覆蓋 `DistributedDataParallel` 中預設的 allreduce 操作來控制梯度如何在工作程序間通訊。PyTorch 提供了一些內建的通訊鉤子,使用者可以輕鬆應用任何這些鉤子來最佳化通訊。
- FP16 壓縮鉤子:透過將梯度轉換為半精度浮點格式(`torch.float16`)來壓縮梯度,從而減少通訊開銷。
- BF16 壓縮鉤子:與 FP16 類似,但使用 Brain 浮點格式(`torch.bfloat16`),這在某些硬體上可能更高效。
- PowerSGD 鉤子:一種高階梯度壓縮演算法,提供高壓縮率,可以加速受頻寬限制的分散式訓練。
在本教程中,您將瞭解如何快速設定 DDP 通訊鉤子,並使用 Accelerate 中提供的工具進行訓練,這可以像新增一行新程式碼一樣簡單!本教程演示瞭如何使用 DDP 通訊鉤子來最佳化使用 Accelerate 庫的分散式訓練中的梯度通訊。
FP16 壓縮鉤子
PyTorch
Accelerate
import torch
from torch.nn.parallel import DistributedDataParallel as DDP
from torch.distributed.algorithms.ddp_comm_hooks import default_hooks
from accelerate.test_utils.testing import get_backend
device_type, _, _ = get_backend()
device_id = getattr(torch, device_type, torch.cuda).current_device()
class MyModel(torch.nn.Module):
def __init__(self):
super().__init__()
self.layer = torch.nn.Linear(10, 10)
def forward(self, x):
return self.layer(x)
model = MyModel()
model = DDP(model, device_ids=[device_id])
model.register_comm_hook(state=None, hook=default_hooks.fp16_compress_hook)
# Training loop
for data, targets in data_loader:
outputs = model(data)
loss = criterion(outputs, targets)
loss.backward()
optimizer.step()
optimizer.zero_grad()
BF16 壓縮鉤子
BF16 壓縮鉤子 API 是實驗性的,需要 2.9.6 以上版本的 NCCL。
PyTorch
Accelerate
import torch
from torch.nn.parallel import DistributedDataParallel as DDP
from torch.distributed.algorithms.ddp_comm_hooks import default_hooks
from accelerate.test_utils.testing import get_backend
device_type, _, _ = get_backend()
device_id = getattr(torch, device_type, torch.cuda).current_device()
class MyModel(torch.nn.Module):
def __init__(self):
super().__init__()
self.layer = torch.nn.Linear(10, 10)
def forward(self, x):
return self.layer(x)
model = MyModel()
model = DDP(model, device_ids=[device_id])
model.register_comm_hook(state=None, hook=default_hooks.bf16_compress_hook)
# Training loop
for data, targets in data_loader:
outputs = model(data)
loss = criterion(outputs, targets)
loss.backward()
optimizer.step()
optimizer.zero_grad()
PowerSGD 鉤子
PowerSGD 通常需要與模型梯度大小相同的額外記憶體來啟用誤差反饋,這可以補償有偏的壓縮通訊並提高準確性。
PyTorch
Accelerate
import torch
from torch.nn.parallel import DistributedDataParallel as DDP
from torch.distributed.algorithms.ddp_comm_hooks import powerSGD_hook
from accelerate.test_utils.testing import get_backend
device_type, _, _ = get_backend()
device_id = getattr(torch, device_type, torch.cuda).current_device()
class MyModel(torch.nn.Module):
def __init__(self):
super().__init__()
self.layer = torch.nn.Linear(10, 10)
def forward(self, x):
return self.layer(x)
model = MyModel()
model = DDP(model, device_ids=[device_id])
state = powerSGD_hook.PowerSGDState(process_group=None)
model.register_comm_hook(state=state, hook=powerSGD_hook.powerSGD_hook)
# Training loop
for data, targets in data_loader:
outputs = model(data)
loss = criterion(outputs, targets)
loss.backward()
optimizer.step()
optimizer.zero_grad()
DDP 通訊鉤子工具
還有兩個額外的工具,用於支援通訊鉤子的可選功能。
comm_wrapper
`comm_wrapper` 是一個選項,用於為通訊鉤子包裝額外的功能。例如,它可以用於將 FP16 壓縮與其他通訊策略結合使用。當前支援的包裝器有 `no`、`fp16` 和 `bf16`。
from accelerate import Accelerator, DDPCommunicationHookType, DistributedDataParallelKwargs
import torch
class MyModel(torch.nn.Module):
def __init__(self):
super().__init__()
self.layer = torch.nn.Linear(10, 10)
def forward(self, x):
return self.layer(x)
# DDP Communication Hook setup
ddp_kwargs = DistributedDataParallelKwargs(
comm_hook=DDPCommunicationHookType.POWER_SGD,
comm_wrapper=DDPCommunicationHookType.FP16
)
accelerator = Accelerator(kwargs_handlers=[ddp_kwargs])
model = MyModel()
optimizer = torch.optim.Adam(model.parameters())
data_loader = DataLoader(dataset, batch_size=16)
model, optimizer, data_loader = accelerator.prepare(model, optimizer, data_loader)
# Training loop
for data, targets in data_loader:
outputs = model(data)
loss = criterion(outputs, targets)
accelerator.backward(loss)
optimizer.step()
optimizer.zero_grad()
comm_state_option
`comm_state_option` 允許您傳遞某些通訊鉤子所需的附加狀態資訊。這對於像 `PowerSGD` 這樣需要在訓練步驟之間維護超引數和內部狀態的有狀態鉤子特別有用。下面是一個展示如何將 `comm_state_option` 與 `PowerSGD` 鉤子一起使用的示例。
from accelerate import Accelerator, DDPCommunicationHookType, DistributedDataParallelKwargs
import torch
class MyModel(torch.nn.Module):
def __init__(self):
super().__init__()
self.layer = torch.nn.Linear(10, 10)
def forward(self, x):
return self.layer(x)
# DDP Communication Hook setup
ddp_kwargs = DistributedDataParallelKwargs(
comm_hook=DDPCommunicationHookType.POWER_SGD,
comm_state_option={"matrix_approximation_rank": 2}
)
accelerator = Accelerator(kwargs_handlers=[ddp_kwargs])
model = MyModel()
optimizer = torch.optim.Adam(model.parameters())
data_loader = DataLoader(dataset, batch_size=16)
model, optimizer, data_loader = accelerator.prepare(model, optimizer, data_loader)
# Training loop
for data, targets in data_loader:
outputs = model(data)
loss = criterion(outputs, targets)
accelerator.backward(loss)
optimizer.step()
optimizer.zero_grad()
有關更高階的用法和其他鉤子,請參閱 PyTorch DDP 通訊鉤子文件。
< > 在 GitHub 上更新