使用 Hugging Face Datasets 和 Transformers 實現影像相似性
在這篇文章中,您將學習如何使用 🤗 Transformers 構建影像相似性系統。找出查詢影像和潛在候選影像之間的相似性是資訊檢索系統(例如反向影像搜尋)的一個重要用例。該系統試圖回答的問題是:給定一個查詢影像和一組候選影像,哪些影像與查詢影像最相似。
我們將利用 🤗 datasets
庫,因為它無縫支援並行處理,這在構建此係統時會派上用場。
儘管本文使用基於 ViT 的模型(nateraw/vit-base-beans
)和特定資料集(Beans),但它可以擴充套件到使用其他支援視覺模態的模型和其他影像資料集。您可以嘗試的一些著名模型包括:
此外,本文中介紹的方法也有可能擴充套件到其他模態。
要研究完全可用的影像相似性系統,您可以參考開頭連結的 Colab Notebook。
我們如何定義相似性?
為了構建這個系統,我們首先需要定義如何計算兩幅影像之間的相似性。一種廣泛流行的方法是計算給定影像的密集表示(嵌入),然後使用餘弦相似度度量來確定兩幅影像的相似程度。
在本文中,我們將使用“嵌入”來表示向量空間中的影像。這為我們提供了一種很好的方式,可以將影像的高維畫素空間(例如 224 x 224 x 3)有意義地壓縮到更低的維度(例如 768)。這樣做主要優點是減少了後續步驟中的計算時間。

計算嵌入
為了從影像中計算嵌入,我們將使用一個視覺模型,該模型對如何在向量空間中表示輸入影像有一定的理解。這種型別的模型也通常被稱為影像編碼器。
為了載入模型,我們利用 AutoModel
類。它為我們提供了一個介面,用於從 Hugging Face Hub 載入任何相容的模型檢查點。除了模型,我們還載入了與模型關聯的處理器用於資料預處理。
from transformers import AutoImageProcessor, AutoModel
model_ckpt = "nateraw/vit-base-beans"
processor = AutoImageProcessor.from_pretrained(model_ckpt)
model = AutoModel.from_pretrained(model_ckpt)
在這種情況下,檢查點是透過在 beans
資料集上微調 基於 Vision Transformer 的模型獲得的。
這裡可能會出現一些問題
問題 1:為什麼我們不使用 AutoModelForImageClassification
?
這是因為我們想要獲得影像的密集表示,而不是離散類別,而這正是 AutoModelForImageClassification
所能提供的。
問題 2:為什麼是這個特定的檢查點?
如前所述,我們正在使用特定資料集來構建系統。因此,與其使用通用模型(例如在 ImageNet-1k 資料集上訓練的模型),不如使用已在所用資料集上微調過的模型。這樣,底層模型能更好地理解輸入影像。
請注意,您也可以使用透過自監督預訓練獲得的檢查點。該檢查點不一定非要來自監督學習。事實上,如果預訓練得當,自監督模型可以產生令人印象深刻的檢索效能。
現在我們有了一個用於計算嵌入的模型,我們需要一些候選影像來進行查詢。
載入候選影像資料集
過一段時間,我們將構建雜湊表,將候選影像對映到雜湊值。在查詢時,我們將使用這些雜湊表。我們將在相應的章節中詳細討論雜湊表,但目前,為了獲得一組候選影像,我們將使用 beans
資料集的 train
分割。
from datasets import load_dataset
dataset = load_dataset("beans")
這是訓練分割中的一個樣本:

該資料集有三個特徵:
dataset["train"].features
>>> {'image_file_path': Value(dtype='string', id=None),
'image': Image(decode=True, id=None),
'labels': ClassLabel(names=['angular_leaf_spot', 'bean_rust', 'healthy'], id=None)}
為了演示影像相似性系統,我們將使用候選影像資料集中的 100 個樣本,以縮短總體執行時間。
num_samples = 100
seed = 42
candidate_subset = dataset["train"].shuffle(seed=seed).select(range(num_samples))
查詢相似影像的過程
下面是獲取相似影像過程的圖示概覽。

將上圖分解一下,我們有:
- 從候選影像(
candidate_subset
)中提取嵌入,並將其儲存在一個矩陣中。 - 獲取查詢影像並提取其嵌入。
- 迭代嵌入矩陣(在步驟 1 中計算),並計算查詢嵌入和當前候選嵌入之間的相似性得分。我們通常維護一個類似字典的對映,以保持候選影像的某個識別符號和相似性得分之間的對應關係。
- 根據相似性得分對對映結構進行排序,並返回底層識別符號。我們使用這些識別符號來獲取候選樣本。
我們可以編寫一個簡單的實用程式並將其 map()
到我們的候選影像資料集上,以高效地計算嵌入。
import torch
def extract_embeddings(model: torch.nn.Module):
"""Utility to compute embeddings."""
device = model.device
def pp(batch):
images = batch["image"]
# `transformation_chain` is a compostion of preprocessing
# transformations we apply to the input images to prepare them
# for the model. For more details, check out the accompanying Colab Notebook.
image_batch_transformed = torch.stack(
[transformation_chain(image) for image in images]
)
new_batch = {"pixel_values": image_batch_transformed.to(device)}
with torch.no_grad():
embeddings = model(**new_batch).last_hidden_state[:, 0].cpu()
return {"embeddings": embeddings}
return pp
我們可以像這樣對映 extract_embeddings()
:
device = "cuda" if torch.cuda.is_available() else "cpu"
extract_fn = extract_embeddings(model.to(device))
candidate_subset_emb = candidate_subset.map(extract_fn, batched=True, batch_size=batch_size)
接下來,為了方便起見,我們建立一個包含候選影像識別符號的列表。
candidate_ids = []
for id in tqdm(range(len(candidate_subset_emb))):
label = candidate_subset_emb[id]["labels"]
# Create a unique indentifier.
entry = str(id) + "_" + str(label)
candidate_ids.append(entry)
我們將使用所有候選影像的嵌入矩陣來計算與查詢影像的相似性得分。我們已經計算了候選影像的嵌入。在下一個單元格中,我們只是將它們收集到一個矩陣中。
all_candidate_embeddings = np.array(candidate_subset_emb["embeddings"])
all_candidate_embeddings = torch.from_numpy(all_candidate_embeddings)
我們將使用餘弦相似度來計算兩個嵌入向量之間的相似度分數。然後,我們將使用它來根據給定的查詢樣本獲取相似的候選樣本。
def compute_scores(emb_one, emb_two):
"""Computes cosine similarity between two vectors."""
scores = torch.nn.functional.cosine_similarity(emb_one, emb_two)
return scores.numpy().tolist()
def fetch_similar(image, top_k=5):
"""Fetches the `top_k` similar images with `image` as the query."""
# Prepare the input query image for embedding computation.
image_transformed = transformation_chain(image).unsqueeze(0)
new_batch = {"pixel_values": image_transformed.to(device)}
# Comute the embedding.
with torch.no_grad():
query_embeddings = model(**new_batch).last_hidden_state[:, 0].cpu()
# Compute similarity scores with all the candidate images at one go.
# We also create a mapping between the candidate image identifiers
# and their similarity scores with the query image.
sim_scores = compute_scores(all_candidate_embeddings, query_embeddings)
similarity_mapping = dict(zip(candidate_ids, sim_scores))
# Sort the mapping dictionary and return `top_k` candidates.
similarity_mapping_sorted = dict(
sorted(similarity_mapping.items(), key=lambda x: x[1], reverse=True)
)
id_entries = list(similarity_mapping_sorted.keys())[:top_k]
ids = list(map(lambda x: int(x.split("_")[0]), id_entries))
labels = list(map(lambda x: int(x.split("_")[-1]), id_entries))
return ids, labels
執行查詢
有了所有這些實用工具,我們就可以進行相似性搜尋了。讓我們從 beans
資料集的 test
分割中獲取一個查詢影像:
test_idx = np.random.choice(len(dataset["test"]))
test_sample = dataset["test"][test_idx]["image"]
test_label = dataset["test"][test_idx]["labels"]
sim_ids, sim_labels = fetch_similar(test_sample)
print(f"Query label: {test_label}")
print(f"Top 5 candidate labels: {sim_labels}")
導致
Query label: 0
Top 5 candidate labels: [0, 0, 0, 0, 0]
看來我們的系統找到了正確的相似影像集。當可視化時,我們會得到:

進一步的擴充套件和結論
我們現在有了一個可用的影像相似性系統。但在現實中,您將處理更多的候選影像。考慮到這一點,我們目前的程式有幾個缺點:
- 如果我們將嵌入原樣儲存,記憶體需求會迅速飆升,尤其是在處理數百萬張候選影像時。在我們的案例中,嵌入是 768 維的,在大規模場景下仍然相對較高。
- 高維嵌入對檢索部分涉及的後續計算有直接影響。
如果我們能夠在不干擾嵌入含義的情況下降低其維度,我們仍然可以在速度和檢索質量之間保持良好的權衡。本文的配套 Colab Notebook實現了並演示了使用隨機投影和區域性敏感雜湊實現這一點的實用工具。
🤗 Datasets 提供與 FAISS 的直接整合,這進一步簡化了構建相似性系統的過程。假設您已經提取了候選影像(beans
資料集)的嵌入並將其儲存在一個名為 embeddings
的特徵中。您現在可以輕鬆使用資料集的 add_faiss_index()
來構建一個密集索引:
dataset_with_embeddings.add_faiss_index(column="embeddings")
一旦索引構建完成,dataset_with_embeddings
就可以用於使用 get_nearest_examples()
獲取給定查詢嵌入的最近鄰示例。
scores, retrieved_examples = dataset_with_embeddings.get_nearest_examples(
"embeddings", qi_embedding, k=top_k
)
該方法返回得分和相應的候選示例。要了解更多資訊,您可以檢視官方文件和此筆記本。
最後,您可以嘗試以下 Space,它構建了一個迷你影像相似性應用程式:
在這篇文章中,我們快速介紹了構建影像相似性系統。如果您覺得這篇文章很有趣,我們強烈建議您在此基礎上進行構建,以便您能更熟悉其內部工作原理。
還在尋找更多學習資料?以下是一些對您可能有用的額外資源: