Accelerate 文件

檢查點

Hugging Face's logo
加入 Hugging Face 社群

並獲得增強的文件體驗

開始使用

檢查點

當使用 Accelerate 訓練 PyTorch 模型時,您可能經常需要儲存和繼續訓練狀態。這需要儲存和載入模型、最佳化器、RNG 生成器和 GradScaler。在 Accelerate 內部有兩個便捷函式可以快速實現這一點。

  • 使用 save_state() 將上述所有內容儲存到資料夾位置。
  • 使用 load_state() 載入之前透過 save_state 儲存的所有內容。

要進一步自定義透過 save_state() 儲存狀態的位置和方式,可以使用 ProjectConfiguration 類。例如,如果啟用了 automatic_checkpoint_naming,那麼每個儲存的檢查點都將位於 Accelerator.project_dir/checkpoints/checkpoint_{checkpoint_number}

需要注意的是,這些狀態應來自同一個訓練指令碼,而不應來自兩個不同的指令碼。

  • 透過使用 register_for_checkpointing(),您可以註冊自定義物件,以便透過前面兩個函式自動儲存或載入,只要該物件具有 state_dict **和** load_state_dict 功能即可。這可以包括學習率排程器等物件。

下面是一個在訓練期間使用檢查點儲存和重新載入狀態的簡短示例。

from accelerate import Accelerator
import torch

accelerator = Accelerator(project_dir="my/save/path")

my_scheduler = torch.optim.lr_scheduler.StepLR(my_optimizer, step_size=1, gamma=0.99)
my_model, my_optimizer, my_training_dataloader = accelerator.prepare(my_model, my_optimizer, my_training_dataloader)

# Register the LR scheduler
accelerator.register_for_checkpointing(my_scheduler)

# Save the starting state
accelerator.save_state()

device = accelerator.device
my_model.to(device)

# Perform training
for epoch in range(num_epochs):
    for batch in my_training_dataloader:
        my_optimizer.zero_grad()
        inputs, targets = batch
        inputs = inputs.to(device)
        targets = targets.to(device)
        outputs = my_model(inputs)
        loss = my_loss_function(outputs, targets)
        accelerator.backward(loss)
        my_optimizer.step()
    my_scheduler.step()

# Restore the previous state
accelerator.load_state("my/save/path/checkpointing/checkpoint_0")

恢復 DataLoader 的狀態

從檢查點恢復後,如果狀態是在一個 epoch 的中間儲存的,您可能還希望從當前 DataLoader 的特定點恢復。您可以使用 skip_first_batches() 來實現。

from accelerate import Accelerator

accelerator = Accelerator(project_dir="my/save/path")

train_dataloader = accelerator.prepare(train_dataloader)
accelerator.load_state("my_state")

# Assume the checkpoint was saved 100 steps into the epoch
skipped_dataloader = accelerator.skip_first_batches(train_dataloader, 100)

# After the first iteration, go back to `train_dataloader`

# First epoch
for batch in skipped_dataloader:
    # Do something
    pass

# Second epoch
for batch in train_dataloader:
    # Do something
    pass
< > 在 GitHub 上更新

© . This site is unofficial and not affiliated with Hugging Face, Inc.