SetFit 文件
零樣本文字分類
加入 Hugging Face 社群
並獲得增強的文件體驗
開始使用
零樣本文字分類
您的類名可能已經很好地描述了您想要分類的文字。使用 🤗 SetFit,您可以將這些類名與強大的預訓練 Sentence Transformer 模型一起使用,從而無需任何訓練樣本即可獲得一個強大的基線模型。
本指南將向您展示如何執行零樣本文字分類。
測試資料集
我們將使用 dair-ai/emotion 資料集來測試零樣本模型的效能。
from datasets import load_dataset
test_dataset = load_dataset("dair-ai/emotion", "split", split="test")
此資料集將類名儲存在資料集 Features
中,因此我們將按如下方式提取類:
classes = test_dataset.features["label"].names
# => ['sadness', 'joy', 'love', 'anger', 'fear', 'surprise']
否則,我們可以手動設定類列表。
合成數據集
然後,我們可以使用 get_templated_dataset() 根據這些類名合成生成一個虛擬資料集。
from setfit import get_templated_dataset
train_dataset = get_templated_dataset()
print(train_dataset)
# => Dataset({
# features: ['text', 'label'],
# num_rows: 48
# })
print(train_dataset[0])
# {'text': 'This sentence is sadness', 'label': 0}
訓練
我們可以像往常一樣使用此資料集來訓練 SetFit 模型。
from setfit import SetFitModel, Trainer, TrainingArguments
model = SetFitModel.from_pretrained("BAAI/bge-small-en-v1.5")
args = TrainingArguments(
batch_size=32,
num_epochs=1,
)
trainer = Trainer(
model=model,
args=args,
train_dataset=train_dataset,
eval_dataset=test_dataset,
)
trainer.train()
***** Running training *****
Num examples = 60
Num epochs = 1
Total optimization steps = 60
Total train batch size = 32
{'embedding_loss': 0.2628, 'learning_rate': 3.3333333333333333e-06, 'epoch': 0.02}
{'embedding_loss': 0.0222, 'learning_rate': 3.7037037037037037e-06, 'epoch': 0.83}
{'train_runtime': 15.4717, 'train_samples_per_second': 124.098, 'train_steps_per_second': 3.878, 'epoch': 1.0}
100%|█████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 60/60 [00:09<00:00, 6.35it/s]
訓練後,我們可以評估模型。
metrics = trainer.evaluate()
print(metrics)
***** Running evaluation *****
{'accuracy': 0.591}
並執行預測。
preds = model.predict([
"i am just feeling cranky and blue",
"i feel incredibly lucky just to be able to talk to her",
"you're pissing me off right now",
"i definitely have thalassophobia, don't get me near water like that",
"i did not see that coming at all",
])
print([classes[idx] for idx in preds])
['sadness', 'joy', 'anger', 'fear', 'surprise']
這些預測看起來都正確!
基準
為了表明 SetFit 的零樣本效能良好,我們將它與 transformers
中的零樣本分類模型進行比較。
from transformers import pipeline
from datasets import load_dataset
import evaluate
# Prepare the testing dataset
test_dataset = load_dataset("dair-ai/emotion", "split", split="test")
classes = test_dataset.features["label"].names
# Set up the zero-shot classification pipeline from transformers
# Uses 'facebook/bart-large-mnli' by default
pipe = pipeline("zero-shot-classification", device=0)
zeroshot_preds = pipe(test_dataset["text"], batch_size=16, candidate_labels=classes)
preds = [classes.index(pred["labels"][0]) for pred in zeroshot_preds]
# Compute the accuracy
metric = evaluate.load("accuracy")
transformers_accuracy = metric.compute(predictions=preds, references=test_dataset["label"])
print(transformers_accuracy)
{'accuracy': 0.3765}
憑藉 59.1% 的準確率,0-shot SetFit 顯著優於 transformers
推薦的零樣本模型。
預測延遲
除了獲得更高的準確率,SetFit 的速度也快得多。讓我們計算 SetFit 使用 BAAI/bge-small-en-v1.5
與 transformers
使用 facebook/bart-large-mnli
的延遲。兩項測試均在 GPU 上執行。
import time
start_t = time.time()
pipe(test_dataset["text"], batch_size=32, candidate_labels=classes)
delta_t = time.time() - start_t
print(f"`transformers` with `facebook/bart-large-mnli` latency: {delta_t / len(test_dataset['text']) * 1000:.4f}ms per sentence")
`transformers` with `facebook/bart-large-mnli` latency: 31.1765ms per sentence
import time
start_t = time.time()
model.predict(test_dataset["text"])
delta_t = time.time() - start_t
print(f"SetFit with `BAAI/bge-small-en-v1.5` latency: {delta_t / len(test_dataset['text']) * 1000:.4f}ms per sentence")
SetFit with `BAAI/bge-small-en-v1.5` latency: 0.4600ms per sentence
因此,使用 BAAI/bge-small-en-v1.5
的 SetFit 比使用 facebook/bart-large-mnli
的 transformers
快 67 倍,同時更準確。