SetFit 文件
回撥
並獲得增強的文件體驗
開始使用
回撥
SetFit 模型可以透過回撥函式(例如用於日誌記錄或提前停止)進行影響。
本指南將向您展示它們是什麼以及如何使用它們。
SetFit 中的回撥
回撥是自定義 SetFit 訓練器 中訓練迴圈行為的物件,它們可以檢查訓練迴圈狀態(用於進度報告、日誌記錄、訓練期間檢查嵌入)並做出決策(例如提前停止)。
特別是,訓練器 使用一個 TrainerControl
,它可以受到回撥的影響來停止訓練、儲存模型、評估或記錄,以及一個 TrainerState
,它在訓練期間跟蹤一些訓練迴圈指標,例如到目前為止的訓練步數。
SetFit 依賴於 transformers
中實現的回撥,如 transformers
文件 此處 所述。
預設回撥
SetFit 使用 TrainingArguments.report_to
引數來指定應啟用哪些內建回撥。此引數預設為 "all"
,這意味著將啟用所有已安裝的 transformers
中的第三方回撥。例如,TensorBoardCallback
或 WandbCallback
。
除此之外,PrinterCallback
或 ProgressCallback
始終啟用以顯示訓練進度,DefaultFlowCallback
也始終啟用以正確更新 TrainerControl
。
使用回撥
如前所述,您可以使用 TrainingArguments.report_to
來精確指定您希望啟用哪些回撥。例如:
from setfit import TrainingArguments
args = TrainingArguments(
...,
report_to="wandb",
...,
)
# or
args = TrainingArguments(
...,
report_to=["wandb", "tensorboard"],
...,
)
您還可以使用 Trainer.add_callback()、Trainer.pop_callback() 和 Trainer.remove_callback() 來影響訓練器回撥,您還可以透過 訓練器 初始化來指定回撥,例如:
from setfit import Trainer
...
trainer = Trainer(
model,
args=args,
train_dataset=train_dataset,
eval_dataset=eval_dataset,
callbacks=[EarlyStoppingCallback(early_stopping_patience=5)],
)
trainer.train()
自定義回撥
SetFit 支援自定義回撥,其方式與 transformers
相同:透過子類化 TrainerCallback
。此類別實現了許多可以重寫的 on_...
方法。例如,以下指令碼展示了一個自定義回撥,它在訓練期間儲存訓練和評估嵌入的 tSNE 繪圖。
import os
import matplotlib.pyplot as plt
from sklearn.manifold import TSNE
class EmbeddingPlotCallback(TrainerCallback):
"""Simple embedding plotting callback that plots the tSNE of the training and evaluation datasets throughout training."""
def on_init_end(self, args: TrainingArguments, state: TrainerState, control: TrainerControl, **kwargs):
os.makedirs("logs", exist_ok=True)
def on_evaluate(self, args: TrainingArguments, state: TrainerState, control: TrainerControl, model: SetFitModel, **kwargs):
train_embeddings = model.encode(train_dataset["text"])
eval_embeddings = model.encode(eval_dataset["text"])
fig, (train_ax, eval_ax) = plt.subplots(ncols=2)
train_X = TSNE(n_components=2).fit_transform(train_embeddings)
train_ax.scatter(*train_X.T, c=train_dataset["label"], label=train_dataset["label"])
train_ax.set_title("Training embeddings")
eval_X = TSNE(n_components=2).fit_transform(eval_embeddings)
eval_ax.scatter(*eval_X.T, c=eval_dataset["label"], label=eval_dataset["label"])
eval_ax.set_title("Evaluation embeddings")
fig.suptitle(f"tSNE of training and evaluation embeddings at step {state.global_step} of {state.max_steps}.")
fig.savefig(f"logs/step_{state.global_step}.png")
,其中
trainer = Trainer( model=model, args=args, train_dataset=train_dataset, eval_dataset=eval_dataset, callbacks=[EmbeddingPlotCallback()] ) trainer.train()
來自 EmbeddingPlotCallback
的 on_evaluate
將在每次評估呼叫時觸發。在本例中,它產生了以下繪圖:
第 20 步 | 第 40 步 |
---|---|
第 60 步 | 第 80 步 |