開源 AI 食譜文件

使用 🤗 Transformers、🤗 Datasets 和 FAISS 嵌入多模態資料以進行相似性搜尋

Hugging Face's logo
加入 Hugging Face 社群

並獲得增強的文件體驗

開始使用

Open In Colab

使用 🤗 Transformers、🤗 Datasets 和 FAISS 嵌入多模態資料以進行相似性搜尋

作者:Merve Noyan

嵌入是資訊的語義有意義的壓縮。它們可以用於相似性搜尋、零樣本分類或簡單地訓練新模型。相似性搜尋的用例包括在電子商務中搜索類似產品、社交媒體中的內容搜尋等等。本筆記本將引導您使用 🤗transformers、🤗datasets 和 FAISS 從特徵提取模型建立和索引嵌入,以便稍後將它們用於相似性搜尋。讓我們安裝必要的庫。

!pip install -q datasets faiss-gpu transformers sentencepiece

在本教程中,我們將使用CLIP 模型來提取特徵。CLIP 是一個革命性的模型,它引入了文字編碼器和影像編碼器的聯合訓練,以連線兩種模態。

import torch
from PIL import Image
from transformers import AutoImageProcessor, AutoModel, AutoTokenizer
import faiss
import numpy as np

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

model = AutoModel.from_pretrained("openai/clip-vit-base-patch16").to(device)
processor = AutoImageProcessor.from_pretrained("openai/clip-vit-base-patch16")
tokenizer = AutoTokenizer.from_pretrained("openai/clip-vit-base-patch16")

載入資料集。為了讓本筆記本保持輕量級,我們將使用一個小型字幕資料集:jmhessel/newyorker_caption_contest

from datasets import load_dataset

ds = load_dataset("jmhessel/newyorker_caption_contest", "explanation")

檢視一個示例。

>>> ds["train"][0]["image"]
ds["train"][0]["image_description"]

我們不必編寫任何函式來嵌入示例或建立索引。🤗 Datasets 庫的 FAISS 整合抽象了這些過程。我們可以簡單地使用資料集的 `map` 方法來建立一個新列,其中包含每個示例的嵌入,如下所示。讓我們為提示列中的文字特徵建立一個。

dataset = ds["train"]
ds_with_embeddings = dataset.map(
    lambda example: {
        "embeddings": model.get_text_features(
            **tokenizer([example["image_description"]], truncation=True, return_tensors="pt").to("cuda")
        )[0]
        .detach()
        .cpu()
        .numpy()
    }
)

我們可以做同樣的事情並獲得影像嵌入。

ds_with_embeddings = ds_with_embeddings.map(
    lambda example: {
        "image_embeddings": model.get_image_features(**processor([example["image"]], return_tensors="pt").to("cuda"))[
            0
        ]
        .detach()
        .cpu()
        .numpy()
    }
)

現在,我們為每一列建立一個索引。

# create FAISS index for text embeddings
ds_with_embeddings.add_faiss_index(column="embeddings")
# create FAISS index for image embeddings
ds_with_embeddings.add_faiss_index(column="image_embeddings")

使用文字提示查詢資料

我們現在可以使用文字或影像查詢資料集,從中獲取類似的專案。

prmt = "a snowy day"
prmt_embedding = (
    model.get_text_features(**tokenizer([prmt], return_tensors="pt", truncation=True).to("cuda"))[0]
    .detach()
    .cpu()
    .numpy()
)
scores, retrieved_examples = ds_with_embeddings.get_nearest_examples("embeddings", prmt_embedding, k=1)
>>> def downscale_images(image):
...     width = 200
...     ratio = width / float(image.size[0])
...     height = int((float(image.size[1]) * float(ratio)))
...     img = image.resize((width, height), Image.Resampling.LANCZOS)
...     return img


>>> images = [downscale_images(image) for image in retrieved_examples["image"]]
>>> # see the closest text and image
>>> print(retrieved_examples["image_description"])
>>> display(images[0])
['A man is in the snow. A boy with a huge snow shovel is there too. They are outside a house.']

使用影像提示查詢資料

影像相似性推理與此類似,您只需呼叫 `get_image_features`。

>>> import requests

>>> # image of a beaver
>>> url = "https://huggingface.co/datasets/huggingface/documentation-images/resolve/main/transformers/beaver.png"
>>> image = Image.open(requests.get(url, stream=True).raw)
>>> display(downscale_images(image))

搜尋相似影像。

img_embedding = (
    model.get_image_features(**processor([image], return_tensors="pt", truncation=True).to("cuda"))[0]
    .detach()
    .cpu()
    .numpy()
)
scores, retrieved_examples = ds_with_embeddings.get_nearest_examples("image_embeddings", img_embedding, k=1)

顯示與海狸影像最相似的影像。

>>> images = [downscale_images(image) for image in retrieved_examples["image"]]
>>> # see the closest text and image
>>> print(retrieved_examples["image_description"])
>>> display(images[0])
['Salmon swim upstream but they see a grizzly bear and are in shock. The bear has a smug look on his face when he sees the salmon.']

儲存、推送和載入嵌入

我們可以使用 `save_faiss_index` 儲存帶嵌入的資料集。

ds_with_embeddings.save_faiss_index("embeddings", "embeddings/embeddings.faiss")
ds_with_embeddings.save_faiss_index("image_embeddings", "embeddings/image_embeddings.faiss")

將嵌入儲存在資料集倉庫中是一個好習慣,因此我們將建立一個,然後將我們的索引推送到那裡,之後使用 `snapshot_download` 載入。

from huggingface_hub import HfApi, notebook_login, snapshot_download

notebook_login()
from huggingface_hub import HfApi

api = HfApi()
api.create_repo("merve/faiss_embeddings", repo_type="dataset")
api.upload_folder(
    folder_path="./embeddings",
    repo_id="merve/faiss_embeddings",
    repo_type="dataset",
)
snapshot_download(repo_id="merve/faiss_embeddings", repo_type="dataset", local_dir="downloaded_embeddings")

我們可以使用 `load_faiss_index` 將嵌入載入到不帶嵌入的資料集中。

ds = ds["train"]
ds.load_faiss_index("embeddings", "./downloaded_embeddings/embeddings.faiss")
# infer again
prmt = "people under the rain"
prmt_embedding = (
    model.get_text_features(**tokenizer([prmt], return_tensors="pt", truncation=True).to("cuda"))[0]
    .detach()
    .cpu()
    .numpy()
)

scores, retrieved_examples = ds.get_nearest_examples("embeddings", prmt_embedding, k=1)
>>> display(retrieved_examples["image"][0])
< > 在 GitHub 上更新

© . This site is unofficial and not affiliated with Hugging Face, Inc.