開源 AI 食譜文件
由 SQL 和 Jina Reranker v2 支援的 RAG
並獲得增強的文件體驗
開始使用
由 SQL 和 Jina Reranker v2 支援的 RAG
作者:Scott Martens @ Jina AI
本筆記本將向您展示如何製作一個簡單的檢索增強生成 (RAG) 系統,該系統從 SQL 資料庫而不是從文件庫中提取資訊。
工作原理
- 給定一個 SQL 資料庫,我們提取 SQL 表定義 (SQL 轉儲中的 `CREATE` 行) 並存儲它們。在本教程中,我們已經為您完成了這一部分,定義以列表形式儲存在記憶體中。從本示例擴充套件可能需要更復雜的儲存方式。
- 使用者以自然語言輸入一個查詢。
- Jina Reranker v2 (`jinaai/jina-reranker-v2-base-multilingual`),一個來自 Jina AI 的支援 SQL 的重排模型,會根據與使用者查詢的相關性對錶定義進行排序。
- 我們向 Mistral 7B Instruct v0.1 (`mistralai/Mistral-7B-Instruct-v0.1`) 提供一個提示,其中包含使用者的查詢和排名前三的表定義,並請求它編寫一個 SQL 查詢來完成任務。
- Mistral Instruct 生成一個 SQL 查詢,我們對資料庫執行該查詢,檢索結果。
- SQL 查詢結果被轉換為 JSON 格式,並與使用者的原始查詢、SQL 查詢一起在一個新的提示中提供給 Mistral Instruct,並請求它用自然語言為使用者撰寫一個答案。
- Mistral Instruct 的自然語言文字響應將返回給使用者。
資料庫
在本教程中,我們使用一個關於影片遊戲銷售記錄的小型開放訪問資料庫,該資料庫儲存在 GitHub 上。我們將使用 SQLite 版本,因為 SQLite 非常緊湊、跨平臺,並且內建了 Python 支援。
軟硬體要求
我們將在本地執行 Jina Reranker v2 模型。如果您使用 Google Colab 執行此筆記本,請確保您使用的執行時可以訪問 GPU。如果您在本地執行,您將需要 Python 3 (本教程使用 Python 3.11 版本編寫),並且在有支援 CUDA 的 GPU 的情況下執行速度會*快得多*。
在本教程中,我們還將廣泛使用開源的 LlamaIndex RAG 框架,以及 Hugging Face 推理 API 來訪問 Mistral 7B Instruct v0.1。您將需要一個 Hugging Face 賬戶 和一個具有至少 `READ` 訪問許可權的訪問令牌。
如果您使用 Google Colab,SQLite 已經安裝好了。它可能沒有安裝在您的本地計算機上。如果尚未安裝,請按照 SQLite 網站上的說明進行安裝。Python 介面程式碼內置於 Python 中,您無需為其安裝任何 Python 模組。
環境設定
安裝依賴
首先,安裝所需的 Python 模組
!pip install -qU transformers einops llama-index llama-index-postprocessor-jinaai-rerank llama-index-llms-huggingface "huggingface_hub[inference]"
下載資料庫
接下來,從 GitHub 將 SQLite 資料庫 `videogames.db` 下載到本地檔案空間。如果您的系統上沒有 `wget`,請從此連結下載資料庫,並將其放在執行此筆記本的同一目錄中。
!wget https://github.com/bbrumm/databasestar/raw/main/sample_databases/sample_db_videogames/sqlite/videogames.db
下載並執行 Jina Reranker v2
以下程式碼將下載 `jina-reranker-v2-base-multilingual` 模型並在本地執行。
from transformers import AutoModelForSequenceClassification
reranker_model = AutoModelForSequenceClassification.from_pretrained(
"jinaai/jina-reranker-v2-base-multilingual",
torch_dtype="auto",
trust_remote_code=True,
)
reranker_model.to("cuda") # or 'cpu' if no GPU is available
reranker_model.eval()
設定 Mistral Instruct 介面
我們將使用 LlamaIndex 建立一個持有者物件,用於連線到 Hugging Face 推理 API 以及在那裡執行的 `mistralai/Mixtral-8x7B-Instruct-v0.1` 副本。
首先,從您的 Hugging Face 賬戶設定頁面獲取一個 Hugging Face 訪問令牌。
在下方提示時輸入它
import getpass
print("Paste your Hugging Face access token here: ")
hf_token = getpass.getpass()
接下來,初始化一個 LlamaIndex 的 `HuggingFaceInferenceAPI` 類的例項,並將其儲存為 `mistral_llm`。
from llama_index.llms.huggingface import HuggingFaceInferenceAPI
mistral_llm = HuggingFaceInferenceAPI(model_name="mistralai/Mixtral-8x7B-Instruct-v0.1", token=hf_token)
使用支援 SQL 的 Jina Reranker v2
我們從位於 GitHub 上的資料庫匯入檔案中提取了八個表定義。執行以下命令將它們放入一個名為 `table_declarations` 的 Python 列表中。
table_declarations = [
"CREATE TABLE platform (\n\tid INTEGER PRIMARY KEY,\n\tplatform_name TEXT DEFAULT NULL\n);",
"CREATE TABLE genre (\n\tid INTEGER PRIMARY KEY,\n\tgenre_name TEXT DEFAULT NULL\n);",
"CREATE TABLE publisher (\n\tid INTEGER PRIMARY KEY,\n\tpublisher_name TEXT DEFAULT NULL\n);",
"CREATE TABLE region (\n\tid INTEGER PRIMARY KEY,\n\tregion_name TEXT DEFAULT NULL\n);",
"CREATE TABLE game (\n\tid INTEGER PRIMARY KEY,\n\tgenre_id INTEGER,\n\tgame_name TEXT DEFAULT NULL,\n\tCONSTRAINT fk_gm_gen FOREIGN KEY (genre_id) REFERENCES genre(id)\n);",
"CREATE TABLE game_publisher (\n\tid INTEGER PRIMARY KEY,\n\tgame_id INTEGER DEFAULT NULL,\n\tpublisher_id INTEGER DEFAULT NULL,\n\tCONSTRAINT fk_gpu_gam FOREIGN KEY (game_id) REFERENCES game(id),\n\tCONSTRAINT fk_gpu_pub FOREIGN KEY (publisher_id) REFERENCES publisher(id)\n);",
"CREATE TABLE game_platform (\n\tid INTEGER PRIMARY KEY,\n\tgame_publisher_id INTEGER DEFAULT NULL,\n\tplatform_id INTEGER DEFAULT NULL,\n\trelease_year INTEGER DEFAULT NULL,\n\tCONSTRAINT fk_gpl_gp FOREIGN KEY (game_publisher_id) REFERENCES game_publisher(id),\n\tCONSTRAINT fk_gpl_pla FOREIGN KEY (platform_id) REFERENCES platform(id)\n);",
"CREATE TABLE region_sales (\n\tregion_id INTEGER DEFAULT NULL,\n\tgame_platform_id INTEGER DEFAULT NULL,\n\tnum_sales REAL,\n CONSTRAINT fk_rs_gp FOREIGN KEY (game_platform_id) REFERENCES game_platform(id),\n\tCONSTRAINT fk_rs_reg FOREIGN KEY (region_id) REFERENCES region(id)\n);",
]
現在,我們定義一個函式,該函式接受自然語言查詢和表定義列表,使用 Jina Reranker v2 對所有表定義進行評分,並按從最高分到最低分的順序返回它們。
from typing import List, Tuple
def rank_tables(query: str, table_specs: List[str], top_n: int = 0) -> List[Tuple[float, str]]:
"""
Get sorted pairs of scores and table specifications, then return the top N,
or all if top_n is 0 or default.
"""
pairs = [[query, table_spec] for table_spec in table_specs]
scores = reranker_model.compute_score(pairs)
scored_tables = [(score, table_spec) for score, table_spec in zip(scores, table_specs)]
scored_tables.sort(key=lambda x: x[0], reverse=True)
if top_n and top_n < len(scored_tables):
return scored_tables[0:top_n]
return scored_tables
Jina Reranker v2 會為我們提供的每個表定義評分,預設情況下,此函式將返回所有表定義及其分數。可選引數 `top_n` 將返回結果的數量限制為使用者定義的數量,從得分最高的結果開始。
來試試看。首先,定義一個查詢。
user_query = "Identify the top 10 platforms by total sales."
執行 `rank_tables` 以返回一個表定義列表。我們將 `top_n` 設定為 3 以限制返回列表的大小,並將其分配給變數 `ranked_tables`,然後檢查結果。
ranked_tables = rank_tables(user_query, table_declarations, top_n=3)
ranked_tables
輸出應包括 `region_sales`、`platform` 和 `game_platform` 這幾張表,它們似乎都是查詢查詢答案的合理位置。
使用 Mistral Instruct 生成 SQL
我們將讓 Mistral Instruct v0.1 根據重排器給出的前三個表的宣告,編寫一個 SQL 查詢來滿足使用者的查詢。
首先,我們使用 LlamaIndex 的 `PromptTemplate` 類為此目的建立一個提示。
from llama_index.core import PromptTemplate
make_sql_prompt_tmpl_text = """
Generate a SQL query to answer the following question from the user:
\"{query_str}\"
The SQL query should use only tables with the following SQL definitions:
Table 1:
{table_1}
Table 2:
{table_2}
Table 3:
{table_3}
Make sure you ONLY output an SQL query and no explanation.
"""
make_sql_prompt_tmpl = PromptTemplate(make_sql_prompt_tmpl_text)
我們使用 `format` 方法來填充模板欄位,包括使用者查詢和來自 Jina Reranker v2 的前三個表宣告。
make_sql_prompt = make_sql_prompt_tmpl.format(
query_str=user_query, table_1=ranked_tables[0][1], table_2=ranked_tables[1][1], table_3=ranked_tables[2][1]
)
您可以看到我們將要傳遞給 Mistral Instruct 的實際文字。
print(make_sql_prompt)
現在,讓我們將提示傳送給 Mistral Instruct 並檢索其響應。
response = mistral_llm.complete(make_sql_prompt)
sql_query = str(response)
print(sql_query)
執行 SQL 查詢
使用內建的 Python SQLite 介面對資料庫 `videogames.db` 執行上述查詢。
import sqlite3
con = sqlite3.connect("videogames.db")
cur = con.cursor()
sql_response = cur.execute(sql_query).fetchall()
有關 SQLite 介面的詳細資訊,請參閱 Python3 文件。
檢查結果
sql_response
您可以透過執行自己的 SQL 查詢來檢查這是否正確。此資料庫中儲存的銷售資料是浮點數形式,大概是成千上萬或數百萬的單位銷量。
獲得自然語言答案
現在,我們將把使用者的查詢、SQL 查詢和結果連同一個新的提示模板一起傳回給 Mistral Instruct。
首先,使用 LlamaIndex 建立新的提示模板,方法同上。
rag_prompt_tmpl_str = """
Use the information in the JSON table to answer the following user query.
Do not explain anything, just answer concisely. Use natural language in your
answer, not computer formatting.
USER QUERY: {query_str}
JSON table:
{json_table}
This table was generated by the following SQL query:
{sql_query}
Answer ONLY using the information in the table and the SQL query, and if the
table does not provide the information to answer the question, answer
"No Information".
"""
rag_prompt_tmpl = PromptTemplate(rag_prompt_tmpl_str)
我們將把 SQL 輸出轉換為 JSON 格式,這是 Mistral Instruct v0.1 能理解的格式。
填充模板欄位
import json
rag_prompt = rag_prompt_tmpl.format(
query_str="Identify the top 10 platforms by total sales", json_table=json.dumps(sql_response), sql_query=sql_query
)
現在從 Mistral Instruct 請求一個自然語言響應
rag_response = mistral_llm.complete(rag_prompt)
print(str(rag_response))
自己動手試試
讓我們把所有這些組織成一個帶有異常捕獲的函式。
def answer_sql(user_query: str) -> str:
try:
ranked_tables = rank_tables(user_query, table_declarations, top_n=3)
except Exception as e:
print(f"Ranking failed.\nUser query:\n{user_query}\n\n")
raise (e)
make_sql_prompt = make_sql_prompt_tmpl.format(
query_str=user_query, table_1=ranked_tables[0][1], table_2=ranked_tables[1][1], table_3=ranked_tables[2][1]
)
try:
response = mistral_llm.complete(make_sql_prompt)
except Exception as e:
print(f"SQL query generation failed\nPrompt:\n{make_sql_prompt}\n\n")
raise (e)
# Backslash removal is a necessary hack because sometimes Mistral puts them
# in its generated code.
sql_query = str(response).replace("\\", "")
try:
sql_response = sqlite3.connect("videogames.db").cursor().execute(sql_query).fetchall()
except Exception as e:
print(f"SQL querying failed. Query:\n{sql_query}\n\n")
raise (e)
rag_prompt = rag_prompt_tmpl.format(query_str=user_query, json_table=json.dumps(sql_response), sql_query=sql_query)
try:
rag_response = mistral_llm.complete(rag_prompt)
return str(rag_response)
except Exception as e:
print(f"Answer generation failed. Prompt:\n{rag_prompt}\n\n")
raise (e)
試試看吧
print(answer_sql("Identify the top 10 platforms by total sales."))
嘗試一些其他查詢
print(answer_sql("Summarize sales by region."))
print(answer_sql("List the publisher with the largest number of published games."))
print(answer_sql("Display the year with most games released."))
print(answer_sql("What is the most popular game genre on the Wii platform?"))
print(answer_sql("What is the most popular game genre of 2012?"))
嘗試你自己的查詢
print(answer_sql("<INSERT QUESTION OR INSTRUCTION HERE>"))
回顧與總結
我們向您展示瞭如何製作一個非常基礎的 RAG (檢索增強生成) 系統,用於自然語言問答,該系統使用 SQL 資料庫作為資訊源。在此實現中,我們使用同一個大型語言模型 (Mistral Instruct v0.1) 來生成 SQL 查詢和構建自然語言響應。
這裡的資料庫是一個非常小的示例,要擴充套件此係統可能需要比僅僅對錶定義列表進行排序更復雜的方法。您可能希望使用一個兩階段過程,其中嵌入模型和向量儲存庫首先檢索更多結果,但重排模型會將其篩選到您能夠放入生成語言模型提示中的數量。
本筆記本假設任何請求都不需要超過三張表來滿足,顯然,在實踐中,這不可能總是成立。Mistral 7B Instruct v0.1 不保證產生正確 (甚至可執行) 的 SQL 輸出。在生產環境中,這樣的系統需要更深入的錯誤處理。
更復雜的錯誤處理、更長的輸入上下文視窗,以及專門用於 SQL 特定任務的生成模型,可能會在實際應用中產生巨大差異。
儘管如此,您可以在這裡看到 RAG 概念如何擴充套件到結構化資料庫,從而極大地擴充套件了其使用範圍。
< > 在 GitHub 上更新