Accelerate 文件
執行與推遲作業
並獲得增強的文件體驗
開始使用
執行與推遲作業
當您執行常規指令碼時,指令會按順序執行。使用 Accelerate 在多個 GPU 上同時部署指令碼會帶來一個複雜問題:雖然每個程序都按順序執行所有指令,但某些程序可能會比其他程序快。
您可能需要等待所有程序都到達某個特定點後,才能執行給定的指令。例如,在確定每個程序都完成訓練之前,您不應該儲存模型;在所有模型權重都載入完畢之前,您也不想繼續訓練。要實現這一點,只需在程式碼中寫入以下行:
accelerator.wait_for_everyone()
這條指令會阻塞所有先到達的程序,直到所有其他程序都到達該點(如果您只在一個 GPU 或 CPU 上執行指令碼,這條指令將不起任何作用)。
下面列出了一些使用此功能的示例情況:
其中一些是與 main_process_first() 上下文管理器一起使用的,該管理器利用 wait_for_everyone() 在主程序上預先執行一組特定的程式碼,然後再觸發和啟動其他程序。
下載資料集
下載資料集時,您應該先在主程序上下載,然後再載入快取的資料集。
load_dataset
會在內部執行鎖定,以阻止同時進行多次下載,但如果您下載的內容未使用此庫,則應使用此方法。
with accelerator.main_process_first():
datasets = load_dataset("glue", "mrpc")
這在底層與呼叫以下程式碼相同:
# First do something on the main process
if accelerator.is_main_process:
datasets = load_dataset("glue", "mrpc")
else:
accelerator.wait_for_everyone()
# And then send it to the rest of them
if not accelerator.is_main_process:
datasets = load_dataset("glue", "mrpc")
else:
accelerator.wait_for_everyone()
儲存 state_dict
在儲存模型的 state_dict
時,由於您通常只在主程序上儲存一個檔案,您應該明確指定這一點。
if accelerator.is_main_process:
model = accelerator.unwrap_model(model)
torch.save(model.state_dict(), "weights.pth")
載入 state_dict
將 state_dict
載入到模型、最佳化器或排程器時,您應等待所有工作程序都載入完權重後,再繼續進行訓練。
with accelerator.main_process_first():
state = torch.load("weights.pth")
model.load_state_dict(state)
應用多工作程序 CPU 操作
在多個工作程序上應用 map()
操作(如分詞)時,應首先在主程序上完成,然後傳播到每個工作程序。
datasets = load_dataset("glue", "mrpc")
with accelerator.main_process_first():
tokenized_datasets = datasets.map(
tokenize_function,
batched=True,
remove_columns=["idx", "sentence1", "sentence2"],
)
應用檢查,如提前停止
要實現一個由特定程序設定標誌的檢查,應使用 set_trigger
和 check_trigger
API。這在某些情況下非常有用,例如使用提前停止和監控損失(因為每個程序上的損失會略有不同)。
當您的條件滿足時,呼叫 Accelerator.set_trigger();當檢查任何程序中該條件是否已滿足時,呼叫 Accelerator.check_trigger()。
for (x,y) in data_loader:
logits = model(x)
loss = loss_func(logits, y)
# Assume `should_do_early_stopping` is a custom defined function that returns a conditional
if should_do_early_stopping(loss):
accelerator.set_trigger()
# Later in the training script when we need to check for the breakpoint
if accelerator.check_trigger():
break