使用Sentence Transformers v4訓練和微調Reranker模型
Sentence Transformers是一個Python庫,用於使用和訓練嵌入和重排模型,應用於廣泛的場景,例如檢索增強生成、語義搜尋、語義文字相似度、釋義挖掘等。其v4.0更新引入了一種新的重排器(也稱為交叉編碼器模型)訓練方法,類似於v3.0更新為嵌入模型引入的方法。在這篇部落格文章中,我將向您展示如何使用它來微調一個重排器模型,該模型在您的資料上超越所有現有選項。此方法還可以從頭開始訓練極其強大的新重排器模型。
微調重排器模型涉及幾個元件:資料集、損失函式、訓練引數、評估器和訓練器類本身。我將逐一探討這些元件,並提供如何將它們用於微調強大重排器模型的實用示例。
最後,在評估部分,我將向您展示,我與這篇部落格文章一起訓練的我的小型微調tomaarsen/reranker-ModernBERT-base-gooaq-bce重排器模型,在我的評估資料集上輕鬆超越了13個最常用的公共重排器模型。它甚至擊敗了體積大4倍的模型。
使用更大的基礎模型重複此方法,結果是tomaarsen/reranker-ModernBERT-large-gooaq-bce,一個在我的資料上超越所有現有通用重排器模型的重排器模型。
如果您對微調嵌入模型感興趣,那麼也可以閱讀我之前的使用Sentence Transformers v3訓練和微調嵌入模型部落格文章。
目錄
什麼是Reranker模型?
重排器模型,通常使用交叉編碼器架構實現,旨在評估文字對(例如,查詢和文件,或兩個句子)之間的相關性。與Sentence Transformers(又稱雙編碼器、嵌入模型)不同,後者獨立地將每個文字嵌入到向量中並透過距離度量計算相似度,交叉編碼器透過共享神經網路同時處理配對文字,從而產生一個輸出分數。透過讓兩個文字相互關注,交叉編碼器模型可以勝過嵌入模型。
然而,這種優勢也帶來了權衡:交叉編碼器模型速度較慢,因為它們處理所有可能的文字對(例如,10個查詢和500個候選文件需要5,000次計算,而嵌入模型只需要510次)。這使得它們在大規模初始檢索中效率較低,但非常適合重排:最佳化由更快的Sentence Transformer模型首先識別出的前k個結果。最強的搜尋系統通常採用這種兩階段的“檢索和重排”方法。
在這篇部落格文章中,我將互換使用“重排器模型”和“交叉編碼器模型”。
為什麼要微調?
重排器模型通常面臨一個具有挑戰性的問題:
在這些高度相關的文件中,哪一個最能回答問題?
通用重排器模型經過訓練,可以在各種領域和主題中充分回答這個問題,這阻礙了它們在您的特定領域中發揮最大潛力。透過微調,模型可以學習專注於對您而言重要的領域和/或語言。
在這篇部落格文章的評估部分,我將展示在您的領域中訓練的模型可以超越任何通用重排器模型,即使這些基線模型大得多。不要低估在您的領域中進行微調的力量!
訓練元件
訓練重排器模型涉及以下元件:
- 資料集:用於訓練和/或評估的資料。
- 損失函式:衡量模型效能並指導最佳化過程的函式。
- 訓練引數(可選):影響訓練效能、跟蹤和除錯的引數。
- 評估器(可選):用於在訓練前、訓練中或訓練後評估模型的類。
- 訓練器:將所有訓練元件整合在一起。
讓我們仔細看看每個元件。
資料集
CrossEncoderTrainer
使用datasets.Dataset
或datasets.DatasetDict
例項進行訓練和評估。您可以從Hugging Face Datasets Hub載入資料,或者使用您喜歡的任何格式(例如CSV、JSON、Parquet、Arrow或SQL)的本地資料。
注意: 許多可以直接與Sentence Transformers一起使用的公共資料集已在Hugging Face Hub上被標記為sentence-transformers
,因此您可以在https://huggingface.co/datasets?other=sentence-transformers上輕鬆找到它們。考慮瀏覽這些資料集,以找到可能對您的任務、領域或語言有用的現成資料集。
Hugging Face Hub上的資料
您可以使用load_dataset
函式從Hugging Face Hub中的資料集載入資料。
from datasets import load_dataset
train_dataset = load_dataset("sentence-transformers/natural-questions", split="train")
print(train_dataset)
"""
Dataset({
features: ['query', 'answer'],
num_rows: 100231
})
"""
一些資料集,如nthakur/swim-ir-monolingual
,有多個不同資料格式的子集。您需要指定子集名稱和資料集名稱,例如dataset = load_dataset("nthakur/swim-ir-monolingual", "de", split="train")
。
本地資料(CSV、JSON、Parquet、Arrow、SQL)
您也可以使用load_dataset
載入某些檔案格式的本地資料
from datasets import load_dataset
dataset = load_dataset("csv", data_files="my_file.csv")
# or
dataset = load_dataset("json", data_files="my_file.json")
需要預處理的本地資料
如果您的本地資料需要預處理,您可以使用datasets.Dataset.from_dict
。這允許您使用列表字典初始化資料集。
from datasets import Dataset
queries = []
documents = []
# Open a file, perform preprocessing, filtering, cleaning, etc.
# and append to the lists
dataset = Dataset.from_dict({
"query": queries,
"document": documents,
})
字典中的每個鍵都將成為結果資料集中的一列。
資料集格式
重要的是,您的資料集格式必須與您的損失函式匹配(或者您選擇與您的資料集格式和模型匹配的損失函式)。驗證資料集格式和模型是否與損失函式配合使用涉及三個步驟:
- 根據損失概述表,所有未命名為“label”、“labels”、“score”或“scores”的列都被視為*輸入*。剩餘列的數量必須與您選擇的損失函式的有效輸入數量匹配。
- 如果您的損失函式根據損失概述表需要一個*標籤*,那麼您的資料集必須有一個名為“label”、“labels”、“score”或“scores”的**列**。此列會自動作為標籤。
- 模型輸出標籤的數量與損失概述表根據損失函式的要求匹配。
例如,給定一個包含列["text1", "text2", "label"]
的資料集,其中“label”列的浮點相似度分數範圍為0到1,以及一個輸出1個標籤的模型,我們可以將其與BinaryCrossEntropyLoss
一起使用,因為:
- 資料集具有“label”列,這是此損失函式所必需的。
- 資料集有2個非標籤列,正好是此損失函式所需的數量。
- 模型有1個輸出標籤,正好是此損失函式所需的數量。
如果您的列順序不正確,請務必使用Dataset.select_columns
重新排序您的資料集列。例如,如果您的資料集有["good_answer", "bad_answer", "question"]
作為列,那麼此資料集技術上可以與需要(錨點,正例,負例)三元組的損失一起使用,但good_answer
列將作為錨點,bad_answer
作為正例,question
作為負例。
此外,如果您的資料集有多餘的列(例如sample_id、metadata、source、type),您應該使用Dataset.remove_columns
將其刪除,否則它們將被用作輸入。您還可以使用Dataset.select_columns
僅保留所需列。
難例挖掘
訓練重排器模型的成功通常取決於*負例*的質量,即查詢-負例得分應低的段落。負例可以分為兩種型別:
- 軟負例:完全不相關的段落。也稱為簡單負例。
- 難負例:看起來可能與查詢相關但實際上不相關的段落。
一個簡潔的例子是:
- 查詢:Apple在哪裡成立的?
- 軟負例:卡什河大橋是一座帕克桁架橋,橫跨阿肯色州核桃嶺和帕拉古爾德之間的卡什河。
- 難負例:富士蘋果是一種在20世紀30年代後期開發,並於1962年上市的蘋果栽培品種。
最強的交叉編碼器模型通常經過訓練來識別難負例,因此能夠“挖掘”難負例進行訓練是很有價值的。Sentence Transformers支援一個強大的mine_hard_negatives
函式,可以在給定查詢-答案對資料集的情況下提供幫助:
from datasets import load_dataset
from sentence_transformers import SentenceTransformer
from sentence_transformers.util import mine_hard_negatives
# Load the GooAQ dataset: https://huggingface.co/datasets/sentence-transformers/gooaq
train_dataset = load_dataset("sentence-transformers/gooaq", split="train").select(range(100_000))
print(train_dataset)
# Mine hard negatives using a very efficient embedding model
embedding_model = SentenceTransformer("sentence-transformers/static-retrieval-mrl-en-v1", device="cpu")
hard_train_dataset = mine_hard_negatives(
train_dataset,
embedding_model,
num_negatives=5, # How many negatives per question-answer pair
range_min=10, # Skip the x most similar samples
range_max=100, # Consider only the x most similar samples
max_score=0.8, # Only consider samples with a similarity score of at most x
margin=0.1, # Similarity between query and negative samples should be x lower than query-positive similarity
sampling_strategy="top", # Randomly sample negatives from the range
batch_size=4096, # Use a batch size of 4096 for the embedding model
output_format="labeled-pair", # The output format is (query, passage, label), as required by BinaryCrossEntropyLoss
use_faiss=True, # Using FAISS is recommended to keep memory usage low (pip install faiss-gpu or pip install faiss-cpu)
)
print(hard_train_dataset)
print(hard_train_dataset[1])
點選檢視此指令碼的輸出。
Dataset({
features: ['question', 'answer'],
num_rows: 100000
})
Batches: 100%|████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 22/22 [00:01<00:00, 13.74it/s]
Batches: 100%|████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 25/25 [00:00<00:00, 36.49it/s]
Querying FAISS index: 100%|█████████████████████████████████████████████████████████████████████████████████████████████████████████| 7/7 [00:19<00:00, 2.80s/it]
Metric Positive Negative Difference
Count 100,000 436,925
Mean 0.5882 0.4040 0.2157
Median 0.5989 0.4024 0.1836
Std 0.1425 0.0905 0.1013
Min -0.0514 0.1405 0.1014
25% 0.4993 0.3377 0.1352
50% 0.5989 0.4024 0.1836
75% 0.6888 0.4681 0.2699
Max 0.9748 0.7486 0.7545
Skipped 2420871 potential negatives (23.97%) due to the margin of 0.1.
Skipped 43 potential negatives (0.00%) due to the maximum score of 0.8.
Could not find enough negatives for 63075 samples (12.62%). Consider adjusting the range_max, range_min, margin and max_score parameters if you'd like to find more valid negatives.
Dataset({
features: ['question', 'answer', 'label'],
num_rows: 536925
})
{
'question': 'how to transfer bookmarks from one laptop to another?',
'answer': 'Using an External Drive Just about any external drive, including a USB thumb drive, or an SD card can be used to transfer your files from one laptop to another. Connect the drive to your old laptop; drag your files to the drive, then disconnect it and transfer the drive contents onto your new laptop.',
'label': 0
}
損失函式
損失函式有助於評估模型在一組資料上的效能並指導訓練過程。適用於您任務的正確損失函式取決於您擁有的資料以及您想要實現的目標。您可以在損失概述中找到可用損失函式的完整列表。
大多數損失函式都易於設定——您只需提供您正在訓練的CrossEncoder
模型:
from datasets import load_dataset
from sentence_transformers import CrossEncoder
from sentence_transformers.cross_encoder.losses import CachedMultipleNegativesRankingLoss
# Load a model to train/finetune
model = CrossEncoder("xlm-roberta-base", num_labels=1) # num_labels=1 is for rerankers
# Initialize the CachedMultipleNegativesRankingLoss, which requires pairs of
# related texts or triplets
loss = CachedMultipleNegativesRankingLoss(model)
# Load an example training dataset that works with our loss function:
train_dataset = load_dataset("sentence-transformers/gooaq", split="train")
...
訓練引數
您可以使用CrossEncoderTrainingArguments
類自定義訓練過程。此類別允許您調整可能影響訓練速度並幫助您瞭解訓練期間發生的事情的引數。
有關最有用的訓練引數的更多資訊,請檢視交叉編碼器 > 訓練概述 > 訓練引數。值得一讀,以充分利用您的訓練。
以下是如何設定CrossEncoderTrainingArguments
的示例:
from sentence_transformers.cross_encoder import CrossEncoderTrainingArguments
args = CrossEncoderTrainingArguments(
# Required parameter:
output_dir="models/reranker-MiniLM-msmarco-v1",
# Optional training parameters:
num_train_epochs=1,
per_device_train_batch_size=16,
per_device_eval_batch_size=16,
learning_rate=2e-5,
warmup_ratio=0.1,
fp16=True, # Set to False if you get an error that your GPU can't run on FP16
bf16=False, # Set to True if you have a GPU that supports BF16
batch_sampler=BatchSamplers.NO_DUPLICATES, # losses that use "in-batch negatives" benefit from no duplicates
# Optional tracking/debugging parameters:
eval_strategy="steps",
eval_steps=100,
save_strategy="steps",
save_steps=100,
save_total_limit=2,
logging_steps=100,
run_name="reranker-MiniLM-msmarco-v1", # Will be used in W&B if `wandb` is installed
)
評估器
為了在訓練期間跟蹤模型的效能,您可以將eval_dataset
傳遞給CrossEncoderTrainer
。但是,您可能需要除評估損失之外更詳細的指標。這就是評估器可以幫助您在訓練的不同階段使用特定指標評估模型效能的地方。您可以根據需要使用評估資料集、評估器、兩者或都不用。評估策略和頻率由eval_strategy
和eval_steps
訓練引數控制。
Sentence Transformers包含以下內建評估器:
評估器 | 所需資料 |
---|---|
CrossEncoderClassificationEvaluator |
帶有類標籤的對(二分類或多分類) |
CrossEncoderCorrelationEvaluator |
帶有相似度分數的對 |
CrossEncoderNanoBEIREvaluator |
無需資料 |
CrossEncoderRerankingEvaluator |
{'query': '...', 'positive': [...], 'negative': [...]} 字典列表。負例可以使用mine_hard_negatives 挖掘。 |
您還可以使用SequentialEvaluator
將多個評估器組合成一個,然後將其傳遞給CrossEncoderTrainer
。您也可以直接將評估器列表傳遞給訓練器。
有時您沒有所需的評估資料來自行準備這些評估器,但您仍然希望跟蹤模型在某些常見基準上的表現。在這種情況下,您可以將這些評估器與來自Hugging Face的資料一起使用。
使用STSb的CrossEncoderCorrelationEvaluator
STS基準測試(又稱STSb)是一個常用基準資料集,用於衡量模型對“一個人正在給蛇喂老鼠”等短文字語義相似度的理解。
歡迎瀏覽Hugging Face上的sentence-transformers/stsb資料集。
from datasets import load_dataset
from sentence_transformers import CrossEncoder
from sentence_transformers.cross_encoder.evaluation import CrossEncoderCorrelationEvaluator
# Load a model
model = CrossEncoder("cross-encoder/stsb-TinyBERT-L4")
# Load the STSB dataset (https://huggingface.co/datasets/sentence-transformers/stsb)
eval_dataset = load_dataset("sentence-transformers/stsb", split="validation")
pairs = list(zip(eval_dataset["sentence1"], eval_dataset["sentence2"]))
# Initialize the evaluator
dev_evaluator = CrossEncoderCorrelationEvaluator(
sentence_pairs=pairs,
scores=eval_dataset["score"],
name="sts_dev",
)
# You can run evaluation like so:
# results = dev_evaluator(model)
# Later, you can provide this evaluator to the trainer to get results during training
使用GooAQ挖掘負例的CrossEncoderRerankingEvaluator
為CrossEncoderRerankingEvaluator
準備資料可能很困難,因為除了查詢-正例資料外,您還需要負例。
mine_hard_negatives
函式有一個方便的include_positives
引數,可以將其設定為True
以同時挖掘正例文字。當將其作為documents
(必須是1. 已排序且2. 包含正例)提供給CrossEncoderRerankingEvaluator
時,評估器不僅會評估交叉編碼器的重排效能,還會評估用於挖掘的嵌入模型原始排名。
例如:
CrossEncoderRerankingEvaluator: Evaluating the model on the gooaq-dev dataset:
Queries: 1000 Positives: Min 1.0, Mean 1.0, Max 1.0 Negatives: Min 49.0, Mean 49.1, Max 50.0
Base -> Reranked
MAP: 53.28 -> 67.28
MRR@10: 52.40 -> 66.65
NDCG@10: 59.12 -> 71.35
請注意,預設情況下,如果您使用帶有documents
的CrossEncoderRerankingEvaluator
,評估器將使用*所有*正例進行重排,即使它們不在文件中。這對於從評估器中獲得更強的訊號很有用,但確實會給出略微不切實際的效能。畢竟,最大效能現在是100,而通常它的上限取決於第一階段檢索器是否實際檢索到了正例。
您可以透過在初始化CrossEncoderRerankingEvaluator
時設定always_rerank_positives=False
來啟用真實行為。使用這種真實的兩階段效能重複相同的指令碼會得到:
CrossEncoderRerankingEvaluator: Evaluating the model on the gooaq-dev dataset:
Queries: 1000 Positives: Min 1.0, Mean 1.0, Max 1.0 Negatives: Min 49.0, Mean 49.1, Max 50.0
Base -> Reranked
MAP: 53.28 -> 66.12
MRR@10: 52.40 -> 65.61
NDCG@10: 59.12 -> 70.10
from datasets import load_dataset
from sentence_transformers import SentenceTransformer
from sentence_transformers.cross_encoder import CrossEncoder
from sentence_transformers.cross_encoder.evaluation import CrossEncoderRerankingEvaluator
from sentence_transformers.util import mine_hard_negatives
# Load a model
model = CrossEncoder("cross-encoder/ms-marco-MiniLM-L6-v2")
# Load the GooAQ dataset: https://huggingface.co/datasets/sentence-transformers/gooaq
full_dataset = load_dataset("sentence-transformers/gooaq", split=f"train").select(range(100_000))
dataset_dict = full_dataset.train_test_split(test_size=1_000, seed=12)
train_dataset = dataset_dict["train"]
eval_dataset = dataset_dict["test"]
print(eval_dataset)
"""
Dataset({
features: ['question', 'answer'],
num_rows: 1000
})
"""
# Mine hard negatives using a very efficient embedding model
embedding_model = SentenceTransformer("sentence-transformers/static-retrieval-mrl-en-v1", device="cpu")
hard_eval_dataset = mine_hard_negatives(
eval_dataset,
embedding_model,
corpus=full_dataset["answer"], # Use the full dataset as the corpus
num_negatives=50, # How many negatives per question-answer pair
batch_size=4096, # Use a batch size of 4096 for the embedding model
output_format="n-tuple", # The output format is (query, positive, negative1, negative2, ...) for the evaluator
include_positives=True, # Key: Include the positive answer in the list of negatives
use_faiss=True, # Using FAISS is recommended to keep memory usage low (pip install faiss-gpu or pip install faiss-cpu)
)
print(hard_eval_dataset)
"""
Dataset({
features: ['question', 'answer', 'negative_1', 'negative_2', 'negative_3', 'negative_4', 'negative_5', 'negative_6', 'negative_7', 'negative_8', 'negative_9', 'negative_10', 'negative_11', 'negative_12', 'negative_13', 'negative_14', 'negative_15', 'negative_16', 'negative_17', 'negative_18', 'negative_19', 'negative_20', 'negative_21', 'negative_22', 'negative_23', 'negative_24', 'negative_25', 'negative_26', 'negative_27', 'negative_28', 'negative_29', 'negative_30', 'negative_31', 'negative_32', 'negative_33', 'negative_34', 'negative_35', 'negative_36', 'negative_37', 'negative_38', 'negative_39', 'negative_40', 'negative_41', 'negative_42', 'negative_43', 'negative_44', 'negative_45', 'negative_46', 'negative_47', 'negative_48', 'negative_49', 'negative_50'],
num_rows: 1000
})
"""
reranking_evaluator = CrossEncoderRerankingEvaluator(
samples=[
{
"query": sample["question"],
"positive": [sample["answer"]],
"documents": [sample[column_name] for column_name in hard_eval_dataset.column_names[2:]],
}
for sample in hard_eval_dataset
],
batch_size=32,
name="gooaq-dev",
)
# You can run evaluation like so
results = reranking_evaluator(model)
"""
CrossEncoderRerankingEvaluator: Evaluating the model on the gooaq-dev dataset:
Queries: 1000 Positives: Min 1.0, Mean 1.0, Max 1.0 Negatives: Min 49.0, Mean 49.1, Max 50.0
Base -> Reranked
MAP: 53.28 -> 67.28
MRR@10: 52.40 -> 66.65
NDCG@10: 59.12 -> 71.35
"""
# {'gooaq-dev_map': 0.6728370126462222, 'gooaq-dev_mrr@10': 0.6665190476190477, 'gooaq-dev_ndcg@10': 0.7135068904582963, 'gooaq-dev_base_map': 0.5327714512001362, 'gooaq-dev_base_mrr@10': 0.5239674603174603, 'gooaq-dev_base_ndcg@10': 0.5912299141913905}
訓練器
CrossEncoderTrainer
是所有先前元件的集合。我們只需指定訓練器與模型、訓練引數(可選)、訓練資料集、評估資料集(可選)、損失函式、評估器(可選),然後就可以開始訓練了。讓我們看看一個所有這些元件都結合在一起的指令碼:
import logging
import traceback
import torch
from datasets import load_dataset
from sentence_transformers import SentenceTransformer
from sentence_transformers.cross_encoder import (
CrossEncoder,
CrossEncoderModelCardData,
CrossEncoderTrainer,
CrossEncoderTrainingArguments,
)
from sentence_transformers.cross_encoder.evaluation import (
CrossEncoderNanoBEIREvaluator,
CrossEncoderRerankingEvaluator,
)
from sentence_transformers.cross_encoder.losses.BinaryCrossEntropyLoss import BinaryCrossEntropyLoss
from sentence_transformers.evaluation.SequentialEvaluator import SequentialEvaluator
from sentence_transformers.util import mine_hard_negatives
# Set the log level to INFO to get more information
logging.basicConfig(format="%(asctime)s - %(message)s", datefmt="%Y-%m-%d %H:%M:%S", level=logging.INFO)
def main():
model_name = "answerdotai/ModernBERT-base"
train_batch_size = 16
num_epochs = 1
num_hard_negatives = 5 # How many hard negatives should be mined for each question-answer pair
# 1a. Load a model to finetune with 1b. (Optional) model card data
model = CrossEncoder(
model_name,
model_card_data=CrossEncoderModelCardData(
language="en",
license="apache-2.0",
model_name="ModernBERT-base trained on GooAQ",
),
)
print("Model max length:", model.max_length)
print("Model num labels:", model.num_labels)
# 2a. Load the GooAQ dataset: https://huggingface.co/datasets/sentence-transformers/gooaq
logging.info("Read the gooaq training dataset")
full_dataset = load_dataset("sentence-transformers/gooaq", split="train").select(range(100_000))
dataset_dict = full_dataset.train_test_split(test_size=1_000, seed=12)
train_dataset = dataset_dict["train"]
eval_dataset = dataset_dict["test"]
logging.info(train_dataset)
logging.info(eval_dataset)
# 2b. Modify our training dataset to include hard negatives using a very efficient embedding model
embedding_model = SentenceTransformer("sentence-transformers/static-retrieval-mrl-en-v1", device="cpu")
hard_train_dataset = mine_hard_negatives(
train_dataset,
embedding_model,
num_negatives=num_hard_negatives, # How many negatives per question-answer pair
margin=0, # Similarity between query and negative samples should be x lower than query-positive similarity
range_min=0, # Skip the x most similar samples
range_max=100, # Consider only the x most similar samples
sampling_strategy="top", # Sample the top negatives from the range
batch_size=4096, # Use a batch size of 4096 for the embedding model
output_format="labeled-pair", # The output format is (query, passage, label), as required by BinaryCrossEntropyLoss
use_faiss=True,
)
logging.info(hard_train_dataset)
# 2c. (Optionally) Save the hard training dataset to disk
# hard_train_dataset.save_to_disk("gooaq-hard-train")
# Load again with:
# hard_train_dataset = load_from_disk("gooaq-hard-train")
# 3. Define our training loss.
# pos_weight is recommended to be set as the ratio between positives to negatives, a.k.a. `num_hard_negatives`
loss = BinaryCrossEntropyLoss(model=model, pos_weight=torch.tensor(num_hard_negatives))
# 4a. Define evaluators. We use the CrossEncoderNanoBEIREvaluator, which is a light-weight evaluator for English reranking
nano_beir_evaluator = CrossEncoderNanoBEIREvaluator(
dataset_names=["msmarco", "nfcorpus", "nq"],
batch_size=train_batch_size,
)
# 4b. Define a reranking evaluator by mining hard negatives given query-answer pairs
# We include the positive answer in the list of negatives, so the evaluator can use the performance of the
# embedding model as a baseline.
hard_eval_dataset = mine_hard_negatives(
eval_dataset,
embedding_model,
corpus=full_dataset["answer"], # Use the full dataset as the corpus
num_negatives=30, # How many documents to rerank
batch_size=4096,
include_positives=True,
output_format="n-tuple",
use_faiss=True,
)
logging.info(hard_eval_dataset)
reranking_evaluator = CrossEncoderRerankingEvaluator(
samples=[
{
"query": sample["question"],
"positive": [sample["answer"]],
"documents": [sample[column_name] for column_name in hard_eval_dataset.column_names[2:]],
}
for sample in hard_eval_dataset
],
batch_size=train_batch_size,
name="gooaq-dev",
always_rerank_positives=False,
)
# 4c. Combine the evaluators & run the base model on them
evaluator = SequentialEvaluator([reranking_evaluator, nano_beir_evaluator])
evaluator(model)
# 5. Define the training arguments
short_model_name = model_name if "/" not in model_name else model_name.split("/")[-1]
run_name = f"reranker-{short_model_name}-gooaq-bce"
args = CrossEncoderTrainingArguments(
# Required parameter:
output_dir=f"models/{run_name}",
# Optional training parameters:
num_train_epochs=num_epochs,
per_device_train_batch_size=train_batch_size,
per_device_eval_batch_size=train_batch_size,
learning_rate=2e-5,
warmup_ratio=0.1,
fp16=False, # Set to False if you get an error that your GPU can't run on FP16
bf16=True, # Set to True if you have a GPU that supports BF16
dataloader_num_workers=4,
load_best_model_at_end=True,
metric_for_best_model="eval_gooaq-dev_ndcg@10",
# Optional tracking/debugging parameters:
eval_strategy="steps",
eval_steps=4000,
save_strategy="steps",
save_steps=4000,
save_total_limit=2,
logging_steps=1000,
logging_first_step=True,
run_name=run_name, # Will be used in W&B if `wandb` is installed
seed=12,
)
# 6. Create the trainer & start training
trainer = CrossEncoderTrainer(
model=model,
args=args,
train_dataset=hard_train_dataset,
loss=loss,
evaluator=evaluator,
)
trainer.train()
# 7. Evaluate the final model, useful to include these in the model card
evaluator(model)
# 8. Save the final model
final_output_dir = f"models/{run_name}/final"
model.save_pretrained(final_output_dir)
# 9. (Optional) save the model to the Hugging Face Hub!
# It is recommended to run `huggingface-cli login` to log into your Hugging Face account first
try:
model.push_to_hub(run_name)
except Exception:
logging.error(
f"Error uploading model to the Hugging Face Hub:\n{traceback.format_exc()}To upload it manually, you can run "
f"`huggingface-cli login`, followed by loading the model using `model = CrossEncoder({final_output_dir!r})` "
f"and saving it using `model.push_to_hub('{run_name}')`."
)
if __name__ == "__main__":
main()
在此示例中,我正在從answerdotai/ModernBERT-base
進行微調,這是一個尚未成為交叉編碼器模型的基礎模型。這通常比微調現有重排器模型(如Alibaba-NLP/gte-multilingual-reranker-base
)需要更多的訓練資料。我使用了來自GooAQ資料集的99k個查詢-答案對,之後我使用sentence-transformers/static-retrieval-mrl-en-v1嵌入模型挖掘難負例。這導致了578k個帶標籤的對:99k個正例對(即標籤=1)和479k個負例對(即標籤=0)。
我使用了BinaryCrossEntropyLoss
,它非常適合這些帶標籤的對。我還設定了兩種評估形式:CrossEncoderNanoBEIREvaluator
用於評估NanoBEIR基準,以及CrossEncoderRerankingEvaluator
用於評估上述靜態嵌入模型對前30個結果進行重排的效能。之後,我定義了一組相當標準的超引數,包括學習率、預熱比率、bf16、最後載入最佳模型以及一些除錯引數。最後,我運行了訓練器,進行了訓練後評估,並將模型儲存到本地和Hugging Face Hub。
執行此指令碼後,tomaarsen/reranker-ModernBERT-base-gooaq-bce模型已為我上傳。請參閱即將到來的評估部分,其中有證據表明該模型優於13種常用開源替代方案,包括更大的模型。我還使用answerdotai/ModernBERT-large
作為基礎模型運行了該模型,結果是tomaarsen/reranker-ModernBERT-large-gooaq-bce。
評估結果會自動儲存在生成的模型卡中,模型卡中包含基礎模型、語言、許可證、評估結果、訓練和評估資料集資訊、超引數、訓練日誌等。無需任何額外工作,您上傳的模型將包含您的潛在使用者確定模型是否適合他們所需的所有資訊。
回撥
交叉編碼器訓練器支援各種transformers.TrainerCallback
子類,包括:
- 如果安裝了
wandb
,則使用WandbCallback
將訓練指標記錄到W&B。 - 如果可以訪問
tensorboard
,則使用TensorBoardCallback
將訓練指標記錄到TensorBoard。 - 如果安裝了
codecarbon
,則使用CodeCarbonCallback
跟蹤訓練期間的碳排放。
只要安裝了所需的依賴項,這些功能就會自動使用,您無需進行任何指定。
有關這些回撥以及如何建立自己的回撥的更多資訊,請參閱Transformers回撥文件。
多資料集訓練
通常,表現最佳的通用模型是同時在多個數據集上訓練的。然而,由於每個資料集的格式不同,這種方法可能具有挑戰性。幸運的是,CrossEncoderTrainer
允許您在多個數據集上進行訓練,而無需統一格式。此外,它提供了為每個資料集應用不同損失函式的靈活性。以下是同時使用多個數據集進行訓練的步驟:
- 使用
datasets.Dataset
例項字典(或datasets.DatasetDict
)作為train_dataset
(以及可選的eval_dataset
)。 - (可選)使用損失函式字典,將資料集名稱對映到損失。僅當您希望為不同資料集使用不同損失函式時才需要。
每個訓練/評估批次將只包含來自一個數據集的樣本。從多個數據集中取樣批次的順序由MultiDatasetBatchSamplers
列舉定義,該列舉可以透過multi_dataset_batch_sampler
傳遞給CrossEncoderTrainingArguments
。有效選項包括:
MultiDatasetBatchSamplers.ROUND_ROBIN
:從每個資料集迴圈取樣,直到其中一個耗盡。使用此策略,可能不會使用每個資料集中的所有樣本,但每個資料集的取樣頻率相同。MultiDatasetBatchSamplers.PROPORTIONAL
(預設):根據每個資料集的大小按比例取樣。使用此策略,將使用每個資料集中的所有樣本,並且從較大的資料集中取樣的頻率更高。
訓練技巧
交叉編碼器模型有其獨特的特點,因此這裡有一些技巧可以幫助您:
交叉編碼器模型很容易過擬合,因此建議使用像
CrossEncoderNanoBEIREvaluator
或CrossEncoderRerankingEvaluator
這樣的評估器,並結合load_best_model_at_end
和metric_for_best_model
訓練引數,以便在訓練結束後加載具有最佳評估效能的模型。交叉編碼器對強硬負例(
mine_hard_negatives
)特別敏感。它們教導模型非常嚴格,例如在區分回答問題和與問題相關的段落時很有用。- 請注意,如果您只使用難負例,您的模型在較簡單任務上的表現可能會出人意料地變差。這可能意味著,對第一階段檢索系統(例如使用SentenceTransformer模型)檢索到的前200個結果進行重排,實際上可能比重排前100個結果得到更差的前10個結果。同時使用隨機負例和難負例進行訓練可以緩解這種情況。
不要低估
BinaryCrossEntropyLoss
的力量,儘管它比學習排序(LambdaLoss、ListNetLoss)或批內負例(CachedMultipleNegativesRankingLoss、MultipleNegativesRankingLoss)損失更簡單,但它仍然是一個非常強大的選擇,並且其資料易於準備,尤其是在使用mine_hard_negatives
時。
評估
我對我模型在訓練器部分中的重排評估,與GooAQ開發集上的幾個基線進行了比較,重排評估器中同時使用了always_rerank_positives=False
和always_rerank_positives=True
。這分別代表了真實(僅重排檢索器找到的內容)和評估(重排所有正例,即使檢索器未找到)格式。
提醒一下,我使用了極其高效的sentence-transformers/static-retrieval-mrl-en-v1
靜態嵌入模型來檢索前30個用於重排。
模型 | 模型引數 | 重排前30後GooAQ NDCG@10 | 重排前30+所有正例後GooAQ NDCG@10 |
---|---|---|---|
無重排,僅檢索器 | - | 59.12 | 59.12 |
cross-encoder/ms-marco-MiniLM-L6-v2 | 22.7M | 69.56 | 72.09 |
jinaai/jina-reranker-v1-tiny-en | 33M | 66.83 | 69.54 |
jinaai/jina-reranker-v1-turbo-en | 37.8M | 72.01 | 76.10 |
jinaai/jina-reranker-v2-base-multilingual | 278M | 74.87 | 78.88 |
BAAI/bge-reranker-base | 278M | 70.98 | 74.31 |
BAAI/bge-reranker-large | 560M | 73.20 | 77.46 |
BAAI/bge-reranker-v2-m3 | 568M | 73.56 | 77.55 |
mixedbread-ai/mxbai-rerank-xsmall-v1 | 70.8M | 66.63 | 69.41 |
mixedbread-ai/mxbai-rerank-base-v1 | 184M | 70.43 | 74.39 |
mixedbread-ai/mxbai-rerank-large-v1 | 435M | 74.03 | 78.66 |
mixedbread-ai/mxbai-rerank-base-v2 | 494M | 73.03 | 76.76 |
mixedbread-ai/mxbai-rerank-large-v2 | 1.54B | 75.40 | 80.04 |
Alibaba-NLP/gte-reranker-modernbert-base | 150M | 73.18 | 77.49 |
tomaarsen/reranker-ModernBERT-base-gooaq-bce | 150M | 77.14 | 83.51 |
tomaarsen/reranker-ModernBERT-large-gooaq-bce | 396M | 79.42 | 85.81 |
點選檢視評估指令碼和資料集
這是評估指令碼:
import logging
from pprint import pprint
from datasets import load_dataset
from sentence_transformers.cross_encoder import CrossEncoder
from sentence_transformers.cross_encoder.evaluation import CrossEncoderRerankingEvaluator
# Set the log level to INFO to get more information
logging.basicConfig(format="%(asctime)s - %(message)s", datefmt="%Y-%m-%d %H:%M:%S", level=logging.INFO)
def main():
model_name = "tomaarsen/reranker-ModernBERT-base-gooaq-bce"
eval_batch_size = 64
# 1. Load a model to evaluate
model = CrossEncoder(model_name)
# 2. Load the GooAQ dataset: https://huggingface.co/datasets/tomaarsen/gooaq-reranker-blogpost-datasets
logging.info("Read the gooaq reranking dataset")
hard_eval_dataset = load_dataset("tomaarsen/gooaq-reranker-blogpost-datasets", "rerank", split="eval")
# 4. Create reranking evaluators. We use `always_rerank_positives=False` for a realistic evaluation
# where only all top 30 documents are reranked, and `always_rerank_positives=True` for an evaluation
# where the positive answer is always reranked as well.
samples = [
{
"query": sample["question"],
"positive": [sample["answer"]],
"documents": [sample[column_name] for column_name in hard_eval_dataset.column_names[2:]],
}
for sample in hard_eval_dataset
]
reranking_evaluator = CrossEncoderRerankingEvaluator(
samples=samples,
batch_size=eval_batch_size,
name="gooaq-dev-realistic",
always_rerank_positives=False,
)
realistic_results = reranking_evaluator(model)
pprint(realistic_results)
reranking_evaluator = CrossEncoderRerankingEvaluator(
samples=samples,
batch_size=eval_batch_size,
name="gooaq-dev-evaluation",
always_rerank_positives=True,
)
evaluation_results = reranking_evaluator(model)
pprint(evaluation_results)
if __name__ == "__main__":
main()
它使用了我的tomaarsen/gooaq-reranker-blogpost-datasets
資料集中的rerank
子集。該資料集包含:
pair
子集,train
分割:99k個訓練樣本直接取自GooAQ。這不直接用於訓練,而是用於準備hard-labeled-pair
子集,後者用於訓練。pair
子集,eval
分割:1k個訓練樣本直接取自GooAQ,與之前的99k沒有重疊。這不直接用於評估,而是用於準備rerank
子集,後者用於評估。hard-labeled-pair
子集,train
分割:578k個帶標籤的對用於訓練,透過使用來自pair
子集和train
分割的99k個樣本與sentence-transformers/static-retrieval-mrl-en-v1進行挖掘。該資料集用於訓練。rerank
子集,eval
分割:1k個樣本,包含問題、答案以及由sentence-transformers/static-retrieval-mrl-en-v1使用我的GooAQ子集中完整的100k訓練和評估答案檢索到的30個文件。該排名已經具有59.12的NDCG@10。
僅使用gooaq資料集中300萬訓練對中的9.9萬對,並在我的RTX 3090上僅訓練30分鐘,我的小型1.5億引數tomaarsen/reranker-ModernBERT-base-gooaq-bce模型就輕鬆超越了所有小於10億引數的通用重排器。更大的tomaarsen/reranker-ModernBERT-large-gooaq-bce訓練時間不到一小時,並在實際設定中以高達79.42的NDCG@10獨佔鰲頭。GooAQ訓練和評估資料集與這些基線模型的訓練目標非常吻合,因此在更小眾的領域進行訓練時,差異應該更大。
請注意,這並不意味著tomaarsen/reranker-ModernBERT-large-gooaq-bce是*所有*領域中最強的模型:它只是*我們*領域中最強的。這完全沒有問題,因為我們只需要這個重排器在我們的資料上表現良好。
不要低估在您的領域中微調重排器模型的力量。透過微調(小型)重排器,您可以同時提高搜尋效能和搜尋堆疊的延遲!
附加資源
訓練示例
這些頁面包含帶解釋的訓練示例以及訓練指令碼程式碼連結。您可以使用它們來熟悉重排器訓練迴圈:
文件
如需進一步學習,您可能還希望探索Sentence Transformers上的以下資源:
這裡有一個您可能感興趣的高階頁面: