SetFit 文件
零樣本文字分類
並獲得增強的文件體驗
開始使用
零樣本文字分類
雖然 SetFit 是為少樣本學習而設計的,但該方法也可以應用於沒有標記資料的場景。主要的技巧是建立類似於分類任務的*合成示例*,然後使用它們訓練 SetFit 模型。
值得注意的是,這種簡單的技術通常優於 🤗 Transformers 中的零樣本管道,並且預測速度可以快 5 倍(或更多)!
在本教程中,我們將探討如何
- SetFit 可以應用於零樣本分類
- 新增合成示例還可以為少樣本分類提供效能提升。
設定
如果您在 Colab 或其他雲平臺上執行此 Notebook,則需要安裝 `setfit` 庫。取消註釋以下單元格並執行它
# %pip install setfit matplotlib
為了基準測試“零樣本”方法的效能,我們將使用以下資料集和預訓練模型
dataset_id = "emotion"
model_id = "sentence-transformers/paraphrase-mpnet-base-v2"
接下來,我們將從 Hugging Face Hub 下載參考資料集
from datasets import load_dataset
reference_dataset = load_dataset(dataset_id)
reference_dataset
DatasetDict({
train: Dataset({
features: ['text', 'label'],
num_rows: 16000
})
validation: Dataset({
features: ['text', 'label'],
num_rows: 2000
})
test: Dataset({
features: ['text', 'label'],
num_rows: 2000
})
})
現在我們已經設定好,讓我們建立一些合成數據進行訓練!
建立合成數據集
我們需要做的第一件事是建立一個合成示例資料集。在 `setfit` 中,我們可以透過將 `get_templated_dataset()` 函式應用於虛擬資料集來做到這一點。此函式需要幾個主要內容
- 一個用於分類的候選標籤列表。我們將在此處使用參考資料集中的標籤,但這可以是與任務和當前資料集相關的任何內容。
- 一個用於生成示例的模板。預設情況下,它是 `"This sentence is {}"`,其中 `{}` 將由一個候選標籤填充
- 一個樣本大小 $N$,它將為每個類建立 $N$ 個合成示例。我們發現 $N=8$ 通常效果最好。
有了這些資訊,我們首先從資料集中提取一些候選標籤
# Extract ClassLabel feature from "label" column
label_features = reference_dataset["train"].features["label"]
# Label names to classify with
candidate_labels = label_features.names
candidate_labels
['sadness', 'joy', 'love', 'anger', 'fear', 'surprise']
Hugging Face Hub 上的一些資料集在標籤列中沒有 `ClassLabel` 特性。在這種情況下,您應該首先計算 id2label 對映,然後手動計算候選標籤,如下所示
def get_id2label(dataset):
# The column with the label names
label_names = dataset.unique("label_text")
# The column with the label IDs
label_ids = dataset.unique("label")
id2label = dict(zip(label_ids, label_names))
# Sort by label ID
return {key: val for key, val in sorted(id2label.items(), key = lambda x: x[0])}
id2label = get_id2label(reference_dataset["train"])
candidate_labels = list(id2label.values())
現在我們有了標籤,建立合成示例就變得很簡單了
from datasets import Dataset
from setfit import get_templated_dataset
# A dummy dataset to fill with synthetic examples
dummy_dataset = Dataset.from_dict({})
train_dataset = get_templated_dataset(dummy_dataset, candidate_labels=candidate_labels, sample_size=8)
train_dataset
Dataset({
features: ['text', 'label'],
num_rows: 48
})
您可能會發現,透過將 `template` 引數從預設的 `"The sentence is {}"` 調整為 `"This sentence is {}"` 或 `"This example is {}"` 等變體,可以獲得更好的效能。
由於我們的資料集有 6 個類,我們選擇了 8 的樣本大小,因此我們的合成數據集包含 $6\times 8=48$ 個示例。如果我們檢視一些示例
train_dataset.shuffle()[:3]
{'text': ['This sentence is love',
'This sentence is fear',
'This sentence is joy'],
'label': [2, 4, 1]}
我們可以看到每個輸入都採用模板的形式,並具有與之關聯的相應標籤。
我們不要在這些示例上訓練 SetFit 模型!
微調模型
要訓練 SetFit 模型,首先要從 Hub 下載預訓練檢查點。我們可以透過使用 `SetFitModel.from_pretrained()` 方法來做到這一點
from setfit import SetFitModel
model = SetFitModel.from_pretrained(model_id)
在這裡,我們從 Hub 下載了一個預訓練的 Sentence Transformer,並添加了一個邏輯分類頭來建立 SetFit 模型。如訊息所示,我們需要在一些標記示例上訓練這個模型。我們可以透過使用 Trainer 類來做到這一點
from setfit import Trainer
trainer = Trainer(
model=model,
train_dataset=train_dataset,
eval_dataset=reference_dataset["test"]
)
現在我們已經建立了一個訓練器,我們可以訓練它了!同時,讓我們記錄訓練和評估模型所需的時間
%%time trainer.train() zeroshot_metrics = trainer.evaluate() zeroshot_metrics
***** Running training *****
Num examples = 1920
Num epochs = 1
Total optimization steps = 120
Total train batch size = 16
***** Running evaluation *****
{'accuracy': 0.5345}
CPU times: user 12.9 s, sys: 2.37 s, total: 15.2 s
Wall time: 11 s
太好了,現在我們有了一個參考分數,讓我們與 🤗 Transformers 中的零樣本管道進行比較。
與 🤗 Transformers 中的零樣本管道進行比較
🤗 Transformers 提供了一個零樣本管道,將文字分類構建為自然語言推理任務。讓我們載入管道並將其放在 GPU 上以實現快速推理
from transformers import pipeline
pipe = pipeline("zero-shot-classification", device=0)
現在我們有了模型,讓我們生成一些預測。我們將使用與 SetFit 相同的候選標籤,並增加批處理大小以加快速度
%%time
zeroshot_preds = pipe(reference_dataset["test"]["text"], batch_size=16, candidate_labels=candidate_labels)
CPU times: user 1min 10s, sys: 166 ms, total: 1min 11s
Wall time: 53.1 s
請注意,這比 SetFit 生成預測的時間長了近 5 倍!好的,那麼它的表現如何呢?由於每個預測都是按分數排名的標籤名稱字典
zeroshot_preds[0]
{'sequence': 'im feeling rather rotten so im not very ambitious right now',
'labels': ['sadness', 'anger', 'surprise', 'fear', 'joy', 'love'],
'scores': [0.7367985844612122,
0.10041674226522446,
0.09770156443119049,
0.05880110710859299,
0.004266355652362108,
0.0020156768150627613]}
我們可以使用 `label` 列中的 `str2int()` 函式將它們轉換為整數。
preds = [label_features.str2int(pred["labels"][0]) for pred in zeroshot_preds]
**注意:** 如前所述,如果您使用的資料集的標籤列沒有 `ClassLabel` 特性,則需要手動計算標籤對映,例如
id2label = get_id2label(reference_dataset["train"])
label2id = {v:k for k,v in id2label.items()}
preds = [label2id[pred["labels"][0]] for pred in zeroshot_preds]
最後一步是使用 🤗 Evaluate 計算準確率
import evaluate
metric = evaluate.load("accuracy")
transformers_metrics = metric.compute(predictions=preds, references=reference_dataset["test"]["label"])
transformers_metrics
{'accuracy': 0.3765}
與 SetFit 相比,這種方法的效能明顯更差。讓我們透過將合成示例與一些標記示例相結合來結束我們的分析。
用合成示例增強標記資料
如果您有一些標記示例,新增合成數據通常可以提高效能。為了模擬這一點,我們首先從參考資料集中取樣 8 個標記示例
from setfit import sample_dataset
train_dataset = sample_dataset(reference_dataset["train"])
train_dataset
Dataset({
features: ['text', 'label'],
num_rows: 48
})
預熱一下,我們將用這些真實標籤訓練一個 SetFit 模型
model = SetFitModel.from_pretrained(model_id)
trainer = Trainer(
model=model,
train_dataset=train_dataset,
eval_dataset=reference_dataset["test"]
)
trainer.train()
fewshot_metrics = trainer.evaluate()
fewshot_metrics
{'accuracy': 0.4705}
請注意,對於這個特定的資料集,使用真實標籤的效能*比*使用合成示例訓練的效能*更差*!在我們的實驗中,我們發現差異很大程度上取決於具體的資料集。由於 SetFit 模型訓練速度快,您總是可以嘗試兩種方法並選擇最佳的一種。
無論如何,現在讓我們向訓練集中新增一些合成示例
augmented_dataset = get_templated_dataset(train_dataset, candidate_labels=candidate_labels, sample_size=8)
augmented_dataset
Dataset({
features: ['text', 'label'],
num_rows: 96
})
和以前一樣,我們可以用增強資料集訓練和評估 SetFit
model = SetFitModel.from_pretrained(model_id)
trainer = Trainer(
model=model,
train_dataset=augmented_dataset,
eval_dataset=reference_dataset["test"]
)
trainer.train()
augmented_metrics = trainer.evaluate()
augmented_metrics
{'accuracy': 0.613}
太好了,這大大提升了我們的效能,比純粹的合成示例提高了幾個百分點。
讓我們繪製最終結果進行比較
import pandas as pd
df = pd.DataFrame.from_dict({"Method":["Transformers (zero-shot)", "SetFit (zero-shot)", "SetFit (augmented)"], "Accuracy": [transformers_metrics["accuracy"], zeroshot_metrics["accuracy"], augmented_metrics["accuracy"]]})
df.plot(kind="barh", x="Method");