SetFit 文件

零樣本文字分類

Hugging Face's logo
加入 Hugging Face 社群

並獲得增強的文件體驗

開始使用

零樣本文字分類

您的類名可能已經很好地描述了您想要分類的文字。使用 🤗 SetFit,您可以將這些類名與強大的預訓練 Sentence Transformer 模型一起使用,從而無需任何訓練樣本即可獲得一個強大的基線模型。

本指南將向您展示如何執行零樣本文字分類。

測試資料集

我們將使用 dair-ai/emotion 資料集來測試零樣本模型的效能。

from datasets import load_dataset

test_dataset = load_dataset("dair-ai/emotion", "split", split="test")

此資料集將類名儲存在資料集 Features 中,因此我們將按如下方式提取類:

classes = test_dataset.features["label"].names
# => ['sadness', 'joy', 'love', 'anger', 'fear', 'surprise']

否則,我們可以手動設定類列表。

合成數據集

然後,我們可以使用 get_templated_dataset() 根據這些類名合成生成一個虛擬資料集。

from setfit import get_templated_dataset

train_dataset = get_templated_dataset()
print(train_dataset)
# => Dataset({
#     features: ['text', 'label'],
#     num_rows: 48
# })
print(train_dataset[0])
# {'text': 'This sentence is sadness', 'label': 0}

訓練

我們可以像往常一樣使用此資料集來訓練 SetFit 模型。

from setfit import SetFitModel, Trainer, TrainingArguments

model = SetFitModel.from_pretrained("BAAI/bge-small-en-v1.5")

args = TrainingArguments(
    batch_size=32,
    num_epochs=1,
)

trainer = Trainer(
    model=model,
    args=args,
    train_dataset=train_dataset,
    eval_dataset=test_dataset,
)
trainer.train()
***** Running training *****
  Num examples = 60
  Num epochs = 1
  Total optimization steps = 60
  Total train batch size = 32
{'embedding_loss': 0.2628, 'learning_rate': 3.3333333333333333e-06, 'epoch': 0.02}                                                                                 
{'embedding_loss': 0.0222, 'learning_rate': 3.7037037037037037e-06, 'epoch': 0.83}                                                                                 
{'train_runtime': 15.4717, 'train_samples_per_second': 124.098, 'train_steps_per_second': 3.878, 'epoch': 1.0}                                                     
100%|█████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 60/60 [00:09<00:00,  6.35it/s]

訓練後,我們可以評估模型。

metrics = trainer.evaluate()
print(metrics)
***** Running evaluation *****
{'accuracy': 0.591}

並執行預測。

preds = model.predict([
    "i am just feeling cranky and blue",
    "i feel incredibly lucky just to be able to talk to her",
    "you're pissing me off right now",
    "i definitely have thalassophobia, don't get me near water like that",
    "i did not see that coming at all",
])
print([classes[idx] for idx in preds])
['sadness', 'joy', 'anger', 'fear', 'surprise']

這些預測看起來都正確!

基準

為了表明 SetFit 的零樣本效能良好,我們將它與 transformers 中的零樣本分類模型進行比較。

from transformers import pipeline
from datasets import load_dataset
import evaluate

# Prepare the testing dataset
test_dataset = load_dataset("dair-ai/emotion", "split", split="test")
classes = test_dataset.features["label"].names

# Set up the zero-shot classification pipeline from transformers
# Uses 'facebook/bart-large-mnli' by default
pipe = pipeline("zero-shot-classification", device=0)
zeroshot_preds = pipe(test_dataset["text"], batch_size=16, candidate_labels=classes)
preds = [classes.index(pred["labels"][0]) for pred in zeroshot_preds]

# Compute the accuracy
metric = evaluate.load("accuracy")
transformers_accuracy = metric.compute(predictions=preds, references=test_dataset["label"])
print(transformers_accuracy)
{'accuracy': 0.3765}

憑藉 59.1% 的準確率,0-shot SetFit 顯著優於 transformers 推薦的零樣本模型。

預測延遲

除了獲得更高的準確率,SetFit 的速度也快得多。讓我們計算 SetFit 使用 BAAI/bge-small-en-v1.5transformers 使用 facebook/bart-large-mnli 的延遲。兩項測試均在 GPU 上執行。

import time

start_t = time.time()
pipe(test_dataset["text"], batch_size=32, candidate_labels=classes)
delta_t = time.time() - start_t
print(f"`transformers` with `facebook/bart-large-mnli` latency: {delta_t / len(test_dataset['text']) * 1000:.4f}ms per sentence")
`transformers` with `facebook/bart-large-mnli` latency: 31.1765ms per sentence
import time

start_t = time.time()
model.predict(test_dataset["text"])
delta_t = time.time() - start_t
print(f"SetFit with `BAAI/bge-small-en-v1.5` latency: {delta_t / len(test_dataset['text']) * 1000:.4f}ms per sentence")
SetFit with `BAAI/bge-small-en-v1.5` latency: 0.4600ms per sentence

因此,使用 BAAI/bge-small-en-v1.5 的 SetFit 比使用 facebook/bart-large-mnlitransformers 快 67 倍,同時更準確。

zero_shot_transformers_vs_setfit

< > 在 GitHub 上更新

© . This site is unofficial and not affiliated with Hugging Face, Inc.