開源 AI 食譜文件
在零樣本文字分類中使用 SetFit 進行資料標註的建議
並獲得增強的文件體驗
開始使用
在零樣本文字分類中使用 SetFit 進行資料標註的建議
作者:David Berenstein 和 Sara Han Díaz
建議是讓標註團隊工作更輕鬆、更快捷的絕佳方式。這些預選選項將使標註過程更高效,因為他們只需要糾正建議即可。在本例中,我們將演示如何使用 SetFit 實現零樣本方法,為 Argilla 中的一個數據集獲取一些初步建議,該資料集結合了兩種文字分類任務,包括 LabelQuestion
和 MultiLabelQuestion
。
Argilla 是一個面向 AI 工程師和領域專家的協作工具,他們需要為自己的專案構建高質量的資料集。透過使用 Argilla,每個人都可以透過結合人工和機器反饋,加快資料整理過程,從而構建出穩健的語言模型。
反饋是資料整理過程中的關鍵部分,Argilla 也提供了一種管理和視覺化反饋的方法,以便後續使用整理好的資料來改進語言模型。在本教程中,我們將展示一個真實案例,說明如何透過提供建議來簡化標註者的工作。為實現這一目標,您將學習如何使用 SetFit 訓練零樣本情感和主題分類器,然後用它們來為資料集建議標籤。
在本教程中,我們將遵循以下步驟:
- 在 Argilla 中建立一個數據集。
- 使用 SetFit 訓練零樣本分類器。
- 使用訓練好的分類器為資料集獲取建議。
- 在 Argilla 中視覺化建議。
讓我們開始吧!
環境設定
在本教程中,您需要執行一個 Argilla 伺服器。如果您已經部署了 Argilla,可以跳過此步驟。否則,您可以按照本指南在 HF Spaces 或本地快速部署 Argilla。完成後,請執行以下步驟:
- 使用
pip
安裝 Argilla 客戶端和所需的第三方庫
!pip install argilla
!pip install setfit==1.0.3 transformers==4.40.2 huggingface_hub==0.23.5
- 匯入必要的庫
import argilla as rg
from datasets import load_dataset
from setfit import SetFitModel, Trainer, get_templated_dataset
- 如果您使用 Docker 快速啟動映象或 Hugging Face Spaces 執行 Argilla,您需要使用
API_URL
和API_KEY
初始化 Argilla 客戶端
# Replace api_url with your url if using Docker
# Replace api_key if you configured a custom API key
# Uncomment the last line and set your HF_TOKEN if your space is private
client = rg.Argilla(
api_url="https://[your-owner-name]-[your_space_name].hf.space",
api_key="[your-api-key]",
# headers={"Authorization": f"Bearer {HF_TOKEN}"}
)
配置資料集
在本示例中,我們將載入 banking77 資料集,這是一個流行的開源資料集,包含銀行領域的客戶請求。
data = load_dataset("PolyAI/banking77", split="test")
Argilla 使用 Dataset
類,這使您可以輕鬆建立資料集並管理資料和反饋。首先需要對 Dataset
進行配置。在 Settings
中,我們可以指定標註*指南*、待標註資料將新增到的*欄位*以及給標註者的*問題*。此外,還可以新增更多功能。更多資訊,請檢視 Argilla 操作指南。
對於我們的用例,我們需要一個文字欄位和兩個不同的問題。我們將使用該資料集的原始標籤對請求中提到的主題進行多標籤分類,並且我們還將設定一個標籤問題,將請求的情感分類為“積極”、“中性”或“消極”。
settings = rg.Settings(
fields=[rg.TextField(name="text")],
questions=[
rg.MultiLabelQuestion(
name="topics",
title="Select the topic(s) of the request",
labels=data.info.features["label"].names,
visible_labels=10,
),
rg.LabelQuestion(
name="sentiment",
title="What is the sentiment of the message?",
labels=["positive", "neutral", "negative"],
),
],
)
dataset = rg.Dataset(
name="setfit_tutorial_dataset",
settings=settings,
)
dataset.create()
訓練模型
現在,我們將使用從 HF 載入的資料以及為資料集配置的標籤和問題,為資料集中的每個問題訓練一個零樣本文字分類模型。如前幾節所述,我們將在兩個分類器中都使用 SetFit 框架對 Sentence Transformers 進行少樣本微調。此外,我們將使用的模型是 all-MiniLM-L6-v2,這是一個句子嵌入模型,它在一個包含 10 億個句子對的資料集上使用對比目標進行了微調。
def train_model(question_name, template, multi_label=False):
train_dataset = get_templated_dataset(
candidate_labels=dataset.questions[question_name].labels,
sample_size=8,
template=template,
multi_label=multi_label,
)
# Train a model using the training dataset we just built
if multi_label:
model = SetFitModel.from_pretrained(
"sentence-transformers/all-MiniLM-L6-v2",
multi_target_strategy="one-vs-rest",
)
else:
model = SetFitModel.from_pretrained("sentence-transformers/all-MiniLM-L6-v2")
trainer = Trainer(model=model, train_dataset=train_dataset)
trainer.train()
return model
topic_model = train_model(
question_name="topics",
template="The customer request is about {}",
multi_label=True,
)
# topic_model.save_pretrained(
# "/path-to-your-models-folder/topic_model"
# )
sentiment_model = train_model(question_name="sentiment", template="This message is {}", multi_label=False)
# topic_model.save_pretrained(
# "/path-to-your-models-folder/sentiment_model"
# )
進行預測
訓練步驟結束後,我們可以對資料進行預測。
def get_predictions(texts, model, question_name):
probas = model.predict_proba(texts, as_numpy=True)
labels = dataset.questions[question_name].labels
for pred in probas:
yield [{"label": label, "score": score} for label, score in zip(labels, pred)]
data = data.map(
lambda batch: {
"topics": list(get_predictions(batch["text"], topic_model, "topics")),
"sentiment": list(get_predictions(batch["text"], sentiment_model, "sentiment")),
},
batched=True,
)
data.to_pandas().head()
將記錄寫入 Argilla
利用我們生成的資料和預測結果,我們現在可以構建包含模型建議的記錄 (標註團隊將要標註的每個資料項)。對於 LabelQuestion
,我們將使用獲得最高機率分數的標籤;對於 MultiLabelQuestion
,我們將包含所有分數高於某個閾值的標籤。在本例中,我們決定使用 2/len(labels)
,但您可以根據自己的資料進行實驗,選擇一個更嚴格或更寬鬆的閾值。
請注意,更寬鬆的閾值 (接近或等於
1/len(labels)
) 會建議更多標籤,而更嚴格的閾值 (介於 2 和 3 之間) 則會選擇更少 (或沒有) 標籤。
def add_suggestions(record):
suggestions = []
# Get label with max score for sentiment question
sentiment = max(record["sentiment"], key=lambda x: x["score"])["label"]
suggestions.append(rg.Suggestion(question_name="sentiment", value=sentiment))
# Get all labels above a threshold for topics questions
threshold = 2 / len(dataset.questions["topics"].labels)
topics = [label["label"] for label in record["topics"] if label["score"] >= threshold]
if topics:
suggestions.append(rg.Suggestion(question_name="topics", value=topics))
return suggestions
records = [rg.Record(fields={"text": record["text"]}, suggestions=add_suggestions(record)) for record in data]
一旦我們對結果滿意,就可以將記錄寫入上面配置的資料集。現在您可以在 Argilla 中訪問該資料集並檢視建議。
dataset.records.log(records)
以下是 UI 顯示我們模型建議的樣子:
或者,您也可以將 Argilla 資料集儲存並載入到 Hugging Face Hub 中。有關如何操作的更多資訊,請參閱 Argilla 文件。
# Export to HuggingFace Hub
dataset.to_hub(repo_id="argilla/my_setfit_dataset")
# Import from HuggingFace Hub
dataset = rg.Dataset.from_hub(repo_id="argilla/my_setfit_dataset")
結論
在本教程中,我們介紹瞭如何使用 SetFit 庫透過零樣本方法向 Argilla 資料集新增建議。這將透過減少標註團隊需要做出的決策和編輯次數來提高標註過程的效率。
檢視以下連結獲取更多資源:
< > 在 GitHub 上更新