開源 AI 食譜文件

利用 FAISS 實現語義快取以改進 RAG 系統。

Hugging Face's logo
加入 Hugging Face 社群

並獲得增強的文件體驗

開始使用

Open In Colab

利用 FAISS 實現語義快取以改進 RAG 系統。

作者:Pere Martra

在本 notebook 中,我們將探索一個典型的 RAG 解決方案,其中將使用一個開源模型和向量資料庫 Chroma DB。 然而,我們將整合一個語義快取系統,它會儲存各種使用者查詢,並決定是從向量資料庫還是從快取中生成富含資訊的提示。

語義快取系統旨在識別相似或相同的使用者請求。當找到匹配的請求時,系統會從快取中檢索相應的資訊,從而減少了從原始來源獲取資訊的需求。

由於比較考慮了請求的語義含義,因此即使請求不完全相同,系統也能識別出它們是同一個問題。它們的表述方式可能不同,或者包含不準確之處(無論是拼寫錯誤還是句子結構錯誤),但我們仍然可以識別出使用者實際上在請求相同的資訊。

例如,像 法國的首都是什麼?告訴我法國首都的名字?法國的首都是? 等查詢都表達了相同的意圖,應被識別為同一個問題。

雖然模型的響應可能因第二個例子中要求簡潔回答而有所不同,但從向量資料庫中檢索到的資訊應該是相同的。這就是為什麼我將快取系統置於使用者和向量資料庫之間,而不是使用者和大型語言模型之間。

大多數指導您建立 RAG 系統的教程都是為單使用者使用而設計的,旨在在測試環境中執行。換句話說,在一個 notebook 中,與本地向量資料庫互動,並進行 API 呼叫或使用本地儲存的模型。

當試圖將這些模型之一投入生產時,這種架構很快變得不足,因為它們可能會遇到從幾十到數千個重複請求。

提升效能的一種方法是使用一個或多個語義快取。這個快取會保留先前請求的結果,在解決新請求之前,它會檢查之前是否收到過類似的請求。如果是,它會從快取中檢索資訊,而不是重新執行整個過程。

在 RAG 系統中,有兩個耗時的點

  • 檢索用於構建增強提示的資訊
  • 呼叫大型語言模型以獲取響應。

在這兩個點上都可以實現語義快取系統,我們甚至可以有兩個快取,每個點一個。

將其置於模型的響應點可能會導致對所獲響應的影響力減弱。我們的快取系統可能會將“用 10 個詞解釋法國大革命”和“用一百個詞解釋法國大革命”視為同一個查詢。如果我們的快取系統儲存模型響應,使用者可能會認為他們的指令沒有被準確遵循。

但這兩個請求都需要相同的資訊來豐富提示。這就是我選擇將語義快取系統置於使用者請求和從向量資料庫檢索資訊之間的主要原因。

然而,這是一個設計決策。根據響應型別和系統請求,它可以放在一個點或另一個點。很明顯,快取模型響應會節省最多的時間,但正如我已經解釋過的,這是以犧牲使用者對響應的影響力為代價的。

匯入並載入庫。

首先,我們需要安裝必要的 Python 包。

  • sentence transformers。這個庫對於將句子轉換為固定長度的向量(也稱為嵌入)是必需的。
  • xformers。它是一個提供庫和實用程式的包,以方便使用 transformers 模型。我們需要安裝它以避免在使用模型和嵌入時出現錯誤。
  • chromadb。這是我們的向量資料庫。ChromaDB 易於使用且開源,可能是用於儲存嵌入的最常用的向量資料庫。
  • accelerate 需要在 GPU 上執行模型。
!pip install -q transformers==4.38.1
!pip install -q accelerate==0.27.2
!pip install -q sentence-transformers==2.5.1
!pip install -q xformers==0.0.24
!pip install -q chromadb==0.4.24
!pip install -q datasets==2.17.1
import numpy as np
import pandas as pd

載入資料集

由於我們在一個免費且受限的空間中工作,並且只能使用幾 GB 的記憶體,我用變數 `MAX_ROWS` 限制了要使用的資料集中的行數。

#Login to Hugging Face. It is mandatory to use the Gemma Model,
#and recommended to acces public models and Datasets.
from getpass import getpass
if 'hf_key' not in locals():
  hf_key = getpass("Your Hugging Face API Key: ")
!huggingface-cli login --token $hf_key
from datasets import load_dataset

data = load_dataset("keivalya/MedQuad-MedicalQnADataset", split="train")

ChromaDB 要求資料具有唯一的識別符號。我們可以透過這條語句來實現,它將建立一個名為 Id 的新列。

data = data.to_pandas()
data["id"] = data.index
data.head(10)
MAX_ROWS = 15000
DOCUMENT = "Answer"
TOPIC = "qtype"
# Because it is just a sample we select a small portion of News.
subset_data = data.head(MAX_ROWS)

匯入和配置向量資料庫

為了儲存資訊,我選擇使用 ChromaDB,它是最著名和最廣泛使用的開源向量資料庫之一。

首先我們需要匯入 ChromaDB。

import chromadb

現在我們只需要指定向量資料庫將儲存的路徑。

chroma_client = chromadb.PersistentClient(path="/path/to/persist/directory")

填充和查詢 ChromaDB 資料庫

ChromaDB 中的資料儲存在集合(collections)中。如果集合存在,我們需要刪除它。

在接下來的幾行中,我們透過呼叫上面建立的 `chroma_client` 中的 `create_collection` 函式來建立集合。

collection_name = "news_collection"
if len(chroma_client.list_collections()) > 0 and collection_name in [chroma_client.list_collections()[0].name]:
    chroma_client.delete_collection(name=collection_name)

collection = chroma_client.create_collection(name=collection_name)

現在我們準備好使用 `add` 函式將資料新增到集合中。這個函式需要三個關鍵資訊:

  • document 中,我們儲存資料集中 `Answer` 列的內容。
  • metadatas 中,我們可以提供一個主題列表。我使用了 `qtype` 列中的值。
  • id 中,我們需要為每一行提供一個唯一的識別符號。我使用 `MAX_ROWS` 的範圍來建立 ID。
collection.add(
    documents=subset_data[DOCUMENT].tolist(),
    metadatas=[{TOPIC: topic} for topic in subset_data[TOPIC].tolist()],
    ids=[f"id{x}" for x in range(MAX_ROWS)],
)

一旦我們將資訊存入資料庫,我們就可以查詢它,並請求符合我們需求的資料。搜尋是在文件內容中進行的,它不查詢確切的單詞或短語。結果將基於搜尋詞與文件內容之間的相似性。

元資料不直接參與初始搜尋過程,但可以在檢索後用於過濾或最佳化結果,從而實現進一步的定製和精確性。

讓我們定義一個函式來查詢 ChromaDB 資料庫。

def query_database(query_text, n_results=10):
    results = collection.query(query_texts=query_text, n_results=n_results)
    return results

建立語義快取系統

為了實現快取系統,我們將使用 Faiss,這是一個允許在記憶體中儲存嵌入的庫。它與 Chroma 的功能非常相似,但不具備其永續性。

為此,我們將建立一個名為 `semantic_cache` 的類,該類將使用自己的編碼器,並提供使用者執行查詢所需的函式。

在這個類中,我們首先查詢使用 Faiss 實現的快取,該快取包含先前的請求。如果返回的結果高於指定的閾值,它將返回快取的內容。否則,它將從 Chroma 資料庫中獲取結果。

快取儲存在一個 .json 檔案中。

!pip install -q faiss-cpu==1.8.0
import faiss
from sentence_transformers import SentenceTransformer
import time
import json

下面的 `init_cache()` 函式初始化語義快取。

它採用了 FlatLS 索引,這可能不是最快的,但對於小型資料集來說是理想的。根據快取資料的特性和預期的資料集大小,可以使用其他索引,如 HNSW 或 IVF。

我選擇這個索引是因為它非常適合這個例子。它可以用於高維向量,消耗最少的記憶體,並且在小型資料集上表現良好。

我概述了 Faiss 可用的各種索引的主要特點。

  • FlatL2 或 FlatIP。非常適合小型資料集,它可能不是最快的,但其記憶體消耗並不算多。
  • LSH。它在小型資料集上效果很好,建議用於最多 128 維的向量。
  • HNSW。速度非常快,但需要大量的 RAM。
  • IVF。在大型資料集上表現良好,且不消耗太多記憶體或犧牲效能。

關於 Faiss 可用的不同索引的更多資訊可以在此連結找到:https://github.com/facebookresearch/faiss/wiki/Guidelines-to-choose-an-index

def init_cache():
    index = faiss.IndexFlatL2(768)
    if index.is_trained:
        print("Index trained")

    # Initialize Sentence Transformer model
    encoder = SentenceTransformer("all-mpnet-base-v2")

    return index, encoder

在 `retrieve_cache` 函式中,.json 檔案從磁碟中檢索,以備在跨會話重用快取時需要。

def retrieve_cache(json_file):
    try:
        with open(json_file, "r") as file:
            cache = json.load(file)
    except FileNotFoundError:
        cache = {"questions": [], "embeddings": [], "answers": [], "response_text": []}

    return cache

`store_cache` 函式將包含快取資料的檔案儲存到磁碟。

def store_cache(json_file, cache):
    with open(json_file, "w") as file:
        json.dump(cache, file)

這些函式將在 `SemanticCache` 類中使用,該類包括搜尋函式及其初始化函式。

儘管 `ask` 函式的程式碼量很大,但其目的非常直接。它在快取中查詢與使用者剛剛提出的問題最接近的問題。

之後,檢查它是否在指定的閾值內。如果是,它直接從快取返回響應;否則,它呼叫 `query_database` 函式從 ChromaDB 檢索資料。

我使用了歐幾里得距離而不是在向量比較中廣泛使用的餘弦距離。這個選擇是基於歐幾里得距離是 Faiss 使用的預設度量標準。雖然也可以計算餘弦距離,但這樣做會增加複雜性,而可能對最終結果沒有顯著貢獻。

我在 semantic_cache 類中加入了 FIFO 驅逐策略,旨在提高其效率和靈活性。透過引入驅逐策略,我們讓使用者能夠控制快取達到最大容量時的行為。這對於維持最佳快取效能和處理可用記憶體受限的情況至關重要。

從快取的結構來看,FIFO 的實現似乎很簡單。每當一個新的問題-答案對被新增到快取中時,它都會被追加到列表的末尾。因此,最舊的(先進先出)項位於列表的前面。當快取達到其最大大小時,你需要驅逐一個項,你只需從每個列表中移除(pop)第一個項。這就是 FIFO 驅逐策略。

另一種驅逐策略是最近最少使用(LRU)策略,這種策略更復雜,因為它需要知道快取中每個專案最後一次被訪問的時間。然而,這個策略目前尚不可用,將在以後實現。

class semantic_cache:
    def __init__(self, json_file="cache_file.json", thresold=0.35, max_response=100, eviction_policy=None):
        """Initializes the semantic cache.

        Args:
        json_file (str): The name of the JSON file where the cache is stored.
        thresold (float): The threshold for the Euclidean distance to determine if a question is similar.
        max_response (int): The maximum number of responses the cache can store.
        eviction_policy (str): The policy for evicting items from the cache.
                                This can be any policy, but 'FIFO' (First In First Out) has been implemented for now.
                                If None, no eviction policy will be applied.
        """

        # Initialize Faiss index with Euclidean distance
        self.index, self.encoder = init_cache()

        # Set Euclidean distance threshold
        # a distance of 0 means identicals sentences
        # We only return from cache sentences under this thresold
        self.euclidean_threshold = thresold

        self.json_file = json_file
        self.cache = retrieve_cache(self.json_file)
        self.max_response = max_response
        self.eviction_policy = eviction_policy

    def evict(self):
        """Evicts an item from the cache based on the eviction policy."""
        if self.eviction_policy and len(self.cache["questions"]) > self.max_size:
            for _ in range((len(self.cache["questions"]) - self.max_response)):
                if self.eviction_policy == "FIFO":
                    self.cache["questions"].pop(0)
                    self.cache["embeddings"].pop(0)
                    self.cache["answers"].pop(0)
                    self.cache["response_text"].pop(0)

    def ask(self, question: str) -> str:
        # Method to retrieve an answer from the cache or generate a new one
        start_time = time.time()
        try:
            # First we obtain the embeddings corresponding to the user question
            embedding = self.encoder.encode([question])

            # Search for the nearest neighbor in the index
            self.index.nprobe = 8
            D, I = self.index.search(embedding, 1)

            if D[0] >= 0:
                if I[0][0] >= 0 and D[0][0] <= self.euclidean_threshold:
                    row_id = int(I[0][0])

                    print("Answer recovered from Cache. ")
                    print(f"{D[0][0]:.3f} smaller than {self.euclidean_threshold}")
                    print(f"Found cache in row: {row_id} with score {D[0][0]:.3f}")
                    print(f"response_text: " + self.cache["response_text"][row_id])

                    end_time = time.time()
                    elapsed_time = end_time - start_time
                    print(f"Time taken: {elapsed_time:.3f} seconds")
                    return self.cache["response_text"][row_id]

            # Handle the case when there are not enough results
            # or Euclidean distance is not met, asking to chromaDB.
            answer = query_database([question], 1)
            response_text = answer["documents"][0][0]

            self.cache["questions"].append(question)
            self.cache["embeddings"].append(embedding[0].tolist())
            self.cache["answers"].append(answer)
            self.cache["response_text"].append(response_text)

            print("Answer recovered from ChromaDB. ")
            print(f"response_text: {response_text}")

            self.index.add(embedding)

            self.evict()

            store_cache(self.json_file, self.cache)

            end_time = time.time()
            elapsed_time = end_time - start_time
            print(f"Time taken: {elapsed_time:.3f} seconds")

            return response_text
        except Exception as e:
            raise RuntimeError(f"Error during 'ask' method: {e}")

測試 semantic_cache 類。

>>> # Initialize the cache.
>>> cache = semantic_cache("4cache.json")
Index trained
>>> results = cache.ask("How do vaccines work?")
Answer recovered from ChromaDB. 
response_text: Summary : Shots may hurt a little, but the diseases they can prevent are a lot worse. Some are even life-threatening. Immunization shots, or vaccinations, are essential. They protect against things like measles, mumps, rubella, hepatitis B, polio, tetanus, diphtheria, and pertussis (whooping cough). Immunizations are important for adults as well as children.    Your immune system helps your body fight germs by producing substances to combat them. Once it does, the immune system "remembers" the germ and can fight it again. Vaccines contain germs that have been killed or weakened. When given to a healthy person, the vaccine triggers the immune system to respond and thus build immunity.     Before vaccines, people became immune only by actually getting a disease and surviving it. Immunizations are an easier and less risky way to become immune.     NIH: National Institute of Allergy and Infectious Diseases
Time taken: 0.057 seconds

正如預期的那樣,這個響應是從 ChromaDB 獲取的。然後類將其儲存在快取中。

現在,如果我們傳送一個截然不同的第二個問題,響應也應該從 ChromaDB 中檢索。這是因為之前儲存的問題與新問題非常不同,以至於它們的歐幾里得距離會超過指定的閾值。

>>> results = cache.ask("Explain briefly what is a Sydenham chorea")
Answer recovered from ChromaDB. 
response_text: Sydenham chorea (SD) is a neurological disorder of childhood resulting from infection via Group A beta-hemolytic streptococcus (GABHS), the bacterium that causes rheumatic fever. SD is characterized by rapid, irregular, and aimless involuntary movements of the arms and legs, trunk, and facial muscles. It affects girls more often than boys and typically occurs between 5 and 15 years of age. Some children will have a sore throat several weeks before the symptoms begin, but the disorder can also strike up to 6 months after the fever or infection has cleared. Symptoms can appear gradually or all at once, and also may include uncoordinated movements, muscular weakness, stumbling and falling, slurred speech, difficulty concentrating and writing, and emotional instability. The symptoms of SD can vary from a halting gait and slight grimacing to involuntary movements that are frequent and severe enough to be incapacitating. The random, writhing movements of chorea are caused by an auto-immune reaction to the bacterium that interferes with the normal function of a part of the brain (the basal ganglia) that controls motor movements. Due to better sanitary conditions and the use of antibiotics to treat streptococcal infections, rheumatic fever, and consequently SD, are rare in North America and Europe. The disease can still be found in developing nations.
Time taken: 0.082 seconds

很好,語義快取系統的行為符合預期。

讓我們繼續測試一個與我們剛才問的問題非常相似的問題。

在這種情況下,響應應該直接來自快取,而不需要訪問 ChromaDB 資料庫。

>>> results = cache.ask("Briefly explain me what is a Sydenham chorea.")
Answer recovered from Cache. 
0.028 smaller than 0.35
Found cache in row: 1 with score 0.028
response_text: Sydenham chorea (SD) is a neurological disorder of childhood resulting from infection via Group A beta-hemolytic streptococcus (GABHS), the bacterium that causes rheumatic fever. SD is characterized by rapid, irregular, and aimless involuntary movements of the arms and legs, trunk, and facial muscles. It affects girls more often than boys and typically occurs between 5 and 15 years of age. Some children will have a sore throat several weeks before the symptoms begin, but the disorder can also strike up to 6 months after the fever or infection has cleared. Symptoms can appear gradually or all at once, and also may include uncoordinated movements, muscular weakness, stumbling and falling, slurred speech, difficulty concentrating and writing, and emotional instability. The symptoms of SD can vary from a halting gait and slight grimacing to involuntary movements that are frequent and severe enough to be incapacitating. The random, writhing movements of chorea are caused by an auto-immune reaction to the bacterium that interferes with the normal function of a part of the brain (the basal ganglia) that controls motor movements. Due to better sanitary conditions and the use of antibiotics to treat streptococcal infections, rheumatic fever, and consequently SD, are rare in North America and Europe. The disease can still be found in developing nations.
Time taken: 0.019 seconds

這兩個問題是如此相似,以至於它們的歐幾里得距離非常小,幾乎就像它們是相同的一樣。

現在,讓我們嘗試另一個問題,這次稍微有點不同,並觀察系統的行為。

>>> question_def = "Write in 20 words what is a Sydenham chorea."
>>> results = cache.ask(question_def)
Answer recovered from Cache. 
0.228 smaller than 0.35
Found cache in row: 1 with score 0.228
response_text: Sydenham chorea (SD) is a neurological disorder of childhood resulting from infection via Group A beta-hemolytic streptococcus (GABHS), the bacterium that causes rheumatic fever. SD is characterized by rapid, irregular, and aimless involuntary movements of the arms and legs, trunk, and facial muscles. It affects girls more often than boys and typically occurs between 5 and 15 years of age. Some children will have a sore throat several weeks before the symptoms begin, but the disorder can also strike up to 6 months after the fever or infection has cleared. Symptoms can appear gradually or all at once, and also may include uncoordinated movements, muscular weakness, stumbling and falling, slurred speech, difficulty concentrating and writing, and emotional instability. The symptoms of SD can vary from a halting gait and slight grimacing to involuntary movements that are frequent and severe enough to be incapacitating. The random, writhing movements of chorea are caused by an auto-immune reaction to the bacterium that interferes with the normal function of a part of the brain (the basal ganglia) that controls motor movements. Due to better sanitary conditions and the use of antibiotics to treat streptococcal infections, rheumatic fever, and consequently SD, are rare in North America and Europe. The disease can still be found in developing nations.
Time taken: 0.016 seconds

我們觀察到歐幾里得距離增加了,但它仍然在指定的閾值內。因此,它繼續直接從快取中返回響應。

載入模型並建立提示

是時候使用 transformers 庫了,這是 hugging face 最著名的用於處理語言模型的庫。

我們正在匯入

  • Autotokenizer: 它是一個實用工具類,用於對與各種預訓練語言模型相容的文字輸入進行分詞。
  • AutoModelForCausalLM: 它提供了一個介面,用於專門為因果語言建模(例如,GPT 模型)的語言生成任務設計的預訓練語言模型,或者本 notebook 中使用的模型 Gemma-2b-it

請隨意測試不同的模型,您需要搜尋為文字生成訓練的 NLP 模型。

!pip install torch
from torch import cuda, torch

# In a MAC Silicon the device must be 'mps'
# device = torch.device('mps') #to use with MAC Silicon
device = f"cuda:{cuda.current_device()}" if cuda.is_available() else "cpu"
from transformers import AutoTokenizer, AutoModelForCausalLM

model_id = "google/gemma-2b-it"
tokenizer = AutoTokenizer.from_pretrained(model_id)
model = AutoModelForCausalLM.from_pretrained(model_id, device_map="cuda", torch_dtype=torch.bfloat16)

建立擴充套件提示

為了建立提示,我們使用查詢‘semantic_cache’類的結果和使用者輸入的問題。

提示有兩部分,相關上下文 是從資料庫中恢復的資訊,以及 使用者的問題

我們只需要將兩部分放在一起建立提示,然後傳送給模型。

prompt_template = f"Relevant context: {results}\n\n The user's question: {question_def}"
prompt_template
input_ids = tokenizer(prompt_template, return_tensors="pt").to("cuda")

現在剩下的就是把提示傳送給模型,等待它的回應!

>>> outputs = model.generate(**input_ids, max_new_tokens=256)
>>> print(tokenizer.decode(outputs[0]))
Relevant context: Sydenham chorea (SD) is a neurological disorder of childhood resulting from infection via Group A beta-hemolytic streptococcus (GABHS), the bacterium that causes rheumatic fever. SD is characterized by rapid, irregular, and aimless involuntary movements of the arms and legs, trunk, and facial muscles. It affects girls more often than boys and typically occurs between 5 and 15 years of age. Some children will have a sore throat several weeks before the symptoms begin, but the disorder can also strike up to 6 months after the fever or infection has cleared. Symptoms can appear gradually or all at once, and also may include uncoordinated movements, muscular weakness, stumbling and falling, slurred speech, difficulty concentrating and writing, and emotional instability. The symptoms of SD can vary from a halting gait and slight grimacing to involuntary movements that are frequent and severe enough to be incapacitating. The random, writhing movements of chorea are caused by an auto-immune reaction to the bacterium that interferes with the normal function of a part of the brain (the basal ganglia) that controls motor movements. Due to better sanitary conditions and the use of antibiotics to treat streptococcal infections, rheumatic fever, and consequently SD, are rare in North America and Europe. The disease can still be found in developing nations.

 The user's question: Write in 20 words what is a Sydenham chorea.

Sure, here is a 20-word answer:

Sydenham chorea is a neurological disorder of childhood resulting from infection via Group A beta-hemolytic streptococcus (GABHS).

結論。

在訪問 ChromaDB 和直接訪問快取之間,資料檢索時間減少了 50%。然而,在更大的專案中,這種差異會增加,導致 90-95% 的改進。

我們在 Chroma 中的資料很少,並且只有一個快取類的例項。通常,快取系統背後的資料要大得多,可能不僅僅是查詢向量資料庫,而是來自各種來源。

通常會有多個快取類的例項,通常基於使用者型別,因為具有共同特徵的使用者之間的問題更容易重複。

總而言之,我們建立了一個非常簡單的 RAG(檢索增強生成)系統,並在使用者問題和獲取建立增強提示所需資訊之間增加了一個語義快取層。

< > 在 GitHub 上更新

© . This site is unofficial and not affiliated with Hugging Face, Inc.