開源 AI 食譜文件
使用向量嵌入和 Qdrant 進行程式碼搜尋
並獲得增強的文件體驗
開始使用
使用向量嵌入和 Qdrant 進行程式碼搜尋
作者:Qdrant 團隊
在本 notebook 中,我們將演示如何使用向量嵌入來瀏覽程式碼庫,並找到相關的程式碼片段。我們將使用自然的語義查詢來搜尋程式碼庫,並基於相似邏輯搜尋程式碼。
您可以檢視此方法的線上部署,它透過一個 Web 介面開放了 Qdrant 程式碼庫的搜尋功能。
方法
我們需要兩個模型來實現我們的目標。
用於自然語言處理 (NLP) 的通用神經編碼器,在我們的例子中是 sentence-transformers/all-MiniLM-L6-v2。我們稱之為 NLP 模型。
用於程式碼到程式碼相似性搜尋的專門嵌入。我們將使用 jinaai/jina-embeddings-v2-base-code 模型來完成此任務。它支援英語和 30 種廣泛使用的程式語言,序列長度為 8192。我們稱之為程式碼模型。
為了讓我們的程式碼適用於 NLP 模型,我們需要將程式碼預處理成一種與自然語言非常相似的格式。程式碼模型支援多種標準程式語言,因此無需對程式碼片段進行預處理。我們可以直接使用程式碼。
安裝依賴
讓我們安裝我們將要使用的包。
- inflection - 一個字串轉換庫。它可以將英文單詞進行單複數轉換,並將駝峰式命名轉換為下劃線命名。
- fastembed - 一個 CPU 優先、輕量級的向量嵌入生成庫。支援 GPU。
- qdrant-client - 與 Qdrant 伺服器互動的官方 Python 庫。
%pip install inflection qdrant-client fastembed
資料準備
將應用程式原始碼分塊成更小的部分是一項不小的任務。一般來說,函式、類方法、結構體、列舉以及所有其他特定於語言的構造都是分塊的理想選擇。它們足夠大,可以包含一些有意義的資訊,但又足夠小,可以被上下文視窗有限的嵌入模型處理。您還可以使用文件字串、註釋和其他元資料來豐富分塊的附加資訊。

基於文字的搜尋是基於函式簽名的,但程式碼搜尋可能會返回更小的片段,例如迴圈。因此,如果我們從 NLP 模型收到一個特定的函式簽名,並從程式碼模型收到其實現的一部分,我們會將結果合併。
解析程式碼庫
本演示將使用 Qdrant 程式碼庫。雖然此程式碼庫使用 Rust,但您也可以將此方法用於任何其他語言。您可以使用語言伺服器協議 (LSP) 工具來構建程式碼庫的圖,然後提取分塊。我們使用了 rust-analyzer 來完成這項工作。我們將解析後的程式碼庫匯出為 LSIF 格式,這是一種程式碼智慧資料的標準。接下來,我們使用 LSIF 資料來瀏覽程式碼庫並提取分塊。
您也可以對其他語言使用相同的方法。有大量的實現可供選擇。
然後,我們將分塊匯出為 JSON 文件,其中不僅包含程式碼本身,還包含程式碼在專案中的位置上下文。
您可以在我們的 Google Cloud Storage 儲存桶中的 structures.jsonl 檔案中檢視解析為 JSON 的 Qdrant 結構。下載它並將其用作我們程式碼搜尋的資料來源。
!wget https://storage.googleapis.com/tutorial-attachments/code-search/structures.jsonl
接下來,載入檔案並將各行解析為一個字典列表。
import json
structures = []
with open("structures.jsonl", "r") as fp:
for i, row in enumerate(fp):
entry = json.loads(row)
structures.append(entry)
讓我們看看其中一個條目的樣子。
structures[0]
{'name': 'InvertedIndexRam',
'signature': '# [doc = " Inverted flatten index from dimension id to posting list"] # [derive (Debug , Clone , PartialEq)] pub struct InvertedIndexRam { # [doc = " Posting lists for each dimension flattened (dimension id -> posting list)"] # [doc = " Gaps are filled with empty posting lists"] pub postings : Vec < PostingList > , # [doc = " Number of unique indexed vectors"] # [doc = " pre-computed on build and upsert to avoid having to traverse the posting lists."] pub vector_count : usize , }',
'code_type': 'Struct',
'docstring': '= " Inverted flatten index from dimension id to posting list"',
'line': 15,
'line_from': 13,
'line_to': 22,
'context': {'module': 'inverted_index',
'file_path': 'lib/sparse/src/index/inverted_index/inverted_index_ram.rs',
'file_name': 'inverted_index_ram.rs',
'struct_name': None,
'snippet': '/// Inverted flatten index from dimension id to posting list\n#[derive(Debug, Clone, PartialEq)]\npub struct InvertedIndexRam {\n /// Posting lists for each dimension flattened (dimension id -> posting list)\n /// Gaps are filled with empty posting lists\n pub postings: Vec<PostingList>,\n /// Number of unique indexed vectors\n /// pre-computed on build and upsert to avoid having to traverse the posting lists.\n pub vector_count: usize,\n}\n'}}
程式碼到自然語言的轉換
每種程式語言都有其自己的語法,這些語法不屬於自然語言。因此,通用模型可能無法直接理解程式碼。然而,我們可以透過移除程式碼特有的部分幷包含額外的上下文(如模組、類、函式和檔名)來對資料進行歸一化。我們採取以下步驟
- 提取函式、方法或其他程式碼構造的簽名。
- 將駝峰式命名和蛇形命名分割成獨立的單詞。
- 提取文件字串、註釋和其他重要的元資料。
- 使用預定義的模板,根據提取的資料構建一個句子。
- 移除特殊字元,並用空格替換。
現在我們可以定義 textify
函式,該函式使用 inflection
庫來執行我們的轉換
import inflection
import re
from typing import Dict, Any
def textify(chunk: Dict[str, Any]) -> str:
# Get rid of all the camel case / snake case
# - inflection.underscore changes the camel case to snake case
# - inflection.humanize converts the snake case to human readable form
name = inflection.humanize(inflection.underscore(chunk["name"]))
signature = inflection.humanize(inflection.underscore(chunk["signature"]))
# Check if docstring is provided
docstring = ""
if chunk["docstring"]:
docstring = f"that does {chunk['docstring']} "
# Extract the location of that snippet of code
context = f"module {chunk['context']['module']} " f"file {chunk['context']['file_name']}"
if chunk["context"]["struct_name"]:
struct_name = inflection.humanize(inflection.underscore(chunk["context"]["struct_name"]))
context = f"defined in struct {struct_name} {context}"
# Combine all the bits and pieces together
text_representation = f"{chunk['code_type']} {name} " f"{docstring}" f"defined as {signature} " f"{context}"
# Remove any special characters and concatenate the tokens
tokens = re.split(r"\W", text_representation)
tokens = filter(lambda x: x, tokens)
return " ".join(tokens)
現在我們可以使用 textify
將所有分塊轉換為文字表示
text_representations = list(map(textify, structures))
讓我們看看我們的一個表示形式是怎樣的
text_representations[1000]
'Function Hnsw discover precision that does Checks discovery search precision when using hnsw index this is different from the tests in defined as Fn hnsw discover precision module integration file hnsw_discover_test rs'
自然語言嵌入
from fastembed import TextEmbedding
batch_size = 5
nlp_model = TextEmbedding("sentence-transformers/all-MiniLM-L6-v2", threads=0)
nlp_embeddings = nlp_model.embed(text_representations, batch_size=batch_size)
程式碼嵌入
code_snippets = [structure["context"]["snippet"] for structure in structures]
code_model = TextEmbedding("jinaai/jina-embeddings-v2-base-code")
code_embeddings = code_model.embed(code_snippets, batch_size=batch_size)
構建 Qdrant 集合
Qdrant 支援多種部署模式。包括用於原型設計的記憶體模式、Docker 和 Qdrant Cloud。您可以參考安裝說明獲取更多資訊。
我們將繼續使用記憶體例項來進行本教程。
記憶體模式只能用於快速原型設計和測試。它是 Qdrant 伺服器方法的 Python 實現。
讓我們建立一個集合來儲存我們的向量。
from qdrant_client import QdrantClient, models
COLLECTION_NAME = "qdrant-sources"
client = QdrantClient(":memory:") # Use in-memory storage
# client = QdrantClient("http://locahost:6333") # For Qdrant server
client.create_collection(
COLLECTION_NAME,
vectors_config={
"text": models.VectorParams(
size=384,
distance=models.Distance.COSINE,
),
"code": models.VectorParams(
size=768,
distance=models.Distance.COSINE,
),
},
)
我們新建立的集合已經準備好接收資料了。讓我們上傳嵌入向量吧。
from tqdm import tqdm
points = []
total = len(structures)
print("Number of points to upload: ", total)
for id, (text_embedding, code_embedding, structure) in tqdm(
enumerate(zip(nlp_embeddings, code_embeddings, structures)), total=total
):
# FastEmbed returns generators. Embeddings are computed as consumed.
points.append(
models.PointStruct(
id=id,
vector={
"text": text_embedding,
"code": code_embedding,
},
payload=structure,
)
)
# Upload points in batches
if len(points) >= batch_size:
client.upload_points(COLLECTION_NAME, points=points, wait=True)
points = []
# Ensure any remaining points are uploaded
if points:
client.upload_points(COLLECTION_NAME, points=points)
print(f"Total points in collection: {client.count(COLLECTION_NAME).count}")
上傳的點立即可用於搜尋。接下來,查詢集合以查詢相關的程式碼片段。
查詢程式碼庫
我們使用其中一個模型透過 Qdrant 的新查詢 API 來搜尋集合。從文字嵌入開始。執行以下查詢:“如何計算集合中的點數?”。檢視結果。
query = "How do I count points in a collection?"
hits = client.query_points(
COLLECTION_NAME,
query=next(nlp_model.query_embed(query)).tolist(),
using="text",
limit=3,
).points
現在,檢視結果。下表列出了模組、檔名和得分。每一行都包含一個指向簽名的連結。
模組 | 檔名 | 得分 | 簽名 |
---|---|---|---|
operations | types.rs | 0.5493385 | pub struct CountRequestInternal |
map_index | types.rs | 0.49973965 | fn get_points_with_value_count |
map_index | mutable_map_index.rs | 0.49941066 | pub fn get_points_with_value_count |
看來我們已經找到了相關的程式碼結構。讓我們用程式碼嵌入再試一次。
hits = client.query_points(
COLLECTION_NAME,
query=next(code_model.query_embed(query)).tolist(),
using="code",
limit=3,
).points
輸出
模組 | 檔名 | 得分 | 簽名 |
---|---|---|---|
field_index | geo_index.rs | 0.7217579 | fn count_indexed_points |
numeric_index | mod.rs | 0.7113214 | fn count_indexed_points |
full_text_index | text_index.rs | 0.6993165 | fn count_indexed_points |
雖然不同模型檢索到的分數不可比較,但我們可以看到結果是不同的。程式碼和文字嵌入可以捕捉到程式碼庫的不同方面。我們可以同時使用兩個模型來查詢集合,然後結合結果以獲得最相關的程式碼片段。
from qdrant_client import models
hits = client.query_points(
collection_name=COLLECTION_NAME,
prefetch=[
models.Prefetch(
query=next(nlp_model.query_embed(query)).tolist(),
using="text",
limit=5,
),
models.Prefetch(
query=next(code_model.query_embed(query)).tolist(),
using="code",
limit=5,
),
],
query=models.FusionQuery(fusion=models.Fusion.RRF),
).points
>>> for hit in hits:
... print(
... "| ",
... hit.payload["context"]["module"],
... " | ",
... hit.payload["context"]["file_path"],
... " | ",
... hit.score,
... " | `",
... hit.payload["signature"],
... "` |",
... )
| operations | lib/collection/src/operations/types.rs | 0.5 | ` # [doc = " Count Request"] # [doc = " Counts the number of points which satisfy the given filter."] # [doc = " If filter is not provided, the count of all points in the collection will be returned."] # [derive (Debug , Deserialize , Serialize , JsonSchema , Validate)] # [serde (rename_all = "snake_case")] pub struct CountRequestInternal { # [doc = " Look only for points which satisfies this conditions"] # [validate] pub filter : Option < Filter > , # [doc = " If true, count exact number of points. If false, count approximate number of points faster."] # [doc = " Approximate count might be unreliable during the indexing process. Default: true"] # [serde (default = "default_exact_count")] pub exact : bool , } ` | | field_index | lib/segment/src/index/field_index/geo_index.rs | 0.5 | ` fn count_indexed_points (& self) -> usize ` | | map_index | lib/segment/src/index/field_index/map_index/mod.rs | 0.33333334 | ` fn get_points_with_value_count < Q > (& self , value : & Q) -> Option < usize > where Q : ? Sized , N : std :: borrow :: Borrow < Q > , Q : Hash + Eq , ` | | numeric_index | lib/segment/src/index/field_index/numeric_index/mod.rs | 0.33333334 | ` fn count_indexed_points (& self) -> usize ` | | fixtures | lib/segment/src/fixtures/payload_context_fixture.rs | 0.25 | ` fn total_point_count (& self) -> usize ` | | map_index | lib/segment/src/index/field_index/map_index/mutable_map_index.rs | 0.25 | ` fn get_points_with_value_count < Q > (& self , value : & Q) -> Option < usize > where Q : ? Sized , N : std :: borrow :: Borrow < Q > , Q : Hash + Eq , ` | | id_tracker | lib/segment/src/id_tracker/simple_id_tracker.rs | 0.2 | ` fn total_point_count (& self) -> usize ` | | map_index | lib/segment/src/index/field_index/map_index/mod.rs | 0.2 | ` fn count_indexed_points (& self) -> usize ` | | map_index | lib/segment/src/index/field_index/map_index/mod.rs | 0.16666667 | ` fn count_indexed_points (& self) -> usize ` | | field_index | lib/segment/src/index/field_index/stat_tools.rs | 0.16666667 | ` fn number_of_selected_points (points : usize , values : usize) -> usize ` |
這是如何融合不同模型結果的一個例子。在實際場景中,您可能需要進行一些重排和去重,以及對結果進行額外的處理。
對結果進行分組
您可以透過根據有效載荷屬性對搜尋結果進行分組來改善搜尋結果。在我們的案例中,我們可以按模組對結果進行分組。如果我們使用程式碼嵌入,我們可以看到來自 map_index
模組的多個結果。讓我們對結果進行分組,並假設每個模組只顯示一個結果
results = client.query_points_groups(
COLLECTION_NAME,
query=next(code_model.query_embed(query)).tolist(),
using="code",
group_by="context.module",
limit=5,
group_size=1,
)
>>> for group in results.groups:
... for hit in group.hits:
... print(
... "| ",
... hit.payload["context"]["module"],
... " | ",
... hit.payload["context"]["file_name"],
... " | ",
... hit.score,
... " | `",
... hit.payload["signature"],
... "` |",
... )
| field_index | geo_index.rs | 0.7217579 | ` fn count_indexed_points (& self) -> usize ` | | numeric_index | mod.rs | 0.7113214 | ` fn count_indexed_points (& self) -> usize ` | | fixtures | payload_context_fixture.rs | 0.6993165 | ` fn total_point_count (& self) -> usize ` | | map_index | mod.rs | 0.68385994 | ` fn count_indexed_points (& self) -> usize ` | | full_text_index | text_index.rs | 0.6660142 | ` fn count_indexed_points (& self) -> usize ` |
我們的教程到此結束。感謝您花時間看到這裡。我們才剛剛開始探索向量嵌入的可能性以及如何改進它。歡迎您自由嘗試;您可能會創造出非常酷的東西!請與我們分享 🙏 我們在這裡。
< > 在 GitHub 上更新