SetFit 文件

回撥

Hugging Face's logo
加入 Hugging Face 社群

並獲得增強的文件體驗

開始使用

回撥

SetFit 模型可以透過回撥函式(例如用於日誌記錄或提前停止)進行影響。

本指南將向您展示它們是什麼以及如何使用它們。

SetFit 中的回撥

回撥是自定義 SetFit 訓練器 中訓練迴圈行為的物件,它們可以檢查訓練迴圈狀態(用於進度報告、日誌記錄、訓練期間檢查嵌入)並做出決策(例如提前停止)。

特別是,訓練器 使用一個 TrainerControl,它可以受到回撥的影響來停止訓練、儲存模型、評估或記錄,以及一個 TrainerState,它在訓練期間跟蹤一些訓練迴圈指標,例如到目前為止的訓練步數。

SetFit 依賴於 transformers 中實現的回撥,如 transformers 文件 此處 所述。

預設回撥

SetFit 使用 TrainingArguments.report_to 引數來指定應啟用哪些內建回撥。此引數預設為 "all",這意味著將啟用所有已安裝的 transformers 中的第三方回撥。例如,TensorBoardCallbackWandbCallback

除此之外,PrinterCallbackProgressCallback 始終啟用以顯示訓練進度,DefaultFlowCallback 也始終啟用以正確更新 TrainerControl

使用回撥

如前所述,您可以使用 TrainingArguments.report_to 來精確指定您希望啟用哪些回撥。例如:

from setfit import TrainingArguments

args = TrainingArguments(
    ...,
    report_to="wandb",
    ...,
)
# or 
args = TrainingArguments(
    ...,
    report_to=["wandb", "tensorboard"],
    ...,
)

您還可以使用 Trainer.add_callback()Trainer.pop_callback()Trainer.remove_callback() 來影響訓練器回撥,您還可以透過 訓練器 初始化來指定回撥,例如:

from setfit import Trainer

...

trainer = Trainer(
    model,
    args=args,
    train_dataset=train_dataset,
    eval_dataset=eval_dataset,
    callbacks=[EarlyStoppingCallback(early_stopping_patience=5)],
)
trainer.train()

自定義回撥

SetFit 支援自定義回撥,其方式與 transformers 相同:透過子類化 TrainerCallback。此類別實現了許多可以重寫的 on_... 方法。例如,以下指令碼展示了一個自定義回撥,它在訓練期間儲存訓練和評估嵌入的 tSNE 繪圖。

import os
import matplotlib.pyplot as plt
from sklearn.manifold import TSNE

class EmbeddingPlotCallback(TrainerCallback):
    """Simple embedding plotting callback that plots the tSNE of the training and evaluation datasets throughout training."""
    def on_init_end(self, args: TrainingArguments, state: TrainerState, control: TrainerControl, **kwargs):
        os.makedirs("logs", exist_ok=True)

    def on_evaluate(self, args: TrainingArguments, state: TrainerState, control: TrainerControl, model: SetFitModel, **kwargs):
        train_embeddings = model.encode(train_dataset["text"])
        eval_embeddings = model.encode(eval_dataset["text"])

        fig, (train_ax, eval_ax) = plt.subplots(ncols=2)

        train_X = TSNE(n_components=2).fit_transform(train_embeddings)
        train_ax.scatter(*train_X.T, c=train_dataset["label"], label=train_dataset["label"])
        train_ax.set_title("Training embeddings")

        eval_X = TSNE(n_components=2).fit_transform(eval_embeddings)
        eval_ax.scatter(*eval_X.T, c=eval_dataset["label"], label=eval_dataset["label"])
        eval_ax.set_title("Evaluation embeddings")

        fig.suptitle(f"tSNE of training and evaluation embeddings at step {state.global_step} of {state.max_steps}.")
        fig.savefig(f"logs/step_{state.global_step}.png")

,其中

trainer = Trainer(
    model=model,
    args=args,
    train_dataset=train_dataset,
    eval_dataset=eval_dataset,
    callbacks=[EmbeddingPlotCallback()]
)
trainer.train()

來自 EmbeddingPlotCallbackon_evaluate 將在每次評估呼叫時觸發。在本例中,它產生了以下繪圖:

第 20 步 第 40 步
step_20 step_40
第 60 步 第 80 步
step_60 step_80
< > 在 GitHub 上更新

© . This site is unofficial and not affiliated with Hugging Face, Inc.