Datasets 文件
搜尋索引
加入 Hugging Face 社群
並獲得增強的文件體驗
開始使用
搜尋索引
FAISS 和 Elasticsearch 可以在資料集中搜索樣本。當您希望從資料集中檢索與您的 NLP 任務相關的特定樣本時,這會很有用。例如,如果您正在處理開放域問答任務,您可能只想返回與回答您的問題相關的樣本。
本指南將向您展示如何為您的資料集構建索引,以便您可以進行搜尋。
FAISS
FAISS 根據文件向量表示的相似性來檢索文件。在此示例中,您將使用 DPR 模型生成向量表示。
- 從 🤗 Transformers 下載 DPR 模型
>>> from transformers import DPRContextEncoder, DPRContextEncoderTokenizer
>>> import torch
>>> torch.set_grad_enabled(False)
>>> ctx_encoder = DPRContextEncoder.from_pretrained("facebook/dpr-ctx_encoder-single-nq-base")
>>> ctx_tokenizer = DPRContextEncoderTokenizer.from_pretrained("facebook/dpr-ctx_encoder-single-nq-base")
- 載入您的資料集並計算向量表示
>>> from datasets import load_dataset
>>> ds = load_dataset('crime_and_punish', split='train[:100]')
>>> ds_with_embeddings = ds.map(lambda example: {'embeddings': ctx_encoder(**ctx_tokenizer(example["line"], return_tensors="pt"))[0][0].numpy()})
- 使用 Dataset.add_faiss_index() 建立索引
>>> ds_with_embeddings.add_faiss_index(column='embeddings')
- 現在,您可以使用
embeddings
索引來查詢您的資料集。載入 DPR Question Encoder,並使用 Dataset.get_nearest_examples() 搜尋問題
>>> from transformers import DPRQuestionEncoder, DPRQuestionEncoderTokenizer
>>> q_encoder = DPRQuestionEncoder.from_pretrained("facebook/dpr-question_encoder-single-nq-base")
>>> q_tokenizer = DPRQuestionEncoderTokenizer.from_pretrained("facebook/dpr-question_encoder-single-nq-base")
>>> question = "Is it serious ?"
>>> question_embedding = q_encoder(**q_tokenizer(question, return_tensors="pt"))[0][0].numpy()
>>> scores, retrieved_examples = ds_with_embeddings.get_nearest_examples('embeddings', question_embedding, k=10)
>>> retrieved_examples["line"][0]
'_that_ serious? It is not serious at all. It’s simply a fantasy to amuse\r\n'
- 您可以使用 Dataset.get_index() 訪問索引並將其用於特殊操作,例如使用
range_search
進行查詢
>>> faiss_index = ds_with_embeddings.get_index('embeddings').faiss_index
>>> limits, distances, indices = faiss_index.range_search(x=question_embedding.reshape(1, -1), thresh=0.95)
- 當您完成查詢後,使用 Dataset.save_faiss_index() 將索引儲存到磁碟
>>> ds_with_embeddings.save_faiss_index('embeddings', 'my_index.faiss')
- 稍後使用 Dataset.load_faiss_index() 重新載入它
>>> ds = load_dataset('crime_and_punish', split='train[:100]')
>>> ds.load_faiss_index('embeddings', 'my_index.faiss')
Elasticsearch
與 FAISS 不同,Elasticsearch 基於精確匹配來檢索文件。
在您的機器上啟動 Elasticsearch,或者如果您尚未安裝,請參閱 Elasticsearch 安裝指南。
- 載入您要索引的資料集
>>> from datasets import load_dataset
>>> squad = load_dataset('rajpurkar/squad', split='validation')
>>> squad.add_elasticsearch_index("context", host="localhost", port="9200")
- 然後您可以使用 Dataset.get_nearest_examples() 查詢
context
索引
>>> query = "machine"
>>> scores, retrieved_examples = squad.get_nearest_examples("context", query, k=10)
>>> retrieved_examples["title"][0]
'Computational_complexity_theory'
- 如果您想重用索引,請在構建索引時定義
es_index_name
引數
>>> from datasets import load_dataset
>>> squad = load_dataset('rajpurkar/squad', split='validation')
>>> squad.add_elasticsearch_index("context", host="localhost", port="9200", es_index_name="hf_squad_val_context")
>>> squad.get_index("context").es_index_name
hf_squad_val_context
- 稍後在呼叫 Dataset.load_elasticsearch_index() 時使用索引名稱重新載入它
>>> from datasets import load_dataset
>>> squad = load_dataset('rajpurkar/squad', split='validation')
>>> squad.load_elasticsearch_index("context", host="localhost", port="9200", es_index_name="hf_squad_val_context")
>>> query = "machine"
>>> scores, retrieved_examples = squad.get_nearest_examples("context", query, k=10)
對於更高階的 Elasticsearch 用法,您可以使用自定義設定指定您自己的配置
>>> import elasticsearch as es
>>> import elasticsearch.helpers
>>> from elasticsearch import Elasticsearch
>>> es_client = Elasticsearch([{"host": "localhost", "port": "9200"}]) # default client
>>> es_config = {
... "settings": {
... "number_of_shards": 1,
... "analysis": {"analyzer": {"stop_standard": {"type": "standard", " stopwords": "_english_"}}},
... },
... "mappings": {"properties": {"text": {"type": "text", "analyzer": "standard", "similarity": "BM25"}}},
... } # default config
>>> es_index_name = "hf_squad_context" # name of the index in Elasticsearch
>>> squad.add_elasticsearch_index("context", es_client=es_client, es_config=es_config, es_index_name=es_index_name)