Transformers 文件

掩碼生成

Hugging Face's logo
加入 Hugging Face 社群

並獲得增強的文件體驗

開始使用

掩碼生成

掩碼生成是將語義有意義的掩碼生成影像的任務。此任務與影像分割非常相似,但存在許多差異。影像分割模型在標記資料集上進行訓練,並受限於其在訓練期間所見的類別;給定影像後,它們返回一組掩碼和相應的類別。

掩碼生成模型在大量資料上進行訓練,並以兩種模式執行。

  • 提示模式:在此模式下,模型接收影像和提示,其中提示可以是影像中物件內的二維點位置(XY 座標)或圍繞物件的邊界框。在提示模式下,模型僅返回提示指向的物件上的掩碼。
  • 全部分割模式:在全部分割模式下,給定影像後,模型生成影像中的每個掩碼。為此,會生成點網格並將其疊加到影像上進行推理。

掩碼生成任務由Segment Anything Model (SAM)支援。它是一個強大的模型,由基於 Vision Transformer 的影像編碼器、提示編碼器和雙向 Transformer 掩碼解碼器組成。影像和提示被編碼,解碼器接收這些嵌入並生成有效掩碼。

SAM Architecture

SAM 作為分割的強大基礎模型,因為它具有廣泛的資料覆蓋範圍。它在SA-1B上進行訓練,該資料集包含 100 萬張影像和 11 億個掩碼。

在本指南中,您將學習如何

  • 在全部分割模式下進行推理(帶批處理),
  • 在點提示模式下進行推理,
  • 在框提示模式下進行推理。

首先,讓我們安裝transformers

pip install -q transformers

掩碼生成管線

推斷掩碼生成模型最簡單的方法是使用mask-generation管線。

>>> from transformers import pipeline

>>> checkpoint = "facebook/sam-vit-base"
>>> mask_generator = pipeline(model=checkpoint, task="mask-generation")

讓我們看看這張圖片。

from PIL import Image
import requests

img_url = "https://huggingface.co/datasets/huggingface/documentation-images/resolve/main/bee.jpg"
image = Image.open(requests.get(img_url, stream=True).raw).convert("RGB")
Example Image

讓我們分割所有內容。`points-per-batch` 可以在全部分割模式下並行推理點。這可以加快推理速度,但會消耗更多記憶體。此外,SAM 只支援對點進行批處理,不支援對影像進行批處理。`pred_iou_thresh` 是 IoU 置信度閾值,只有高於該閾值的掩碼才會被返回。

masks = mask_generator(image, points_per_batch=128, pred_iou_thresh=0.88)

masks 看起來像這樣

{'masks': [array([[False, False, False, ...,  True,  True,  True],
         [False, False, False, ...,  True,  True,  True],
         [False, False, False, ...,  True,  True,  True],
         ...,
         [False, False, False, ..., False, False, False],
         [False, False, False, ..., False, False, False],
         [False, False, False, ..., False, False, False]]),
  array([[False, False, False, ..., False, False, False],
         [False, False, False, ..., False, False, False],
         [False, False, False, ..., False, False, False],
         ...,
'scores': tensor([0.9972, 0.9917,
        ...,
}

我們可以這樣視覺化它們

import matplotlib.pyplot as plt

plt.imshow(image, cmap='gray')

for i, mask in enumerate(masks["masks"]):
    plt.imshow(mask, cmap='viridis', alpha=0.1, vmin=0, vmax=1)

plt.axis('off')
plt.show()

下面是原始影像的灰度圖,上面覆蓋著彩色地圖。非常令人印象深刻。

Visualized

模型推理

點提示

您也可以不使用管線來使用模型。為此,請初始化模型和處理器。

from transformers import SamModel, SamProcessor
import torch
from accelerate.test_utils.testing import get_backend
# automatically detects the underlying device type (CUDA, CPU, XPU, MPS, etc.)
device, _, _ = get_backend()
model = SamModel.from_pretrained("facebook/sam-vit-base").to(device)
processor = SamProcessor.from_pretrained("facebook/sam-vit-base")

要進行點提示,請將輸入點傳遞給處理器,然後將處理器輸出傳遞給模型進行推理。要對模型輸出進行後處理,請傳遞輸出以及我們從處理器初始輸出中獲取的`original_sizes`和`reshaped_input_sizes`。我們需要傳遞這些引數,因為處理器會調整影像大小,並且需要對輸出進行外推。

input_points = [[[2592, 1728]]] # point location of the bee

inputs = processor(image, input_points=input_points, return_tensors="pt").to(device)
with torch.no_grad():
    outputs = model(**inputs)
masks = processor.image_processor.post_process_masks(outputs.pred_masks.cpu(), inputs["original_sizes"].cpu(), inputs["reshaped_input_sizes"].cpu())

我們可以將masks輸出中的三個掩碼視覺化。

import matplotlib.pyplot as plt
import numpy as np

fig, axes = plt.subplots(1, 4, figsize=(15, 5))

axes[0].imshow(image)
axes[0].set_title('Original Image')
mask_list = [masks[0][0][0].numpy(), masks[0][0][1].numpy(), masks[0][0][2].numpy()]

for i, mask in enumerate(mask_list, start=1):
    overlayed_image = np.array(image).copy()

    overlayed_image[:,:,0] = np.where(mask == 1, 255, overlayed_image[:,:,0])
    overlayed_image[:,:,1] = np.where(mask == 1, 0, overlayed_image[:,:,1])
    overlayed_image[:,:,2] = np.where(mask == 1, 0, overlayed_image[:,:,2])
    
    axes[i].imshow(overlayed_image)
    axes[i].set_title(f'Mask {i}')
for ax in axes:
    ax.axis('off')

plt.show()
Visualized

框提示

您也可以以類似於點提示的方式進行框提示。您只需將輸入框以列表[x_min, y_min, x_max, y_max]格式與影像一起傳遞給processor。獲取處理器輸出並直接將其傳遞給模型,然後再次對輸出進行後處理。

# bounding box around the bee
box = [2350, 1600, 2850, 2100]

inputs = processor(
        image,
        input_boxes=[[[box]]],
        return_tensors="pt"
    ).to("cuda")

with torch.no_grad():
    outputs = model(**inputs)

mask = processor.image_processor.post_process_masks(
    outputs.pred_masks.cpu(),
    inputs["original_sizes"].cpu(),
    inputs["reshaped_input_sizes"].cpu()
)[0][0][0].numpy()

您可以視覺化蜜蜂周圍的邊界框,如下圖所示。

import matplotlib.patches as patches

fig, ax = plt.subplots()
ax.imshow(image)

rectangle = patches.Rectangle((2350, 1600), 500, 500, linewidth=2, edgecolor='r', facecolor='none')
ax.add_patch(rectangle)
ax.axis("off")
plt.show()
Visualized Bbox

您可以在下面看到推理輸出。

fig, ax = plt.subplots()
ax.imshow(image)
ax.imshow(mask, cmap='viridis', alpha=0.4)

ax.axis("off")
plt.show()
Visualized Inference
< > 在 GitHub 上更新

© . This site is unofficial and not affiliated with Hugging Face, Inc.