使用Sentence Transformers訓練速度提高400倍的靜態嵌入模型

釋出於2025年1月15日
在 GitHub 上更新

總結

本部落格文章介紹了一種訓練靜態嵌入模型的方法,該模型在CPU上的執行速度比最先進的嵌入模型快100到400倍,同時保持了大部分質量。這開啟了許多令人興奮的應用場景,包括裝置上和瀏覽器內執行、邊緣計算、低功耗和嵌入式應用。

我們應用此方法訓練了兩個極其高效的嵌入模型:用於英語檢索的sentence-transformers/static-retrieval-mrl-en-v1,以及用於多語言相似度任務的sentence-transformers/static-similarity-mrl-multilingual-v1。這些模型在CPU上比all-mpnet-base-v2multilingual-e5-small等常見模型快100到400倍,同時在各種基準測試中至少達到其效能的85%

今天,我們釋出:

  • 上述兩個模型(用於英語檢索和多語言相似度)。
  • 我們遵循的詳細訓練策略,從構思到資料集選擇,再到實施和評估。
  • 兩個基於開源sentence transformers庫的訓練指令碼。
  • 兩份包含訓練期間收集的訓練和評估指標的Weights and Biases報告。
  • 我們使用的詳細資料集列表:30個用於訓練,13個用於評估。

我們還討論了潛在的改進,並鼓勵社群探索這些改進並在此工作的基礎上進行構建!

點選檢視已釋出模型的使用片段

這些模型的使用非常簡單,與常規的Sentence Transformers流程相同

英語檢索

from sentence_transformers import SentenceTransformer

# Download from the 🤗 Hub
model = SentenceTransformer("sentence-transformers/static-retrieval-mrl-en-v1", device="cpu")
# Run inference
sentences = [
    'Gadofosveset-enhanced MR angiography of carotid arteries: does steady-state imaging improve accuracy of first-pass imaging?',
    'To evaluate the diagnostic accuracy of gadofosveset-enhanced magnetic resonance (MR) angiography in the assessment of carotid artery stenosis, with digital subtraction angiography (DSA) as the reference standard, and to determine the value of reading first-pass, steady-state, and "combined" (first-pass plus steady-state) MR angiograms.',
    'In a longitudinal study we investigated in vivo alterations of CVO during neuroinflammation, applying Gadofluorine M- (Gf) enhanced magnetic resonance imaging (MRI) in experimental autoimmune encephalomyelitis, an animal model of multiple sclerosis. SJL/J mice were monitored by Gadopentate dimeglumine- (Gd-DTPA) and Gf-enhanced MRI after adoptive transfer of proteolipid-protein-specific T cells. Mean Gf intensity ratios were calculated individually for different CVO and correlated to the clinical disease course. Subsequently, the tissue distribution of fluorescence-labeled Gf as well as the extent of cellular inflammation was assessed in corresponding histological slices.',
]
embeddings = model.encode(sentences)
print(embeddings.shape)
# [3, 1024]

# Get the similarity scores for the embeddings
similarities = model.similarity(embeddings[0], embeddings[1:])
print(similarities)
# tensor([[0.7649, 0.3279]])

多語言相似度

from sentence_transformers import SentenceTransformer

# Download from the 🤗 Hub
model = SentenceTransformer("sentence-transformers/static-similarity-mrl-multilingual-v1", device="cpu")
# Run inference
sentences = [
    'It is known for its dry red chili powder.',
    'It is popular for dried red chili powder.',
    'These monsters will move in large groups.',
]
embeddings = model.encode(sentences)
print(embeddings.shape)
# [3, 1024]

# Get the similarity scores for the embeddings
similarities = model.similarity(embeddings, embeddings)
print(similarities)
# tensor([[ 1.0000,  0.8388, -0.0012],
#         [ 0.8388,  1.0000,  0.0445],
#         [-0.0012,  0.0445,  1.0000]])

NanoBEIR performance vs inference speed

目錄

什麼是嵌入?

嵌入是自然語言處理中最通用的工具之一,使從業者能夠解決各種任務。本質上,嵌入是更復雜物件(如文字、影像、音訊等)的數值表示。

embedding model

嵌入模型總是生成相同固定大小的嵌入。然後,您可以透過計算各個嵌入的相似度來計算複雜物件的相似度。

embedding similarity

這有大量的用例,並作為推薦系統、檢索、異常檢測、單次或少次學習、相似度搜索、聚類、釋義檢測、分類等的基礎。

現代嵌入

當今許多嵌入模型都包含少量轉換步驟。遵循這些步驟稱為“推理”。

embedding pipeline

`Tokenizer` 和 `Pooler` 分別負責 `Encoder` 的預處理和後處理。前者將文字切分為 `Encoder` 可理解的標記(又稱單詞或子詞),而後者將所有標記的嵌入組合成整個文字的一個嵌入。

在此管道中,`Encoder` 通常是一個帶有注意力層的語言模型,它允許每個標記在其他標記的**上下文**中進行計算。例如,`bank` 可能是一個標記,但如果文字指的是“河岸”或金融機構,則該標記的嵌入可能會有所不同。

具有許多注意力層的大型編碼器模型將有效地利用上下文來生成有用的嵌入,但這樣做的代價是推理速度慢。值得注意的是,在管道中,`Encoder` 步驟通常佔據了幾乎所有計算時間。

靜態嵌入

靜態嵌入指不使用大型緩慢的基於注意力的模型,而是依賴預計算標記嵌入的`Encoder`模型組。靜態嵌入在 Transformer 架構開發之前就已使用多年。常見示例包括GLoVeword2vec。最近,Model2Vec已被用於將預訓練嵌入模型轉換為靜態嵌入模型。

對於靜態嵌入,`Encoder` 步驟就像字典查詢一樣簡單:給定標記,返回預計算的標記嵌入。因此,推理突然不再受 `Encoder` 階段的瓶頸,從而使速度提高**幾個數量級**。這篇部落格文章表明,對質量的影響可以非常小!

我們的方法

我們著手使用現代技術重新審視靜態嵌入模型並訓練它們。我們的大部分收益來自對比學習損失函式的使用,我們將很快解釋。此外,透過使用套娃表示學習,我們可以獲得額外的速度改進,這使得使用嵌入向量的截斷版本成為可能。

我們將使用Sentence Transformers庫進行訓練。有關此庫如何用於訓練嵌入模型的更一般概述,請考慮閱讀使用Sentence Transformers v3訓練和微調嵌入模型部落格文章或Sentence Transformers訓練概述文件

訓練細節

重新構想靜態嵌入的目標是,在這些高效嵌入模型上試驗現代嵌入模型微調技術。特別是,與 GLoVe 和 word2vec 不同,我們將使用:

  1. **對比學習**:在大多數機器學習中,您輸入 $X$ 並期望輸出 $Y$,然後訓練模型,使透過模型輸入的 $X$ 產生接近 $Y$ 的結果。對於嵌入模型,我們沒有 $Y$:我們事先不知道好的嵌入是什麼。

    相反,在對比學習中,我們有多個輸入 $X_1$ 和 $X_2$,以及一個相似度。我們將兩個輸入都透過模型,之後我們可以**對比**生成的兩個嵌入,從而得到預測的相似度。如果真實相似度低,我們可以將嵌入推得更遠;如果真實相似度高,則可以將嵌入拉得更近。

  2. **套娃表示學習(MRL)**:套娃嵌入模型(部落格文章)是一種巧妙的訓練方法,允許使用者在效能損失最小的情況下將嵌入模型截斷為更小的維度。它不僅使用正常大小的嵌入進行對比損失函式計算,還使用其截斷版本。因此,模型學習將資訊主要儲存在嵌入的開頭。

    截斷後的嵌入將在下游應用(如檢索、分類和聚類)中更快。

對於未來的研究,我們留下了各種其他現代訓練方法以提高資料質量。請參閱下一步瞭解具體想法。

訓練要求

如Sentence Transformers中的訓練概述文件所示,訓練由3到5個元件組成:

  1. 資料集
  2. 損失函式
  3. 訓練引數(可選)
  4. 評估器(可選)
  5. 訓練器

在以下部分中,我們將詳細闡述我們對每個元件的思考過程。

模型靈感

根據我們的經驗,嵌入模型要麼1) 專門用於檢索,要麼2) 用於各種任務(分類、聚類、語義文字相似度等)。我們著手訓練了這兩種模型。

對於檢索模型,可用的多語言檢索訓練資料量有限,因此我們選擇僅使用英語模型。相反,我們決定訓練一個多語言通用相似度模型,因為對於此任務來說,多語言資料更容易獲取。

對於這些模型,我們希望使用StaticEmbedding模組,它實現了高效的tokenize方法以避免填充,以及高效的forward方法來計算和池化嵌入。這就像使用一個torchEmbeddingBag一樣簡單,它不過是一個高效的Embedding(即嵌入的查詢表)加上平均池化。

我們可以通過幾種方式初始化它:`StaticEmbedding.from_model2vec`載入Model2Vec 模型`StaticEmbedding.from_distillation`執行Model2Vec風格的蒸餾,或者使用`Tokenizer`和嵌入維度進行初始化以獲得隨機權重。

根據我們的發現,當使用大量資料進行完全訓練時,最後一個選項效果最好。為了匹配all-mpnet-base-v2bge-large-en-v1.5等常見模型,我們選擇將嵌入維度設定為1024,即我們的嵌入向量每個包含1024個值。

英語檢索

對於英語檢索模型,我們依賴google-bert/bert-base-uncased分詞器。因此,模型初始化如下所示:

from sentence_transformers import SentenceTransformer
from sentence_transformers.models import StaticEmbedding
from tokenizers import Tokenizer

tokenizer = Tokenizer.from_pretrained("google-bert/bert-base-uncased")
static_embedding = StaticEmbedding(tokenizer, embedding_dim=1024)

model = SentenceTransformer(modules=[static_embedding])

`modules` 列表中的第一個條目必須實現 `tokenize`,最後一個必須生成池化嵌入。這裡兩者都符合,所以我們可以開始訓練這個模型了。

多語言相似度

對於多語言相似度模型,我們轉而依賴`google-bert/bert-base-multilingual-uncased`分詞器,這是我們初始化程式碼中唯一改變的地方

from sentence_transformers import SentenceTransformer
from sentence_transformers.models import StaticEmbedding
from tokenizers import Tokenizer

tokenizer = Tokenizer.from_pretrained("google-bert/bert-base-multilingual-uncased")
static_embedding = StaticEmbedding(tokenizer, embedding_dim=1024)

model = SentenceTransformer(modules=[static_embedding])

訓練資料集選擇

除了數十個Sentence Transformer模型之外,Hugging Face上的Sentence Transformers組織還託管了70多個數據集(截至撰寫本文時)

除此之外,許多資料集已標記為`sentence-transformers`,以表明它們對訓練嵌入模型有用

英語檢索

對於英語檢索資料集,我們主要尋找具有以下特徵的任何資料集:

  • 問答對,可選地帶有負例(即錯誤答案),以及
  • 與BEIR基準(即MTEB上的檢索選項卡)沒有重疊。我們的目標是避免在這些資料集上進行訓練,以便我們可以將MTEB用作零樣本基準。

我們選擇了以下資料集:

多語言相似度

對於多語言相似度資料集,我們的目標是選擇包含以下特徵的資料集:

  • 跨語言的平行句子,即多種語言中的相同文字,或
  • 正例對,即具有高度相似性的對,可選地帶有負例(即低相似性)。

我們選擇了以下包含平行句子的資料集:

以及以下包含某種正例對的資料集:

程式碼

載入這些資料集相當簡單,例如:

from datasets import load_dataset, Dataset

gooaq_dataset = load_dataset("sentence-transformers/gooaq", split="train")
gooaq_dataset_dict = gooaq_dataset.train_test_split(test_size=10_000, seed=12)
gooaq_train_dataset: Dataset = gooaq_dataset_dict["train"]
gooaq_eval_dataset: Dataset = gooaq_dataset_dict["test"]

print(gooaq_train_dataset)
"""
Dataset({
    features: ['question', 'answer'],
    num_rows: 3002496
})
"""

print(gooaq_eval_dataset)
"""
Dataset({
    features: ['question', 'answer'],
    num_rows: 10000
})
"""

gooaq資料集尚未進行訓練-評估劃分,因此我們可以使用`train_test_split`建立一個。否則,我們可以直接載入預計算的劃分,例如使用`split="eval"`。

請注意,`train_test_split`確實意味著資料集必須載入到記憶體中,而否則它只保留在磁碟上。這種增加的記憶體對於訓練來說並不理想,因此建議1) 載入資料,2) 分割資料,3) 使用`save_to_disk`將其儲存到磁碟。在訓練之前,您可以然後使用`load_from_disk`再次載入它。

損失函式選擇

在 Sentence Transformers 中,您的損失模型必須與您的訓練資料格式匹配。損失概述旨在概述哪些損失與哪些格式相容。

特別是,我們目前的資料有以下格式:

  • (錨點, 正例) 對,無標籤
  • (錨點, 正例, 負例) 三元組,無標籤
  • (錨點, 正例, 負例_1, ..., 負例_n) 元組,無標籤

對於這些格式,我們有一些極佳的選擇:

  1. `MultipleNegativesRankingLoss` (MNRL):也稱為批內負樣本損失或 InfoNCE 損失,這種損失已用於訓練現代嵌入模型數年。簡而言之,該損失最佳化以下目標:

    給定一個錨點(例如一個問題),在批次中的所有正例和負例(例如所有答案)中,將最高相似度分配給對應的正例(即答案)。

    如果您提供可選的負例,它們將僅用作額外選項(也稱為批內負例),模型必須從中選擇正確的正例。在合理範圍內,這種“選擇”越困難,模型就會變得越強大。因此,更大的批次大小會產生更多的批內負例,從而提高效能(達到一定程度)。

  2. `CachedMultipleNegativesRankingLoss` (CMNRL):這是 MNRL 的一個擴充套件,它實現了 GradCache,這種方法允許任意增加批次大小而不增加記憶體。

    除非您已經可以使用MNRL在記憶體中容納足夠大的批處理大小,否則建議使用此損失而非MNRL。在這種情況下,您可以使用MNRL來節省CMNRL帶來的20%訓練速度成本。

  3. `GISTEmbedLoss` (GIST):這也是 MNRL 的一個擴充套件,它使用一個 `guide` Sentence Transformer 模型從模型必須“選擇”正確正例的選項列表中刪除潛在的假負例。

    假負例會損害效能,但難的正負例(接近正確但不完全正確的文字)可以幫助提高效能,因此這種過濾需要謹慎權衡。

由於這些靜態嵌入模型極其微小,我們可以在我們的硬體(一塊24GB視訊記憶體的RTX 3090)上輕鬆容納我們期望的2048樣本批次大小,因此我們不需要使用CMNRL。

此外,由於我們正在訓練如此快的模型,來自`GISTEmbedLoss`的`guide`會使訓練慢很多。因此,我們選擇為我們的模型使用`MultipleNegativesRankingLoss`

如果我們要再次嘗試這些實驗,我們會選擇更大的批次大小,例如使用 CMNRL 的 16384。如果您嘗試,請告訴我們結果如何!

程式碼

用法相當簡單:

from sentence_transformers import SentenceTransformer
from sentence_transformers.losses import MultipleNegativesRankingLoss

# Prepare a model to train
tokenizer = Tokenizer.from_pretrained("google-bert/bert-base-uncased")
static_embedding = StaticEmbedding(tokenizer, embedding_dim=1024)
model = SentenceTransformer(modules=[static_embedding])

# Initialize the MNRL loss given the model
loss = MultipleNegativesRankingLoss(model)

套娃表示學習

除了常規損失函式之外,Sentence Transformers 還實現了一些損失修飾符。它們在標準損失函式之上工作,但以不同的方式應用它們,以試圖向訓練好的嵌入模型注入有用的特性。

一個非常有趣的例子是`MatryoshkaLoss`,它將訓練好的模型轉換為一個**套娃模型**。這允許使用者在效能損失最小的情況下截斷輸出嵌入到更小的維度,這意味著由於維度更小,檢索或聚類可以加速。

程式碼

`MatryoshkaLoss` 應用於正常的損失之上。建議在 `matryoshka_dims` 列表中也包含正常的嵌入維度。

from sentence_transformers import SentenceTransformer
from sentence_transformers.losses import MultipleNegativesRankingLoss, MatryoshkaLoss

# Prepare a model to train
tokenizer = Tokenizer.from_pretrained("google-bert/bert-base-uncased")
static_embedding = StaticEmbedding(tokenizer, embedding_dim=1024)
model = SentenceTransformer(modules=[static_embedding])

# Initialize the MNRL loss given the model
base_loss = MultipleNegativesRankingLoss(model)
loss = MatryoshkaLoss(model, base_loss, matryoshka_dims=[1024, 768, 512, 256, 128, 64, 32])

訓練引數選擇

Sentence Transformers 支援大量的訓練引數,其中最有價值的引數已在訓練概述 > 訓練引數文件中列出。

我們使用相同的核心訓練引數來訓練兩個模型:

  • 訓練週期數: 1
    • 我們有足夠的資料,如果想訓練更多,可以新增更多資料,而不是多次訓練相同的資料。
  • `per_device_train_batch_size`/`per_device_eval_batch_size`: 2048
    • 2048 維度可以輕鬆地在我們的 RTX 3090 上執行。多篇論文(Xiao 等Li 等)表明,即使更大的批次大小也能提高效能。對於未來的版本,我們將使用 `CachedMultipleNegativesRankingLoss` 和更大的批次大小,例如 16384。
  • `learning_rate`: 2e-1
    • 注意!這比正常嵌入模型訓練的損失(通常約為 2e-5)**大得多**。
  • 預熱比率: 0.1
    • 0.1 或 10% 是一個非常標準的預熱比率,用於平滑地將高學習率引入模型。
  • `bf16`: True
    • 如果您的 GPU 支援 `bf16`,那麼使用它進行訓練通常是合理的。否則,如果支援 `fp16`,您可以使用 `fp16=True`。
  • `batch_sampler`: `BatchSamplers.NO_DUPLICATES`
    • 所有具有批內負例的損失(例如 MNRL)都受益於此批取樣器,它避免了批內重複。重複通常會導致假負例,從而削弱訓練後的模型。
  • `multi_dataset_batch_sampler`: `MultiDatasetBatchSamplers.PROPORTIONAL`
    • 當您使用多個數據集進行訓練時,資料集大小通常不相同。發生這種情況時,您可以選擇:
      • 迴圈:從每個資料集抽取相同數量的批次,直到其中一個耗盡。您將獲得均勻的資料分佈,但並非所有資料都將被使用。
      • 按比例:抽取每個資料集,直到所有資料集都耗盡。您將使用所有資料,但資料分佈不均勻。我們選擇了這種方式,因為我們不太關心資料不平衡問題。

除了這些核心引數之外,我們還設定了一些用於跟蹤和除錯的訓練引數:`eval_strategy`、`eval_steps`、`save_strategy`、`save_steps`、`save_total_limit`、`logging_steps`、`logging_first_step` 和 `run_name`。

程式碼

最終,我們為這兩個模型使用了這些`SentenceTransformerTrainingArguments`:

run_name = "static-retrieval-mrl-en-v1"
# or 
# run_name = "static-similarity-mrl-multilingual-v1"

args = SentenceTransformerTrainingArguments(
    # Required parameter:
    output_dir=f"models/{run_name}",
    # Optional training parameters:
    num_train_epochs=1,
    per_device_train_batch_size=2048,
    per_device_eval_batch_size=2048,
    learning_rate=2e-1,
    warmup_ratio=0.1,
    fp16=False,  # Set to False if you get an error that your GPU can't run on FP16
    bf16=True,  # Set to True if you have a GPU that supports BF16
    batch_sampler=BatchSamplers.NO_DUPLICATES,  # MultipleNegativesRankingLoss benefits from no duplicate samples in a batch
    multi_dataset_batch_sampler=MultiDatasetBatchSamplers.PROPORTIONAL,
    # Optional tracking/debugging parameters:
    eval_strategy="steps",
    eval_steps=1000,
    save_strategy="steps",
    save_steps=1000,
    save_total_limit=2,
    logging_steps=1000,
    logging_first_step=True,
    run_name=run_name,  # Used if `wandb`, `tensorboard`, or `neptune`, etc. is installed
)

評估器選擇

如果我們向 Sentence Transformer 訓練器提供一個評估資料集,那麼在評估時我們將得到一個評估損失。這對於跟蹤我們是否過擬合很有用,但在實際下游效能方面意義不大。

因此,Sentence Transformers 還支援評估器。與訓練損失不同,它們提供定性指標,例如資訊檢索的 NDCG、MAP、MRR,語義文字相似度的 Spearman 相關係數,或三元組準確率(`similarity(anchor, positive)` > `similarity(anchor, negative)` 的樣本數量)。

由於其簡單性,我們將為檢索模型使用`NanoBEIREvaluator`。該評估器在NanoBEIR資料集集合上執行資訊檢索基準測試。該資料集是更大的(因此更慢的)BEIR基準的子集,BEIR基準通常用作 MTEB 排行榜中的檢索選項卡。

程式碼

由於所有資料集都已預定義,我們可以無需任何引數載入評估器

from sentence_transformers import SentenceTransformer
from sentence_transformers.evaluation import NanoBEIREvaluator

# Load an example pre-trained model to finetune further
model = SentenceTransformer("sentence-transformers/all-MiniLM-L6-v2")

# Initialize the NanoBEIR Evaluator
evaluator = NanoBEIREvaluator()

# Run it on any Sentence Transformer model
evaluator(model)

硬體詳情

我們正在消費級硬體上訓練這些模型,具體如下:

  • GPU:RTX 3090
  • CPU:i7-13700K
  • 記憶體:32GB

總體訓練指令碼

本節包含兩個模型的最終訓練指令碼,其中結合了所有先前描述的元件(資料集、損失函式、訓練引數、評估器、訓練器)。

英語檢索

點選展開
import random
import logging
from datasets import load_dataset, Dataset, DatasetDict
from sentence_transformers import (
    SentenceTransformer,
    SentenceTransformerTrainer,
    SentenceTransformerTrainingArguments,
    SentenceTransformerModelCardData,
)
from sentence_transformers.losses import MatryoshkaLoss, MultipleNegativesRankingLoss
from sentence_transformers.training_args import BatchSamplers, MultiDatasetBatchSamplers
from sentence_transformers.evaluation import NanoBEIREvaluator
from sentence_transformers.models.StaticEmbedding import StaticEmbedding

from transformers import AutoTokenizer

logging.basicConfig(
    format="%(asctime)s - %(message)s", datefmt="%Y-%m-%d %H:%M:%S", level=logging.INFO
)
random.seed(12)


def load_train_eval_datasets():
    """
    Either load the train and eval datasets from disk or load them from the datasets library & save them to disk.

    Upon saving to disk, we quit() to ensure that the datasets are not loaded into memory before training.
    """
    try:
        train_dataset = DatasetDict.load_from_disk("datasets/train_dataset")
        eval_dataset = DatasetDict.load_from_disk("datasets/eval_dataset")
        return train_dataset, eval_dataset
    except FileNotFoundError:
        print("Loading gooaq dataset...")
        gooaq_dataset = load_dataset("sentence-transformers/gooaq", split="train")
        gooaq_dataset_dict = gooaq_dataset.train_test_split(test_size=10_000, seed=12)
        gooaq_train_dataset: Dataset = gooaq_dataset_dict["train"]
        gooaq_eval_dataset: Dataset = gooaq_dataset_dict["test"]
        print("Loaded gooaq dataset.")

        print("Loading msmarco dataset...")
        msmarco_dataset = load_dataset("sentence-transformers/msmarco-co-condenser-margin-mse-sym-mnrl-mean-v1", "triplet", split="train")
        msmarco_dataset_dict = msmarco_dataset.train_test_split(test_size=10_000, seed=12)
        msmarco_train_dataset: Dataset = msmarco_dataset_dict["train"]
        msmarco_eval_dataset: Dataset = msmarco_dataset_dict["test"]
        print("Loaded msmarco dataset.")

        print("Loading squad dataset...")
        squad_dataset = load_dataset("sentence-transformers/squad", split="train")
        squad_dataset_dict = squad_dataset.train_test_split(test_size=10_000, seed=12)
        squad_train_dataset: Dataset = squad_dataset_dict["train"]
        squad_eval_dataset: Dataset = squad_dataset_dict["test"]
        print("Loaded squad dataset.")

        print("Loading s2orc dataset...")
        s2orc_dataset = load_dataset("sentence-transformers/s2orc", "title-abstract-pair", split="train[:100000]")
        s2orc_dataset_dict = s2orc_dataset.train_test_split(test_size=10_000, seed=12)
        s2orc_train_dataset: Dataset = s2orc_dataset_dict["train"]
        s2orc_eval_dataset: Dataset = s2orc_dataset_dict["test"]
        print("Loaded s2orc dataset.")

        print("Loading allnli dataset...")
        allnli_train_dataset = load_dataset("sentence-transformers/all-nli", "triplet", split="train")
        allnli_eval_dataset = load_dataset("sentence-transformers/all-nli", "triplet", split="dev")
        print("Loaded allnli dataset.")

        print("Loading paq dataset...")
        paq_dataset = load_dataset("sentence-transformers/paq", split="train")
        paq_dataset_dict = paq_dataset.train_test_split(test_size=10_000, seed=12)
        paq_train_dataset: Dataset = paq_dataset_dict["train"]
        paq_eval_dataset: Dataset = paq_dataset_dict["test"]
        print("Loaded paq dataset.")

        print("Loading trivia_qa dataset...")
        trivia_qa = load_dataset("sentence-transformers/trivia-qa", split="train")
        trivia_qa_dataset_dict = trivia_qa.train_test_split(test_size=5_000, seed=12)
        trivia_qa_train_dataset: Dataset = trivia_qa_dataset_dict["train"]
        trivia_qa_eval_dataset: Dataset = trivia_qa_dataset_dict["test"]
        print("Loaded trivia_qa dataset.")

        print("Loading msmarco_10m dataset...")
        msmarco_10m_dataset = load_dataset("bclavie/msmarco-10m-triplets", split="train")
        msmarco_10m_dataset_dict = msmarco_10m_dataset.train_test_split(test_size=10_000, seed=12)
        msmarco_10m_train_dataset: Dataset = msmarco_10m_dataset_dict["train"]
        msmarco_10m_eval_dataset: Dataset = msmarco_10m_dataset_dict["test"]
        print("Loaded msmarco_10m dataset.")

        print("Loading swim_ir dataset...")
        swim_ir_dataset = load_dataset("nthakur/swim-ir-monolingual", "en", split="train").select_columns(["query", "text"])
        swim_ir_dataset_dict = swim_ir_dataset.train_test_split(test_size=10_000, seed=12)
        swim_ir_train_dataset: Dataset = swim_ir_dataset_dict["train"]
        swim_ir_eval_dataset: Dataset = swim_ir_dataset_dict["test"]
        print("Loaded swim_ir dataset.")

        # NOTE: 20 negatives
        print("Loading pubmedqa dataset...")
        pubmedqa_dataset = load_dataset("sentence-transformers/pubmedqa", "triplet-20", split="train")
        pubmedqa_dataset_dict = pubmedqa_dataset.train_test_split(test_size=100, seed=12)
        pubmedqa_train_dataset: Dataset = pubmedqa_dataset_dict["train"]
        pubmedqa_eval_dataset: Dataset = pubmedqa_dataset_dict["test"]
        print("Loaded pubmedqa dataset.")

        # NOTE: A lot of overlap with anchor/positives
        print("Loading miracl dataset...")
        miracl_dataset = load_dataset("sentence-transformers/miracl", "en-triplet-all", split="train")
        miracl_dataset_dict = miracl_dataset.train_test_split(test_size=10_000, seed=12)
        miracl_train_dataset: Dataset = miracl_dataset_dict["train"]
        miracl_eval_dataset: Dataset = miracl_dataset_dict["test"]
        print("Loaded miracl dataset.")

        # NOTE: A lot of overlap with anchor/positives
        print("Loading mldr dataset...")
        mldr_dataset = load_dataset("sentence-transformers/mldr", "en-triplet-all", split="train")
        mldr_dataset_dict = mldr_dataset.train_test_split(test_size=10_000, seed=12)
        mldr_train_dataset: Dataset = mldr_dataset_dict["train"]
        mldr_eval_dataset: Dataset = mldr_dataset_dict["test"]
        print("Loaded mldr dataset.")

        # NOTE: A lot of overlap with anchor/positives
        print("Loading mr_tydi dataset...")
        mr_tydi_dataset = load_dataset("sentence-transformers/mr-tydi", "en-triplet-all", split="train")
        mr_tydi_dataset_dict = mr_tydi_dataset.train_test_split(test_size=10_000, seed=12)
        mr_tydi_train_dataset: Dataset = mr_tydi_dataset_dict["train"]
        mr_tydi_eval_dataset: Dataset = mr_tydi_dataset_dict["test"]
        print("Loaded mr_tydi dataset.")

        train_dataset = DatasetDict({
            "gooaq": gooaq_train_dataset,
            "msmarco": msmarco_train_dataset,
            "squad": squad_train_dataset,
            "s2orc": s2orc_train_dataset,
            "allnli": allnli_train_dataset,
            "paq": paq_train_dataset,
            "trivia_qa": trivia_qa_train_dataset,
            "msmarco_10m": msmarco_10m_train_dataset,
            "swim_ir": swim_ir_train_dataset,
            "pubmedqa": pubmedqa_train_dataset,
            "miracl": miracl_train_dataset,
            "mldr": mldr_train_dataset,
            "mr_tydi": mr_tydi_train_dataset,
        })
        eval_dataset = DatasetDict({
            "gooaq": gooaq_eval_dataset,
            "msmarco": msmarco_eval_dataset,
            "squad": squad_eval_dataset,
            "s2orc": s2orc_eval_dataset,
            "allnli": allnli_eval_dataset,
            "paq": paq_eval_dataset,
            "trivia_qa": trivia_qa_eval_dataset,
            "msmarco_10m": msmarco_10m_eval_dataset,
            "swim_ir": swim_ir_eval_dataset,
            "pubmedqa": pubmedqa_eval_dataset,
            "miracl": miracl_eval_dataset,
            "mldr": mldr_eval_dataset,
            "mr_tydi": mr_tydi_eval_dataset,
        })

        train_dataset.save_to_disk("datasets/train_dataset")
        eval_dataset.save_to_disk("datasets/eval_dataset")
        
        # The `train_test_split` calls have put a lot of the datasets in memory, while we want it to just be on disk
        quit()
    

def main():
    # 1. Load a model to finetune with 2. (Optional) model card data
    static_embedding = StaticEmbedding(AutoTokenizer.from_pretrained("google-bert/bert-base-uncased"), embedding_dim=1024)
    model = SentenceTransformer(
        modules=[static_embedding],
        model_card_data=SentenceTransformerModelCardData(
            language="en",
            license="apache-2.0",
            model_name="Static Embeddings with BERT uncased tokenizer finetuned on various datasets",
        ),
    )

    # 3. Set up training & evaluation datasets - each dataset is trained with MNRL (with MRL)
    train_dataset, eval_dataset = load_train_eval_datasets()
    print(train_dataset)

    # 4. Define a loss function
    loss = MultipleNegativesRankingLoss(model)
    loss = MatryoshkaLoss(model, loss, matryoshka_dims=[32, 64, 128, 256, 512, 1024])

    # 5. (Optional) Specify training arguments
    run_name = "static-retrieval-mrl-en-v1"
    args = SentenceTransformerTrainingArguments(
        # Required parameter:
        output_dir=f"models/{run_name}",
        # Optional training parameters:
        num_train_epochs=1,
        per_device_train_batch_size=2048,
        per_device_eval_batch_size=2048,
        learning_rate=2e-1,
        warmup_ratio=0.1,
        fp16=False,  # Set to False if you get an error that your GPU can't run on FP16
        bf16=True,  # Set to True if you have a GPU that supports BF16
        batch_sampler=BatchSamplers.NO_DUPLICATES,  # MultipleNegativesRankingLoss benefits from no duplicate samples in a batch
        multi_dataset_batch_sampler=MultiDatasetBatchSamplers.PROPORTIONAL,
        # Optional tracking/debugging parameters:
        eval_strategy="steps",
        eval_steps=250,
        save_strategy="steps",
        save_steps=250,
        save_total_limit=2,
        logging_steps=250,
        logging_first_step=True,
        run_name=run_name,  # Will be used in W&B if `wandb` is installed
    )

    # 6. (Optional) Create an evaluator & evaluate the base model
    evaluator = NanoBEIREvaluator()
    evaluator(model)

    # 7. Create a trainer & train
    trainer = SentenceTransformerTrainer(
        model=model,
        args=args,
        train_dataset=train_dataset,
        eval_dataset=eval_dataset,
        loss=loss,
        evaluator=evaluator,
    )
    trainer.train()

    # (Optional) Evaluate the trained model on the evaluator after training
    evaluator(model)

    # 8. Save the trained model
    model.save_pretrained(f"models/{run_name}/final")

    # 9. (Optional) Push it to the Hugging Face Hub
    model.push_to_hub(run_name, private=True)

if __name__ == "__main__":
    main()

該指令碼在訓練17.8小時後生成了sentence-transformers/static-retrieval-mrl-en-v1。總共消耗了2.6千瓦時能源,排放了1千克二氧化碳。這大致相當於一個人每天撥出的二氧化碳量。

請參閱我們的Weights and Biases報告,瞭解訓練期間收集的訓練和評估指標。

多語言相似度

點選展開
import random
import logging
from datasets import load_dataset, Dataset, DatasetDict
from sentence_transformers import (
    SentenceTransformer,
    SentenceTransformerTrainer,
    SentenceTransformerTrainingArguments,
    SentenceTransformerModelCardData,
)
from sentence_transformers.losses import MatryoshkaLoss, MultipleNegativesRankingLoss
from sentence_transformers.training_args import BatchSamplers, MultiDatasetBatchSamplers
from sentence_transformers.models.StaticEmbedding import StaticEmbedding

from transformers import AutoTokenizer

logging.basicConfig(
    format="%(asctime)s - %(message)s", datefmt="%Y-%m-%d %H:%M:%S", level=logging.INFO
)
random.seed(12)


def load_train_eval_datasets():
    """
    Either load the train and eval datasets from disk or load them from the datasets library & save them to disk.

    Upon saving to disk, we quit() to ensure that the datasets are not loaded into memory before training.
    """
    try:
        train_dataset = DatasetDict.load_from_disk("datasets/train_dataset")
        eval_dataset = DatasetDict.load_from_disk("datasets/eval_dataset")
        return train_dataset, eval_dataset
    except FileNotFoundError:
        print("Loading wikititles dataset...")
        wikititles_dataset = load_dataset("sentence-transformers/parallel-sentences-wikititles", split="train")
        wikititles_dataset_dict = wikititles_dataset.train_test_split(test_size=10_000, seed=12)
        wikititles_train_dataset: Dataset = wikititles_dataset_dict["train"]
        wikititles_eval_dataset: Dataset = wikititles_dataset_dict["test"]
        print("Loaded wikititles dataset.")

        print("Loading tatoeba dataset...")
        tatoeba_dataset = load_dataset("sentence-transformers/parallel-sentences-tatoeba", "all", split="train")
        tatoeba_dataset_dict = tatoeba_dataset.train_test_split(test_size=10_000, seed=12)
        tatoeba_train_dataset: Dataset = tatoeba_dataset_dict["train"]
        tatoeba_eval_dataset: Dataset = tatoeba_dataset_dict["test"]
        print("Loaded tatoeba dataset.")

        print("Loading talks dataset...")
        talks_dataset = load_dataset("sentence-transformers/parallel-sentences-talks", "all", split="train")
        talks_dataset_dict = talks_dataset.train_test_split(test_size=10_000, seed=12)
        talks_train_dataset: Dataset = talks_dataset_dict["train"]
        talks_eval_dataset: Dataset = talks_dataset_dict["test"]
        print("Loaded talks dataset.")

        print("Loading europarl dataset...")
        europarl_dataset = load_dataset("sentence-transformers/parallel-sentences-europarl", "all", split="train[:5000000]")
        europarl_dataset_dict = europarl_dataset.train_test_split(test_size=10_000, seed=12)
        europarl_train_dataset: Dataset = europarl_dataset_dict["train"]
        europarl_eval_dataset: Dataset = europarl_dataset_dict["test"]
        print("Loaded europarl dataset.")

        print("Loading global voices dataset...")
        global_voices_dataset = load_dataset("sentence-transformers/parallel-sentences-global-voices", "all", split="train")
        global_voices_dataset_dict = global_voices_dataset.train_test_split(test_size=10_000, seed=12)
        global_voices_train_dataset: Dataset = global_voices_dataset_dict["train"]
        global_voices_eval_dataset: Dataset = global_voices_dataset_dict["test"]
        print("Loaded global voices dataset.")

        print("Loading jw300 dataset...")
        jw300_dataset = load_dataset("sentence-transformers/parallel-sentences-jw300", "all", split="train")
        jw300_dataset_dict = jw300_dataset.train_test_split(test_size=10_000, seed=12)
        jw300_train_dataset: Dataset = jw300_dataset_dict["train"]
        jw300_eval_dataset: Dataset = jw300_dataset_dict["test"]
        print("Loaded jw300 dataset.")

        print("Loading muse dataset...")
        muse_dataset = load_dataset("sentence-transformers/parallel-sentences-muse", split="train")
        muse_dataset_dict = muse_dataset.train_test_split(test_size=10_000, seed=12)
        muse_train_dataset: Dataset = muse_dataset_dict["train"]
        muse_eval_dataset: Dataset = muse_dataset_dict["test"]
        print("Loaded muse dataset.")

        print("Loading wikimatrix dataset...")
        wikimatrix_dataset = load_dataset("sentence-transformers/parallel-sentences-wikimatrix", "all", split="train")
        wikimatrix_dataset_dict = wikimatrix_dataset.train_test_split(test_size=10_000, seed=12)
        wikimatrix_train_dataset: Dataset = wikimatrix_dataset_dict["train"]
        wikimatrix_eval_dataset: Dataset = wikimatrix_dataset_dict["test"]
        print("Loaded wikimatrix dataset.")

        print("Loading opensubtitles dataset...")
        opensubtitles_dataset = load_dataset("sentence-transformers/parallel-sentences-opensubtitles", "all", split="train[:5000000]")
        opensubtitles_dataset_dict = opensubtitles_dataset.train_test_split(test_size=10_000, seed=12)
        opensubtitles_train_dataset: Dataset = opensubtitles_dataset_dict["train"]
        opensubtitles_eval_dataset: Dataset = opensubtitles_dataset_dict["test"]
        print("Loaded opensubtitles dataset.")

        print("Loading stackexchange dataset...")
        stackexchange_dataset = load_dataset("sentence-transformers/stackexchange-duplicates", "post-post-pair", split="train")
        stackexchange_dataset_dict = stackexchange_dataset.train_test_split(test_size=10_000, seed=12)
        stackexchange_train_dataset: Dataset = stackexchange_dataset_dict["train"]
        stackexchange_eval_dataset: Dataset = stackexchange_dataset_dict["test"]
        print("Loaded stackexchange dataset.")

        print("Loading quora dataset...")
        quora_dataset = load_dataset("sentence-transformers/quora-duplicates", "triplet", split="train")
        quora_dataset_dict = quora_dataset.train_test_split(test_size=10_000, seed=12)
        quora_train_dataset: Dataset = quora_dataset_dict["train"]
        quora_eval_dataset: Dataset = quora_dataset_dict["test"]
        print("Loaded quora dataset.")

        print("Loading wikianswers duplicates dataset...")
        wikianswers_duplicates_dataset = load_dataset("sentence-transformers/wikianswers-duplicates", split="train[:10000000]")
        wikianswers_duplicates_dict = wikianswers_duplicates_dataset.train_test_split(test_size=10_000, seed=12)
        wikianswers_duplicates_train_dataset: Dataset = wikianswers_duplicates_dict["train"]
        wikianswers_duplicates_eval_dataset: Dataset = wikianswers_duplicates_dict["test"]
        print("Loaded wikianswers duplicates dataset.")

        print("Loading all nli dataset...")
        all_nli_train_dataset = load_dataset("sentence-transformers/all-nli", "triplet", split="train")
        all_nli_eval_dataset = load_dataset("sentence-transformers/all-nli", "triplet", split="dev")
        print("Loaded all nli dataset.")

        print("Loading simple wiki dataset...")
        simple_wiki_dataset = load_dataset("sentence-transformers/simple-wiki", split="train")
        simple_wiki_dataset_dict = simple_wiki_dataset.train_test_split(test_size=10_000, seed=12)
        simple_wiki_train_dataset: Dataset = simple_wiki_dataset_dict["train"]
        simple_wiki_eval_dataset: Dataset = simple_wiki_dataset_dict["test"]
        print("Loaded simple wiki dataset.")

        print("Loading altlex dataset...")
        altlex_dataset = load_dataset("sentence-transformers/altlex", split="train")
        altlex_dataset_dict = altlex_dataset.train_test_split(test_size=10_000, seed=12)
        altlex_train_dataset: Dataset = altlex_dataset_dict["train"]
        altlex_eval_dataset: Dataset = altlex_dataset_dict["test"]
        print("Loaded altlex dataset.")

        print("Loading flickr30k captions dataset...")
        flickr30k_captions_dataset = load_dataset("sentence-transformers/flickr30k-captions", split="train")
        flickr30k_captions_dataset_dict = flickr30k_captions_dataset.train_test_split(test_size=10_000, seed=12)
        flickr30k_captions_train_dataset: Dataset = flickr30k_captions_dataset_dict["train"]
        flickr30k_captions_eval_dataset: Dataset = flickr30k_captions_dataset_dict["test"]
        print("Loaded flickr30k captions dataset.")

        print("Loading coco captions dataset...")
        coco_captions_dataset = load_dataset("sentence-transformers/coco-captions", split="train")
        coco_captions_dataset_dict = coco_captions_dataset.train_test_split(test_size=10_000, seed=12)
        coco_captions_train_dataset: Dataset = coco_captions_dataset_dict["train"]
        coco_captions_eval_dataset: Dataset = coco_captions_dataset_dict["test"]
        print("Loaded coco captions dataset.")

        print("Loading nli for simcse dataset...")
        nli_for_simcse_dataset = load_dataset("sentence-transformers/nli-for-simcse", "triplet", split="train")
        nli_for_simcse_dataset_dict = nli_for_simcse_dataset.train_test_split(test_size=10_000, seed=12)
        nli_for_simcse_train_dataset: Dataset = nli_for_simcse_dataset_dict["train"]
        nli_for_simcse_eval_dataset: Dataset = nli_for_simcse_dataset_dict["test"]
        print("Loaded nli for simcse dataset.")

        print("Loading negation dataset...")
        negation_dataset = load_dataset("jinaai/negation-dataset", split="train")
        negation_dataset_dict = negation_dataset.train_test_split(test_size=100, seed=12)
        negation_train_dataset: Dataset = negation_dataset_dict["train"]
        negation_eval_dataset: Dataset = negation_dataset_dict["test"]
        print("Loaded negation dataset.")

        train_dataset = DatasetDict({
            "wikititles": wikititles_train_dataset,
            "tatoeba": tatoeba_train_dataset,
            "talks": talks_train_dataset,
            "europarl": europarl_train_dataset,
            "global_voices": global_voices_train_dataset,
            "jw300": jw300_train_dataset,
            "muse": muse_train_dataset,
            "wikimatrix": wikimatrix_train_dataset,
            "opensubtitles": opensubtitles_train_dataset,
            "stackexchange": stackexchange_train_dataset,
            "quora": quora_train_dataset,
            "wikianswers_duplicates": wikianswers_duplicates_train_dataset,
            "all_nli": all_nli_train_dataset,
            "simple_wiki": simple_wiki_train_dataset,
            "altlex": altlex_train_dataset,
            "flickr30k_captions": flickr30k_captions_train_dataset,
            "coco_captions": coco_captions_train_dataset,
            "nli_for_simcse": nli_for_simcse_train_dataset,
            "negation": negation_train_dataset,
        })
        eval_dataset = DatasetDict({
            "wikititles": wikititles_eval_dataset,
            "tatoeba": tatoeba_eval_dataset,
            "talks": talks_eval_dataset,
            "europarl": europarl_eval_dataset,
            "global_voices": global_voices_eval_dataset,
            "jw300": jw300_eval_dataset,
            "muse": muse_eval_dataset,
            "wikimatrix": wikimatrix_eval_dataset,
            "opensubtitles": opensubtitles_eval_dataset,
            "stackexchange": stackexchange_eval_dataset,
            "quora": quora_eval_dataset,
            "wikianswers_duplicates": wikianswers_duplicates_eval_dataset,
            "all_nli": all_nli_eval_dataset,
            "simple_wiki": simple_wiki_eval_dataset,
            "altlex": altlex_eval_dataset,
            "flickr30k_captions": flickr30k_captions_eval_dataset,
            "coco_captions": coco_captions_eval_dataset,
            "nli_for_simcse": nli_for_simcse_eval_dataset,
            "negation": negation_eval_dataset,
        })

        train_dataset.save_to_disk("datasets/train_dataset")
        eval_dataset.save_to_disk("datasets/eval_dataset")
        
        # The `train_test_split` calls have put a lot of the datasets in memory, while we want it to just be on disk
        quit()

def main():
    # 1. Load a model to finetune with 2. (Optional) model card data
    static_embedding = StaticEmbedding(AutoTokenizer.from_pretrained("google-bert/bert-base-multilingual-uncased"), embedding_dim=1024)
    model = SentenceTransformer(
        modules=[static_embedding],
        model_card_data=SentenceTransformerModelCardData(
            license="apache-2.0",
            model_name="Static Embeddings with BERT Multilingual uncased tokenizer finetuned on various datasets",
        ),
    )

    # 3. Set up training & evaluation datasets - each dataset is trained with MNRL (with MRL)
    train_dataset, eval_dataset = load_train_eval_datasets()
    print(train_dataset)

    # 4. Define a loss function
    loss = MultipleNegativesRankingLoss(model)
    loss = MatryoshkaLoss(model, loss, matryoshka_dims=[32, 64, 128, 256, 512, 1024])

    # 5. (Optional) Specify training arguments
    run_name = "static-similarity-mrl-multilingual-v1"
    args = SentenceTransformerTrainingArguments(
        # Required parameter:
        output_dir=f"models/{run_name}",
        # Optional training parameters:
        num_train_epochs=1,
        per_device_train_batch_size=2048,
        per_device_eval_batch_size=2048,
        learning_rate=2e-1,
        warmup_ratio=0.1,
        fp16=False,  # Set to False if you get an error that your GPU can't run on FP16
        bf16=True,  # Set to True if you have a GPU that supports BF16
        batch_sampler=BatchSamplers.NO_DUPLICATES,  # MultipleNegativesRankingLoss benefits from no duplicate samples in a batch
        multi_dataset_batch_sampler=MultiDatasetBatchSamplers.PROPORTIONAL,
        # Optional tracking/debugging parameters:
        eval_strategy="steps",
        eval_steps=1000,
        save_strategy="steps",
        save_steps=1000,
        save_total_limit=2,
        logging_steps=1000,
        logging_first_step=True,
        run_name=run_name,  # Will be used in W&B if `wandb` is installed
    )

    # 6. Create a trainer & train
    trainer = SentenceTransformerTrainer(
        model=model,
        args=args,
        train_dataset=train_dataset,
        eval_dataset=eval_dataset,
        loss=loss,
    )
    trainer.train()

    # 7. Save the trained model
    model.save_pretrained(f"models/{run_name}/final")

    # 8. (Optional) Push it to the Hugging Face Hub
    model.push_to_hub(run_name, private=True)

if __name__ == "__main__":
    main()

這個模型只比流行但慢得多的multilingual-e5-small模型損失了大約8%的效能,正如即將到來的效能 > 多語言相似度部分所示。

請參閱我們的Weights and Biases 報告,瞭解訓練期間收集的訓練和評估損失。

用法

這些模型的使用非常簡單,與常規的Sentence Transformers流程相同

英語檢索

from sentence_transformers import SentenceTransformer

# Download from the 🤗 Hub
model = SentenceTransformer("sentence-transformers/static-retrieval-mrl-en-v1", device="cpu")
# Run inference
sentences = [
    'Gadofosveset-enhanced MR angiography of carotid arteries: does steady-state imaging improve accuracy of first-pass imaging?',
    'To evaluate the diagnostic accuracy of gadofosveset-enhanced magnetic resonance (MR) angiography in the assessment of carotid artery stenosis, with digital subtraction angiography (DSA) as the reference standard, and to determine the value of reading first-pass, steady-state, and "combined" (first-pass plus steady-state) MR angiograms.',
    'In a longitudinal study we investigated in vivo alterations of CVO during neuroinflammation, applying Gadofluorine M- (Gf) enhanced magnetic resonance imaging (MRI) in experimental autoimmune encephalomyelitis, an animal model of multiple sclerosis. SJL/J mice were monitored by Gadopentate dimeglumine- (Gd-DTPA) and Gf-enhanced MRI after adoptive transfer of proteolipid-protein-specific T cells. Mean Gf intensity ratios were calculated individually for different CVO and correlated to the clinical disease course. Subsequently, the tissue distribution of fluorescence-labeled Gf as well as the extent of cellular inflammation was assessed in corresponding histological slices.',
]
embeddings = model.encode(sentences)
print(embeddings.shape)
# [3, 1024]

# Get the similarity scores for the embeddings
similarities = model.similarity(embeddings[0], embeddings[1:])
print(similarities)
# tensor([[0.7649, 0.3279]])

即將到來的效能 > 英語檢索部分將顯示,這些結果非常可靠,與常用的基於 Transformer 的編碼器模型(如all-mpnet-base-v2)相差不到 15%。

多語言相似度

from sentence_transformers import SentenceTransformer

# Download from the 🤗 Hub
model = SentenceTransformer("sentence-transformers/static-similarity-mrl-multilingual-v1", device="cpu")
# Run inference
sentences = [
    'It is known for its dry red chili powder.',
    'It is popular for dried red chili powder.',
    'These monsters will move in large groups.',
]
embeddings = model.encode(sentences)
print(embeddings.shape)
# [3, 1024]

# Get the similarity scores for the embeddings
similarities = model.similarity(embeddings, embeddings)
print(similarities)
# tensor([[ 1.0000,  0.8388, -0.0012],
#         [ 0.8388,  1.0000,  0.0445],
#         [-0.0012,  0.0445,  1.0000]])

與流行的但速度慢得多的multilingual-e5-small相比,該模型僅損失約8%的效能,如即將到來的效能 > 多語言相似度部分所示。

套娃降維截斷

要降低計算出的嵌入的維度,您只需傳遞 `truncate_dim` 引數即可。這適用於所有 Sentence Transformer 模型。

from sentence_transformers import SentenceTransformer

# Download from the 🤗 Hub
model = SentenceTransformer(
    "sentence-transformers/static-retrieval-mrl-en-v1",
    device="cpu",
    truncate_dim=256,
)
# Run inference
sentences = [
    'Gadofosveset-enhanced MR angiography of carotid arteries: does steady-state imaging improve accuracy of first-pass imaging?',
    'To evaluate the diagnostic accuracy of gadofosveset-enhanced magnetic resonance (MR) angiography in the assessment of carotid artery stenosis, with digital subtraction angiography (DSA) as the reference standard, and to determine the value of reading first-pass, steady-state, and "combined" (first-pass plus steady-state) MR angiograms.',
    'In a longitudinal study we investigated in vivo alterations of CVO during neuroinflammation, applying Gadofluorine M- (Gf) enhanced magnetic resonance imaging (MRI) in experimental autoimmune encephalomyelitis, an animal model of multiple sclerosis. SJL/J mice were monitored by Gadopentate dimeglumine- (Gd-DTPA) and Gf-enhanced MRI after adoptive transfer of proteolipid-protein-specific T cells. Mean Gf intensity ratios were calculated individually for different CVO and correlated to the clinical disease course. Subsequently, the tissue distribution of fluorescence-labeled Gf as well as the extent of cellular inflammation was assessed in corresponding histological slices.',
]
embeddings = model.encode(sentences)
print(embeddings.shape)
# [3, 256]

# Get the similarity scores for the embeddings
similarities = model.similarity(embeddings[0], embeddings[1:])
print(similarities)
# tensor([[0.7844, 0.3561]])

第三方庫

該模型還可與各種第三方庫開箱即用,例如LangChainLlamaIndexHaystacktxtai

LangChain

# pip install langchain langchain_huggingface
from langchain_huggingface import HuggingFaceEmbeddings

model_name = "sentence-transformers/static-retrieval-mrl-en-v1"
model_kwargs = {'device': 'cpu'} # you can use 'truncate_dim' here
model = HuggingFaceEmbeddings(
    model_name=model_name,
    model_kwargs=model_kwargs,
)

LlamaIndex

# pip install llama-index llama-index-embeddings-huggingface
from llama_index.core import Settings
from llama_index.embeddings.huggingface import HuggingFaceEmbedding

# Set up the HuggingFaceEmbedding class with the required model to use with llamaindex core.
model_name = "sentence-transformers/static-retrieval-mrl-en-v1"
device = "cpu"
embed_model = HuggingFaceEmbedding(
    model_name=model_name,
    device=device,
    # truncate_dim=256, # you can use 'truncate_dim' here
)
Settings.embed_model = embed_model

Haystack

# pip install haystack sentence-transformers
from haystack.components.embedders import (
    SentenceTransformersDocumentEmbedder,
    SentenceTransformersTextEmbedder,
)

model_name = "sentence-transformers/static-retrieval-mrl-en-v1"
device = "cpu"
document_embedder = SentenceTransformersDocumentEmbedder(
    model=model_name,
    device=device,
    # truncate_dim=256, # you can use 'truncate_dim' here
)
text_embedder = SentenceTransformersTextEmbedder(
    model=model_name,
    device=device,
    # truncate_dim=256, # you can use 'truncate_dim' here
)

txtai

# pip install txtai sentence-transformers
from txtai import Embeddings

model_name = "sentence-transformers/static-retrieval-mrl-en-v1"
embeddings = Embeddings(path=model_name)

效能

英語檢索

訓練完成後,我們評估了最終模型sentence-transformers/static-retrieval-mrl-en-v1在NanoBEIR(正常維度和Matryoshka維度)以及BEIR上的效能。

NanoBEIR

我們評估了sentence-transformers/static-retrieval-mrl-en-v1在NanoBEIR上的效能,並將其與我們在硬體上計算的推理速度進行了對比。對於推理速度測試,我們計算了每秒在CPU或GPU上GooAQ資料集計算的查詢嵌入數量。

我們評估了三種類型的模型:

  1. 基於注意力的密集嵌入模型,例如傳統的 Sentence Transformer 模型,如`all-mpnet-base-v2``bge-base-en-v1.5``gte-large-en-v1.5`

  2. 基於靜態嵌入的模型,例如 `static-retrieval-mrl-en-v1`, `potion-base-8M`, `M2V_base_output`, 和 `glove.6B.300d`

  3. 稀疏詞袋模型,BM25,通常是一個強大的基線。

    點選展開BM25實現細節

    我們依賴於高效的bm25s實現,在標記化和使用英文`PyStemmer`進行詞幹提取後,對標記使用`model.get_scores()`。

**注意:**許多基於注意力的密集嵌入模型在(Nano)BEIR 評估資料集的訓練集上進行了微調。這使得模型在此基準測試中具有不公平的優勢,並可能導致實際檢索任務的下游效能降低。

static-retrieval-mrl-en-v1 有意未在這些資料集上進行訓練。

點選檢視下一頁兩張圖表的所有數值
模型 NanoBEIR NDCG@10 CPU (每秒句子數) GPU (每秒句子數)
zeta-alpha-ai/Zeta-Alpha-E5-Mistral 0.6860 0.00* 0.00*
Alibaba-NLP/gte-large-en-v1.5 0.6808 56.01 965.95
Salesforce/SFR-Embedding-Mistral 0.6800 0.00* 0.00*
mixedbread-ai/mxbai-embed-large-v1 0.6567 79.83 1376.80
BAAI/bge-large-en-v1.5 0.6592 80.94 1315.03
intfloat/e5-mistral-7b-instruct 0.6530 0.00* 0.00*
Alibaba-NLP/gte-base-en-v1.5 0.6411 197.85 3142.94
BAAI/bge-base-en-v1.5 0.6376 264.83 4363.04
BAAI/bge-small-en-v1.5 0.6267 888.46 10159.97
nomic-ai/nomic-embed-text-v1.5 0.6179 86.86 2843.03
jinaai/jina-embeddings-v3 0.6174 0.55 3377.56
BAAI/bge-m3 0.6054 80.63 1434.82
sentence-transformers/all-mpnet-base-v2 0.5757 270.40 4043.13
TaylorAI/gte-tiny 0.5692 1752.26 17215.15
sentence-transformers/all-MiniLM-L6-v2 0.5623 1739.31 16942.46
mixedbread-ai/mxbai-embed-xsmall-v1 0.5557 1749.42 16773.76
sentence-transformers/all-MiniLM-L12-v2 0.5533 909.72 9915.69
sentence-transformers/static-retrieval-mrl-en-v1 0.5032 107419.51 97171.47
bm25 0.4518 49706.77 49706.77
minishlab/potion-base-8M 0.4421 124029.91 122384.10
minishlab/potion-base-4M 0.4225 123082.88 123612.54
minishlab/M2V_base_glove 0.4077 142173.77 146154.73
minishlab/M2V_base_glove_subword 0.3914 127426.83 131412.56
minishlab/M2V_base_output 0.3851 84191.93 85738.36
minishlab/potion-base-2M 0.3666 128994.27 122358.16
sentence-transformers/glove.6B.300d 0.3293 76519.74 62782.23
sentence-transformers/glove.840B.300d 0.2899 86348.98 75350.36
  • *:對於 7B LLM,我們沒有進行推理實驗,因為它們的推理速度在圖中將無法區分。
  • 我們進行了實驗以確定每個模型的最佳批處理大小。
GPU

NanoBEIR performance vs inference speed

CPU

NanoBEIR performance vs inference speed

我們可以從這些資料中得出一些顯著結論:

  1. `static-retrieval-mrl-en-v1` 的效能優於所有其他靜態嵌入模型,如 GloVe 或 Model2Vec。
  2. `static-retrieval-mrl-en-v1` 是唯一優於 BM25 的靜態嵌入模型。
  3. `static-retrieval-mrl-en-v1` 的效能:
    • 與常用模型`all-mpnet-base-v2`相比,效能達到**87.4%**,
    • 在GPU上快**24倍**,
    • 在CPU上快**397倍**。
  4. `static-retrieval-mrl-en-v1` 在 CPU 上比在 GPU 上更快:此模型可以在任何地方以極快的速度執行,包括消費級 PC、小型伺服器、手機或瀏覽器中。

Matryoshka 評估

此外,我們透過將輸出嵌入截斷到較低維度,進行了 Matryoshka 式降維,並對 NanoBEIR 效能結果進行了實驗。

NanoBEIR performance vs Matryoshka dimensionality reduction

這些發現表明,例如將維度減少 2 倍,效能僅下降 1.47%(0.5031 NDCG@10 對 0.4957 NDCG@10),而實際上檢索速度卻提高了 2 倍。

多語言相似度

我們還評估了最終的 sentence-transformers/static-similarity-mrl-multilingual-v1 模型在 5 種語言上的表現,這些語言在 MTEB 上有大量基準測試。

我們希望重申,此模型不適用於檢索用例。相反,我們評估的是語義文字相似度 (STS)、分類和對分類。我們與出色的輕量級 multilingual-e5-small 模型進行了比較。

STS, Classification, Pair Classification on MTEB

在所有測試語言中,static-similarity-mrl-multilingual-v1 相對於 multilingual-e5-small 在 STS 上平均達到 92.3%,在對分類上達到 95.52%,在分類上達到 86.52%

Texts per second processed

為了彌補這種效能下降,static-similarity-mrl-multilingual-v1 在 CPU 裝置上比 multilingual-e5-small 快約 125 倍,在 GPU 裝置上快約 10 倍。由於注意力模型的超線性性質,與靜態嵌入模型的線性性質相比,編碼令牌數量的增加將使加速效果更大。

Matryoshka 評估

最後,我們透過將輸出嵌入截斷到較低維度,進行了 Matryoshka 式降維,並對英語 STS 在 MTEB 效能上的影響進行了實驗。

English STS MTEB performance vs Matryoshka dimensionality reduction

如您所見,您可以輕鬆地將維度減少 2 倍或 4 倍,而效能損失很小(0.15% 或 0.56%)。如果您的下游任務的速度或儲存成本是瓶頸,這應該可以幫助您緩解一些擔憂。

結論

這篇部落格文章描述了我們從構思到完成模型的所有步驟,以及關於兩個結果模型(static-retrieval-mrl-en-v1static-similarity-mrl-multilingual-v1)的使用和評估的詳細資訊。

評估結果表明:

  • 基於靜態嵌入的模型可以超過常見基於注意力稠密模型效能的 85%
  • 基於靜態嵌入的模型在 GPU 上比常見的有效替代方案(如 all-mpnet-base-v2multilingual-e5-small)快 10 倍到 25 倍,在 CPU 上快 100 倍到 400 倍。文字越長,這種加速效果就越大。
  • 使用 Matryoshka 損失進行訓練可以顯著保持下游效能

如果您需要一個高效的僅支援 CPU 的稠密嵌入模型來執行檢索或相似性任務,那麼 static-retrieval-mrl-en-v1static-similarity-mrl-multilingual-v1 將是以最小成本提供極其高效的解決方案,並且其效能出人意料地接近基於注意力的稠密模型。

後續步驟

試一試!如果您已經在某個地方使用了 Sentence Transformer 模型,請隨意將其替換為 static-retrieval-mrl-en-v1static-similarity-mrl-multilingual-v1。或者,更好的是:根據您感興趣的任務和語言的代表性資料訓練您自己的模型。

此外,關於訓練好的模型,還有一些問題有待解決。

  1. 由於基於靜態嵌入的模型不受位置嵌入或超線性時間複雜度的瓶頸,因此它們可以具有任意高的最大序列長度。然而,在某個時刻,大數定律可能會“規範化”所有真正長文件的嵌入,使其不再有用。

    需要進行更多的實驗來確定一個好的截斷點。目前,我們將最大序列長度、分塊等留給使用者。

此外,還有一些可能的擴充套件,可能會提高此模型的效能,我們很高興將其留給其他模型作者。我們也歡迎合作。

  1. 困難負樣本挖掘:搜尋相似但不相關的文字以提高訓練資料難度。
  2. 模型聚合:結合以相同方式訓練的多個模型(使用不同種子或資料分佈)的權重。
  3. 課程學習:從難度逐漸增加的示例中進行訓練。
  4. 引導式批內假負樣本過濾:透過高效的預訓練嵌入模型排除假負樣本。
  5. 隨機權重初始化的種子最佳化:用各種種子訓練最初的步驟,以找到一個有用的權重初始化。
  6. 分詞器再訓練:使用現代文字和學習成果對分詞器進行再訓練。
  7. 梯度快取:透過 CachedMultipleNegativesRankingLoss 應用 GradCache 可以實現更大的批次,這通常會帶來更好的效能。
  8. 模型蒸餾:除了僅使用有監督的訓練資料進行訓練外,我們還可以透過更大的嵌入模型輸入無監督資料,並將這些嵌入蒸餾到基於靜態嵌入的學生模型中。

致謝

我要感謝 Stéphan TulkensThomas van DongenThe Minish Lab 團隊,他們透過 Model2Vec 工作讓我關注到了靜態嵌入模型。此外,我還要感謝 Vaibhav SrivastavPedro Cuenca 在這篇部落格文章中的幫助,以及 Antoine Chaffin 提出的釋出檢查點。

最後,非常感謝所有致力於嵌入模型、資料集和開源 Python 包的研究人員。你們為行業增添了力量,我將站在你們的肩膀上。希望有一天,你們也能站在我的肩膀上。

社群

英偉達:你需要購買 GPU 和機器。
開發者:不,我們只需調整演算法即可。
埃隆·馬斯克:看看我買的 GPU,你們這些窮鬼。

非常感謝您的帖子,工作很棒,

我已經訓練了一些英語和西班牙語模型

  • NickyNicky/StaticEmbedding-MatryoshkaLoss-gemma-2-2b-en-es
  • NickyNicky/StaticEmbedding-MatryoshkaLoss-gemma-2-2b-gooaq-en

我想知道如何增加或減少
“max_length 示例 371”

當我檢查“print(model.max_seq_length) # -> Inf”時。

這可能嗎,怎麼做?我找不到相關文件

非常感謝

·
文章作者

你好!

這些模型做得真好!我是否理解正確,其中一個模型在所有資料集上達到了 NanoBEIR 的 0.5623 NDCG@10?這比 static-retrieval-mrl-en-v1 的 0.5032 NDCG@10 提升了很大。

我想知道如何增加或減少
“max_length 示例 371”

你指的是模型卡 這裡 的“max”嗎?
image.png

那只是關於訓練資料的一些近似統計;取自前 1000 個樣本。儘管不建議使用(遠)大於訓練資料的序列長度的文字,但實際的最大序列長度確實是無限的。它在這裡定義:https://github.com/UKPLab/sentence-transformers/blob/cccab8303aaf6e18f069b0da578b3d162bf8442a/sentence_transformers/models/StaticEmbedding.py#L106-L108

簡而言之:模型永遠不會截斷序列,因為該方法

  1. 具有線性複雜度(資料量增加 2 倍 -> 速度慢 2 倍),這與 Transformer 模型(資料量增加 2 倍 -> 速度慢(遠)超過 2 倍)不同。
  2. 不受可能對最大序列長度施加限制的位置嵌入的影響。

所以,靜態模型沒有最大序列長度。它們只是要求使用者注意不要輸入過大的文件,因為所有文件如果足夠長,最終都會嵌入得非常相似。

  • Tom Aarsen

這太酷了!我很驚訝你做得比 model2vec 更好——區別真的只是使用了(更好的)對比損失預訓練公式嗎?

·
文章作者

是的!架構是相同的。事實上,這篇部落格文章中描述的模型所使用的 StaticEmbedding 模組與在 Sentence Transformers 中載入 Model2Vec 模型時使用的模組實際上是相同的。

from sentence_transformers import SentenceTransformer
from sentence_transformers.models import StaticEmbedding
from tokenizers import Tokenizer

# Pre-distilled embeddings:
static_embedding = StaticEmbedding.from_model2vec("minishlab/M2V_base_output")

model = SentenceTransformer(modules=[static_embedding])

embeddings = model.encode(["What are Pandas?", "The giant panda (Ailuropoda melanoleuca; Chinese: 大熊貓; pinyin: dàxióngmāo), also known as the panda bear or simply the panda, is a bear native to south central China."])
similarity = model.similarity(embeddings[0], embeddings[1])
# tensor([[0.9177]]) (If you use the distilled bge-base)

驚人的工作和出色的寫作!

我嘗試在部分資料(AllNLI, GooAQ, MSMacro, PAQ, S2ORC)上進行此訓練,批處理大小為 16384。耗時 5 小時。

w&b
https://api.wandb.ai/links/arunarumugam411-sui/dkcwm6gs

很棒的工作。看起來很酷!

看起來很棒,但是
1
你能闡明其背後的想法嗎?你是為每個詞元計算嵌入然後取平均值嗎?
2
你能分享 NanoBEIR 的連結嗎?

·

你有像論文一樣的詳細描述嗎,拜託

你能解釋一下嗎
來自
https://huggingface.co/blog/Pringled/model2vec
齊普夫
由於我們對空間中的詞元進行了簡單平均,因此正確加權向量非常重要。通常,一個句子轉換器會在給定上下文的情況下為我們正確加權所有詞元,但我們不再擁有這種奢侈。直觀地,我們希望使用類似逆文件頻率 (IDF) 的方法來降低非常頻繁或無趣的詞的權重。但是我們無法訪問一個語料庫來計算文件頻率。

為了克服這個問題,我們選擇使用語言科學中一個眾所周知的原理,即,給定一個按頻率排序的列表,列表中專案的頻率遵循冪律分佈。這被稱為齊普夫定律。因此,如果我們假設詞彙表是按頻率排序的,我們就可以準確地降低非常頻繁專案的權重,而無需訪問實際頻率。由於分詞器詞彙表是按頻率排序的,我們已經可以訪問一個排序列表,因此無需任何額外工作即可應用此最佳化。

所以對於假設的齊普夫輸入
[ [ 0.2,0.5,0.7] , [1.2, 0.9,0.2], [0.4, 0.3, 0.2] ,[1.3, 2.4, 3.2]]

1
根據每個向量範數對輸入進行排序
所以你得到
[ [0.4, 0.3, 0.2] , [ 0.2,0.5,0.7] , [1.2, 0.9,0.2],[1.3, 2.4, 3.2] ]
2
你將每個向量除以它的範數

[ [0.4, 0.3, 0.2]/n1 , [ 0.2,0.5,0.7]/n2 , [1.2, 0.9,0.2] /n3 ,[1.3, 2.4, 3.2]/n4 ]

3
那麼最終的嵌入是這些降權向量的平均值嗎?
( [0.4, 0.3, 0.2]/n1 + [ 0.2,0.5,0.7]/n2 + [1.2, 0.9,0.2] /n3 + [1.3, 2.4, 3.2]/n4) / 4

這是正確的演算法嗎?

太棒了!這項技術真的令人大開眼界^^

不過,對帖子標題提個小建議,一開始,我以為這篇帖子是關於更快地訓練句子嵌入模型,而不是關於訓練推理時間更快的句子嵌入模型。只是讓你們知道。

這是一種很棒的方法!

我透過整合大量日語資料集,訓練了一個靜態嵌入日語模型(static-embedding-japanese),當我們在日語多語言文字嵌入基準(JMTEB)上進行比較時,我能夠獲得僅略低於 mE5-small 的分數。

JMTEB 結果

模型 平均(微觀) 檢索 STS 分類 重排 聚類 對分類
文字嵌入-3-小型 69.18 66.39 79.46 73.06 92.92 51.06 62.27
多語言-e5-小型 67.71 67.27 80.07 67.62 93.03 46.91 62.19
靜態嵌入-日語 67.17 67.92 80.16 67.96 91.87 40.39 62.37

感謝您發表如此優秀的文章。

·
文章作者

這表現太棒了,工作出色!我也很感謝你非常詳細的模型卡——我現在就用翻譯器讀一下!

隨著一些研究方向轉向無分詞器模型,不知道字元級相似度訓練模型能達到多遠。

·

還在想,考慮到我們現在擁有的額外計算和時間,與其他靜態嵌入模型進行整合是否會帶來額外的質量改進。但是,最好的整合方式可能是什麼,也許特定領域資料集會有幫助,不同損失訓練的模型呢?

有沒有計劃讓模型可用於文字嵌入推理?

https://github.com/huggingface/text-embeddings-inference

註冊登入 以發表評論

© . This site is unofficial and not affiliated with Hugging Face, Inc.