Transformers 文件

關鍵點檢測

Hugging Face's logo
加入 Hugging Face 社群

並獲得增強的文件體驗

開始使用

關鍵點檢測

關鍵點檢測用於識別和定點陣圖像中特定的興趣點。這些關鍵點,也稱為地標,代表了物體的有意義特徵,例如面部特徵或物體部位。這些模型接收影像輸入並返回以下輸出:

  • 關鍵點和分數:興趣點及其置信度分數。
  • 描述符:圍繞每個關鍵點的影像區域的表示,捕獲其紋理、梯度、方向和其他屬性。

在本指南中,我們將演示如何從影像中提取關鍵點。

在本教程中,我們將使用 SuperPoint,這是一個用於關鍵點檢測的基礎模型。

from transformers import AutoImageProcessor, SuperPointForKeypointDetection
processor = AutoImageProcessor.from_pretrained("magic-leap-community/superpoint")
model = SuperPointForKeypointDetection.from_pretrained("magic-leap-community/superpoint")

讓我們在下面的影像上測試模型。

Bee Cats
import torch
from PIL import Image
import requests
import cv2


url_image_1 = "https://huggingface.co/datasets/huggingface/documentation-images/resolve/main/bee.jpg"
image_1 = Image.open(requests.get(url_image_1, stream=True).raw)
url_image_2 = "https://huggingface.co/datasets/huggingface/documentation-images/resolve/main/cats.png"
image_2 = Image.open(requests.get(url_image_2, stream=True).raw)

images = [image_1, image_2]

現在我們可以處理輸入並進行推理。

inputs = processor(images,return_tensors="pt").to(model.device, model.dtype)
outputs = model(**inputs)

模型輸出包含批次中每個專案的相對關鍵點、描述符、掩碼和分數。掩碼突出顯示影像中存在關鍵點的區域。

SuperPointKeypointDescriptionOutput(loss=None, keypoints=tensor([[[0.0437, 0.0167],
         [0.0688, 0.0167],
         [0.0172, 0.0188],
         ...,
         [0.5984, 0.9812],
         [0.6953, 0.9812]]]), 
         scores=tensor([[0.0056, 0.0053, 0.0079,  ..., 0.0125, 0.0539, 0.0377],
        [0.0206, 0.0058, 0.0065,  ..., 0.0000, 0.0000, 0.0000]],
       grad_fn=<CopySlices>), descriptors=tensor([[[-0.0807,  0.0114, -0.1210,  ..., -0.1122,  0.0899,  0.0357],
         [-0.0807,  0.0114, -0.1210,  ..., -0.1122,  0.0899,  0.0357],
         [-0.0807,  0.0114, -0.1210,  ..., -0.1122,  0.0899,  0.0357],
         ...],
       grad_fn=<CopySlices>), mask=tensor([[1, 1, 1,  ..., 1, 1, 1],
        [1, 1, 1,  ..., 0, 0, 0]], dtype=torch.int32), hidden_states=None)

要在影像中繪製實際關鍵點,我們需要對輸出進行後處理。為此,我們必須將實際影像大小與輸出一起傳遞給 post_process_keypoint_detection

image_sizes = [(image.size[1], image.size[0]) for image in images]
outputs = processor.post_process_keypoint_detection(outputs, image_sizes)

現在輸出是字典列表,每個字典都是關鍵點、分數和描述符的已處理輸出。

[{'keypoints': tensor([[ 226,   57],
          [ 356,   57],
          [  89,   64],
          ...,
          [3604, 3391]], dtype=torch.int32),
  'scores': tensor([0.0056, 0.0053, ...], grad_fn=<IndexBackward0>),
  'descriptors': tensor([[-0.0807,  0.0114, -0.1210,  ..., -0.1122,  0.0899,  0.0357],
          [-0.0807,  0.0114, -0.1210,  ..., -0.1122,  0.0899,  0.0357]],
         grad_fn=<IndexBackward0>)},
    {'keypoints': tensor([[ 46,   6],
          [ 78,   6],
          [422,   6],
          [206, 404]], dtype=torch.int32),
  'scores': tensor([0.0206, 0.0058, 0.0065, 0.0053, 0.0070, ...,grad_fn=<IndexBackward0>),
  'descriptors': tensor([[-0.0525,  0.0726,  0.0270,  ...,  0.0389, -0.0189, -0.0211],
          [-0.0525,  0.0726,  0.0270,  ...,  0.0389, -0.0189, -0.0211]}]

我們可以使用這些來繪製關鍵點。

import matplotlib.pyplot as plt
import torch

for i in range(len(images)):
  keypoints = outputs[i]["keypoints"]
  scores = outputs[i]["scores"]
  descriptors = outputs[i]["descriptors"]
  keypoints = outputs[i]["keypoints"].detach().numpy()
  scores = outputs[i]["scores"].detach().numpy()
  image = images[i]
  image_width, image_height = image.size

  plt.axis('off')
  plt.imshow(image)
  plt.scatter(
      keypoints[:, 0],
      keypoints[:, 1],
      s=scores * 100,
      c='cyan',
      alpha=0.4
  )
  plt.show()

您可以在下面看到輸出。

Bee Cats
< > 在 GitHub 上更新

© . This site is unofficial and not affiliated with Hugging Face, Inc.