Transformers 文件
掩碼生成
並獲得增強的文件體驗
開始使用
掩碼生成
掩碼生成是將語義有意義的掩碼生成影像的任務。此任務與影像分割非常相似,但存在許多差異。影像分割模型在標記資料集上進行訓練,並受限於其在訓練期間所見的類別;給定影像後,它們返回一組掩碼和相應的類別。
掩碼生成模型在大量資料上進行訓練,並以兩種模式執行。
- 提示模式:在此模式下,模型接收影像和提示,其中提示可以是影像中物件內的二維點位置(XY 座標)或圍繞物件的邊界框。在提示模式下,模型僅返回提示指向的物件上的掩碼。
- 全部分割模式:在全部分割模式下,給定影像後,模型生成影像中的每個掩碼。為此,會生成點網格並將其疊加到影像上進行推理。
掩碼生成任務由Segment Anything Model (SAM)支援。它是一個強大的模型,由基於 Vision Transformer 的影像編碼器、提示編碼器和雙向 Transformer 掩碼解碼器組成。影像和提示被編碼,解碼器接收這些嵌入並生成有效掩碼。

SAM 作為分割的強大基礎模型,因為它具有廣泛的資料覆蓋範圍。它在SA-1B上進行訓練,該資料集包含 100 萬張影像和 11 億個掩碼。
在本指南中,您將學習如何
- 在全部分割模式下進行推理(帶批處理),
- 在點提示模式下進行推理,
- 在框提示模式下進行推理。
首先,讓我們安裝transformers
pip install -q transformers
掩碼生成管線
推斷掩碼生成模型最簡單的方法是使用mask-generation
管線。
>>> from transformers import pipeline
>>> checkpoint = "facebook/sam-vit-base"
>>> mask_generator = pipeline(model=checkpoint, task="mask-generation")
讓我們看看這張圖片。
from PIL import Image
import requests
img_url = "https://huggingface.co/datasets/huggingface/documentation-images/resolve/main/bee.jpg"
image = Image.open(requests.get(img_url, stream=True).raw).convert("RGB")

讓我們分割所有內容。`points-per-batch` 可以在全部分割模式下並行推理點。這可以加快推理速度,但會消耗更多記憶體。此外,SAM 只支援對點進行批處理,不支援對影像進行批處理。`pred_iou_thresh` 是 IoU 置信度閾值,只有高於該閾值的掩碼才會被返回。
masks = mask_generator(image, points_per_batch=128, pred_iou_thresh=0.88)
masks
看起來像這樣
{'masks': [array([[False, False, False, ..., True, True, True],
[False, False, False, ..., True, True, True],
[False, False, False, ..., True, True, True],
...,
[False, False, False, ..., False, False, False],
[False, False, False, ..., False, False, False],
[False, False, False, ..., False, False, False]]),
array([[False, False, False, ..., False, False, False],
[False, False, False, ..., False, False, False],
[False, False, False, ..., False, False, False],
...,
'scores': tensor([0.9972, 0.9917,
...,
}
我們可以這樣視覺化它們
import matplotlib.pyplot as plt
plt.imshow(image, cmap='gray')
for i, mask in enumerate(masks["masks"]):
plt.imshow(mask, cmap='viridis', alpha=0.1, vmin=0, vmax=1)
plt.axis('off')
plt.show()
下面是原始影像的灰度圖,上面覆蓋著彩色地圖。非常令人印象深刻。

模型推理
點提示
您也可以不使用管線來使用模型。為此,請初始化模型和處理器。
from transformers import SamModel, SamProcessor
import torch
from accelerate.test_utils.testing import get_backend
# automatically detects the underlying device type (CUDA, CPU, XPU, MPS, etc.)
device, _, _ = get_backend()
model = SamModel.from_pretrained("facebook/sam-vit-base").to(device)
processor = SamProcessor.from_pretrained("facebook/sam-vit-base")
要進行點提示,請將輸入點傳遞給處理器,然後將處理器輸出傳遞給模型進行推理。要對模型輸出進行後處理,請傳遞輸出以及我們從處理器初始輸出中獲取的`original_sizes`和`reshaped_input_sizes`。我們需要傳遞這些引數,因為處理器會調整影像大小,並且需要對輸出進行外推。
input_points = [[[2592, 1728]]] # point location of the bee
inputs = processor(image, input_points=input_points, return_tensors="pt").to(device)
with torch.no_grad():
outputs = model(**inputs)
masks = processor.image_processor.post_process_masks(outputs.pred_masks.cpu(), inputs["original_sizes"].cpu(), inputs["reshaped_input_sizes"].cpu())
我們可以將masks
輸出中的三個掩碼視覺化。
import matplotlib.pyplot as plt
import numpy as np
fig, axes = plt.subplots(1, 4, figsize=(15, 5))
axes[0].imshow(image)
axes[0].set_title('Original Image')
mask_list = [masks[0][0][0].numpy(), masks[0][0][1].numpy(), masks[0][0][2].numpy()]
for i, mask in enumerate(mask_list, start=1):
overlayed_image = np.array(image).copy()
overlayed_image[:,:,0] = np.where(mask == 1, 255, overlayed_image[:,:,0])
overlayed_image[:,:,1] = np.where(mask == 1, 0, overlayed_image[:,:,1])
overlayed_image[:,:,2] = np.where(mask == 1, 0, overlayed_image[:,:,2])
axes[i].imshow(overlayed_image)
axes[i].set_title(f'Mask {i}')
for ax in axes:
ax.axis('off')
plt.show()

框提示
您也可以以類似於點提示的方式進行框提示。您只需將輸入框以列表[x_min, y_min, x_max, y_max]
格式與影像一起傳遞給processor
。獲取處理器輸出並直接將其傳遞給模型,然後再次對輸出進行後處理。
# bounding box around the bee
box = [2350, 1600, 2850, 2100]
inputs = processor(
image,
input_boxes=[[[box]]],
return_tensors="pt"
).to("cuda")
with torch.no_grad():
outputs = model(**inputs)
mask = processor.image_processor.post_process_masks(
outputs.pred_masks.cpu(),
inputs["original_sizes"].cpu(),
inputs["reshaped_input_sizes"].cpu()
)[0][0][0].numpy()
您可以視覺化蜜蜂周圍的邊界框,如下圖所示。
import matplotlib.patches as patches
fig, ax = plt.subplots()
ax.imshow(image)
rectangle = patches.Rectangle((2350, 1600), 500, 500, linewidth=2, edgecolor='r', facecolor='none')
ax.add_patch(rectangle)
ax.axis("off")
plt.show()

您可以在下面看到推理輸出。
fig, ax = plt.subplots()
ax.imshow(image)
ax.imshow(mask, cmap='viridis', alpha=0.4)
ax.axis("off")
plt.show()
