SetFit 文件

快速入門

Hugging Face's logo
加入 Hugging Face 社群

並獲得增強的文件體驗

開始使用

快速入門

本快速入門面向準備深入程式碼並檢視如何訓練和使用 🤗 SetFit 模型的開發人員。我們建議從本快速入門開始,然後繼續閱讀教程操作指南以獲取更多資料。此外,概念指南有助於準確解釋 SetFit 的工作原理。

首先安裝 🤗 SetFit

pip install setfit

如果您有支援 CUDA 的顯示卡,建議安裝支援 CUDA 的 torch,以便更快地進行訓練和推理

pip install torch --index-url https://download.pytorch.org/whl/cu118

SetFit

SetFit 是一個高效的框架,可以使用少量訓練資料訓練低延遲文字分類模型。在本快速入門中,您將學習如何訓練 SetFit 模型、如何使用它進行推理以及如何將其儲存到 Hugging Face Hub。

訓練

在本節中,您將載入一個句子轉換器模型,並進一步對其進行微調,以將電影評論分類為正面或負面。為了訓練模型,我們需要準備以下三項:1) 一個模型,2) 一個資料集,以及 3) 訓練引數

1. 使用我們選擇的句子轉換器模型初始化一個 SetFit 模型。考慮使用MTEB 排行榜來指導您選擇哪個句子轉換器模型。我們將使用BAAI/bge-small-en-v1.5,這是一個小巧但效能良好的模型。

>>> from setfit import SetFitModel

>>> model = SetFitModel.from_pretrained("BAAI/bge-small-en-v1.5")

2a. 接下來,載入SetFit/sst2資料集的“訓練”和“測試”拆分。請注意,資料集有 "text""label" 列:這正是 🤗 SetFit 期望的格式。如果您的資料集有不同的列,那麼您可以在第 4 步中使用 Trainer 的 `column_mapping` 引數將列名對映到 "text""label"

>>> from datasets import load_dataset

>>> dataset = load_dataset("SetFit/sst2")
>>> dataset
DatasetDict({
    train: Dataset({
        features: ['text', 'label', 'label_text'],
        num_rows: 6920
    })
    test: Dataset({
        features: ['text', 'label', 'label_text'],
        num_rows: 1821
    })
    validation: Dataset({
        features: ['text', 'label', 'label_text'],
        num_rows: 872
    })
})

2b. 在實際場景中,擁有約 7,000 個高質量標記訓練樣本是非常罕見的,因此我們將大幅縮小訓練資料集,以便更好地瞭解 🤗 SetFit 在實際設定中如何工作。具體來說,`sample_dataset` 函式將為每個類別僅取樣 8 個樣本。測試集不受影響,以便更好地進行評估。

>>> from setfit import sample_dataset

>>> train_dataset = sample_dataset(dataset["train"], label_column="label", num_samples=8)
>>> train_dataset
Dataset({
    features: ['text', 'label', 'label_text'],
    num_rows: 16
})
>>> test_dataset = dataset["test"]
>>> test_dataset
Dataset({
    features: ['text', 'label', 'label_text'],
    num_rows: 1821
})

2c. 我們可以將資料集中的標籤應用於模型,這樣預測輸出就可以是可讀的類別。您也可以直接將標籤提供給 `SetFitModel.from_pretrained()`。

>>> model.labels = ["negative", "positive"]

3. 準備訓練引數以進行訓練。請注意,使用 🤗 SetFit 進行訓練在幕後包含兩個階段:微調嵌入訓練分類頭。因此,某些訓練引數可以是元組,其中兩個值分別用於這兩個階段。

`num_epochs` 和 `max_steps` 引數通常用於增加和減少總訓練步數。請注意,使用 SetFit 時,更好的效能是透過更多資料,而不是更多訓練來實現的!如果您有大量資料,即使訓練不到 1 個 epoch 也無需擔心。

>>> from setfit import TrainingArguments

>>> args = TrainingArguments(
...     batch_size=32,
...     num_epochs=10,
... )

4. 初始化訓練器並執行訓練。

>>> from setfit import Trainer

>>> trainer = Trainer(
...     model=model,
...     args=args,
...     train_dataset=train_dataset,
... )
>>> trainer.train()
***** Running training *****
  Num examples = 5
  Num epochs = 10
  Total optimization steps = 50
  Total train batch size = 32
{'embedding_loss': 0.2077, 'learning_rate': 4.000000000000001e-06, 'epoch': 0.2}                                                                                                                
{'embedding_loss': 0.0097, 'learning_rate': 0.0, 'epoch': 10.0}                                                                                                                                 
{'train_runtime': 14.705, 'train_samples_per_second': 108.807, 'train_steps_per_second': 3.4, 'epoch': 10.0}
100%|██████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 50/50 [00:08<00:00,  5.70it/s]

5. 使用提供的測試資料集進行評估。

>>> trainer.evaluate(test_dataset)
***** Running evaluation *****
{'accuracy': 0.8511806699615596}

隨意嘗試增加每個類別的樣本數量,以觀察準確性的提高。作為一項挑戰,您可以嘗試調整每個類別的樣本數、學習率、 epoch 數、最大步數以及基礎句子轉換器模型,以嘗試在少量資料下將準確性提高到 90% 以上。

儲存 🤗 SetFit 模型

訓練後,您可以將 🤗 SetFit 模型儲存到本地檔案系統或 Hugging Face Hub。透過提供 `save_directory`,使用 `SetFitModel.save_pretrained()` 將模型儲存到本地目錄

>>> model.save_pretrained("setfit-bge-small-v1.5-sst2-8-shot")

或者,透過提供 `repo_id`,使用 `SetFitModel.push_to_hub()` 將模型推送到 Hugging Face Hub

>>> model.push_to_hub("tomaarsen/setfit-bge-small-v1.5-sst2-8-shot")

載入 🤗 SetFit 模型

可以透過提供 1) 來自 Hugging Face Hub 的 `repo_id` 或 2) 本地目錄的路徑來使用 `SetFitModel.from_pretrained()` 載入 🤗 SetFit 模型

>>> model = SetFitModel.from_pretrained("tomaarsen/setfit-bge-small-v1.5-sst2-8-shot") # Load from the Hugging Face Hub

>>> model = SetFitModel.from_pretrained("setfit-bge-small-v1.5-sst2-8-shot") # Load from a local directory

推理

訓練好 🤗 SetFit 模型後,可以使用SetFitModel.predict()SetFitModel.call()對評論進行分類以進行推理

>>> preds = model.predict([
...     "It's a charming and often affecting journey.",
...     "It's slow -- very, very slow.",
...     "A sometimes tedious film.",
... ])
>>> preds
['positive' 'negative' 'negative']

這些預測依賴於 `model.labels`。如果未設定,它將以訓練期間使用的格式返回預測,例如 `tensor([1, 0, 0])`。

下一步是什麼?

您已完成 🤗 SetFit 快速入門!您現在可以訓練、儲存、載入 🤗 SetFit 模型並進行推理!

接下來的步驟,請查閱我們的操作指南,瞭解如何進行超引數搜尋、知識蒸餾或零樣本文字分類等更具體的操作。如果您有興趣深入瞭解 🤗 SetFit 的工作原理,請泡杯咖啡,閱讀我們的概念指南

端到端

此程式碼片段展示了整個快速入門的端到端示例

from setfit import SetFitModel, Trainer, TrainingArguments, sample_dataset
from datasets import load_dataset

# Initializing a new SetFit model
model = SetFitModel.from_pretrained("BAAI/bge-small-en-v1.5", labels=["negative", "positive"])

# Preparing the dataset
dataset = load_dataset("SetFit/sst2")
train_dataset = sample_dataset(dataset["train"], label_column="label", num_samples=8)
test_dataset = dataset["test"]

# Preparing the training arguments
args = TrainingArguments(
    batch_size=32,
    num_epochs=10,
)

# Preparing the trainer
trainer = Trainer(
    model=model,
    args=args,
    train_dataset=train_dataset,
)
trainer.train()

# Evaluating
metrics = trainer.evaluate(test_dataset)
print(metrics)
# => {'accuracy': 0.8511806699615596}

# Saving the trained model
model.save_pretrained("setfit-bge-small-v1.5-sst2-8-shot")
# or
model.push_to_hub("tomaarsen/setfit-bge-small-v1.5-sst2-8-shot")

# Loading a trained model
model = SetFitModel.from_pretrained("tomaarsen/setfit-bge-small-v1.5-sst2-8-shot") # Load from the Hugging Face Hub
# or
model = SetFitModel.from_pretrained("setfit-bge-small-v1.5-sst2-8-shot") # Load from a local directory

# Performing inference
preds = model.predict([
    "It's a charming and often affecting journey.",
    "It's slow -- very, very slow.",
    "A sometimes tedious film.",
])
print(preds)
# => ["positive", "negative", "negative"]
< > 在 GitHub 上更新

© . This site is unofficial and not affiliated with Hugging Face, Inc.