社群計算機視覺課程文件
OneFormer:一個 Transformer 統一通用影像分割
並獲得增強的文件體驗
開始使用
OneFormer:一個 Transformer 統一通用影像分割
簡介
OneFormer 是影像分割領域的一項突破性方法,影像分割是一項涉及將影像分割成有意義的片段的計算機視覺任務。傳統方法對不同的分割任務使用獨立的模型和架構,例如識別物件(例項分割)或標記區域(語義分割)。最近的嘗試旨在透過共享架構統一這些任務,但仍然需要為每個任務進行單獨的訓練。
OneFormer 是一款旨在克服這些挑戰的通用影像分割框架。它引入了一種獨特的多工方法,允許單個模型處理語義、例項和全景分割任務,而無需對每個任務進行單獨訓練。其關鍵創新在於任務條件聯合訓練策略,模型透過任務輸入進行引導,使其在訓練和推理過程中都具有動態性和適應性。
這一突破不僅簡化了訓練過程,還在各種資料集上超越了現有模型。OneFormer 透過使用全景標註實現了這一點,統一了所有任務所需的真實資訊。此外,該框架還引入了查詢-文字對比學習,以更好地區分任務並提高整體效能。
OneFormer 背景
圖片來自 OneFormer 論文
為了理解 OneFormer 的重要性,讓我們深入瞭解影像分割的背景。在影像處理中,分割涉及將影像劃分為不同的部分,這對於識別物件和理解場景內容等任務至關重要。傳統上,兩種主要型別的分割任務是語義分割(其中畫素被標記為“道路”或“天空”等類別)和例項分割(識別具有明確邊界的物件)。
隨著時間的推移,研究人員提出了全景分割作為統一語義分割和例項分割任務的方法。然而,即使取得了這些進展,仍然存在挑戰。為全景分割設計的現有模型仍然需要為每個任務進行單獨訓練,充其量也只能算是半通用。
這就是 OneFormer 作為一個顛覆者的出現。它引入了一種新穎的方法——多工通用架構。其理念是隻訓練一次這個框架,使用一個單一的通用架構、一個獨立模型和一個數據集。OneFormer 的神奇之處在於,它在語義、例項和全景分割任務上都超越了專用框架。這一突破不僅僅是為了提高準確性;它是為了使影像分割更通用和高效。有了 OneFormer,對不同任務的大量資源和單獨訓練的需求成為過去。
OneFormer 的核心概念
現在,讓我們分解一下 OneFormer 的主要特點,讓它脫穎而出。
任務動態掩碼
OneFormer 使用了一種名為“任務動態掩碼”的巧妙技巧,以更好地理解和處理不同型別的影像分割任務。因此,當模型遇到影像時,它會使用此“任務動態掩碼”來決定是關注整體場景,識別具有清晰邊界的特定物件,還是兩者兼而有之。
任務條件聯合訓練
OneFormer 的開創性功能之一是其任務條件聯合訓練策略。OneFormer 不再為語義、例項和全景分割進行單獨訓練,而是在訓練期間統一取樣任務。這種策略使模型能夠同時學習和泛化不同的分割任務。透過任務令牌對特定任務進行架構條件設定,OneFormer 統一了訓練過程,減少了對任務特定架構、模型和資料集的需求。這種創新方法顯著簡化了訓練流程和資源要求。
查詢-文字對比損失
最後,我們來談談“查詢-文字對比損失”。可以將其視為 OneFormer 教導自己任務和類別之間差異的一種方式。在訓練過程中,模型將其從影像中提取的特徵(查詢)與相應的文字描述(例如“一輛汽車的照片”)進行比較。這有助於模型理解每個任務的獨特特徵,並減少不同類別之間的混淆。OneFormer 的“任務動態掩碼”使其能夠像多工助手一樣靈活,而“查詢-文字對比損失”透過比較視覺特徵和文字描述來幫助它學習每個任務的具體細節。
透過結合這些核心概念,OneFormer 成為一種智慧高效的影像分割工具,使過程更通用和易於訪問。
結論
圖片來自 OneFormer 論文
總之,OneFormer 框架代表了影像分割領域的一項突破性方法,旨在簡化和統一各個領域的任務。與依賴於每個分割任務的專用架構的傳統方法不同,OneFormer 引入了一種新穎的多工通用架構,只需一個模型,在通用資料集上訓練一次,即可超越現有框架。此外,在訓練過程中引入查詢-文字對比損失增強了模型學習任務間和類間差異的能力。OneFormer 利用基於 Transformer 的架構,受計算機視覺領域最近成功的啟發,並引入了任務引導查詢以提高任務敏感性。結果令人印象深刻,OneFormer 在 ADE20k、Cityscapes 和 COCO 等基準資料集上的語義、例項和全景分割任務中超越了最先進的模型。該框架的效能透過新的 ConvNeXt 和 DiNAT 主幹進一步增強。
總而言之,OneFormer 代表了通用且易於訪問的影像分割邁出了重要一步。透過引入一個能夠處理各種分割任務的單一模型,該框架簡化了分割過程並減少了資源需求。
模型使用示例
讓我們看一個模型的使用示例。Dinat 主幹需要 Natten 庫,安裝可能需要一段時間。
!pip install -q natten
下面我們可以看到根據不同分割型別進行的推理程式碼。
from transformers import OneFormerProcessor, OneFormerForUniversalSegmentation
from PIL import Image
import requests
import matplotlib.pyplot as plt
def run_segmentation(image, task_type):
"""Performs image segmentation based on the given task type.
Args:
image (PIL.Image): The input image.
task_type (str): The type of segmentation to perform ('semantic', 'instance', or 'panoptic').
Returns:
PIL.Image: The segmented image.
Raises:
ValueError: If the task type is invalid.
"""
processor = OneFormerProcessor.from_pretrained(
"shi-labs/oneformer_ade20k_dinat_large"
) # Load once here
model = OneFormerForUniversalSegmentation.from_pretrained(
"shi-labs/oneformer_ade20k_dinat_large"
)
if task_type == "semantic":
inputs = processor(images=image, task_inputs=["semantic"], return_tensors="pt")
outputs = model(**inputs)
predicted_map = processor.post_process_semantic_segmentation(
outputs, target_sizes=[image.size[::-1]]
)[0]
elif task_type == "instance":
inputs = processor(images=image, task_inputs=["instance"], return_tensors="pt")
outputs = model(**inputs)
predicted_map = processor.post_process_instance_segmentation(
outputs, target_sizes=[image.size[::-1]]
)[0]["segmentation"]
elif task_type == "panoptic":
inputs = processor(images=image, task_inputs=["panoptic"], return_tensors="pt")
outputs = model(**inputs)
predicted_map = processor.post_process_panoptic_segmentation(
outputs, target_sizes=[image.size[::-1]]
)[0]["segmentation"]
else:
raise ValueError(
"Invalid task type. Choose from 'semantic', 'instance', or 'panoptic'"
)
return predicted_map
def show_image_comparison(image, predicted_map, segmentation_title):
"""Displays the original image and the segmented image side-by-side.
Args:
image (PIL.Image): The original image.
predicted_map (PIL.Image): The segmented image.
segmentation_title (str): The title for the segmented image.
"""
plt.figure(figsize=(12, 6))
plt.subplot(1, 2, 1)
plt.imshow(image)
plt.title("Original Image")
plt.axis("off")
plt.subplot(1, 2, 2)
plt.imshow(predicted_map)
plt.title(segmentation_title + " Segmentation")
plt.axis("off")
plt.show()
url = "https://huggingface.co/datasets/shi-labs/oneformer_demo/resolve/main/ade20k.jpeg"
response = requests.get(url, stream=True)
response.raise_for_status() # Check for HTTP errors
image = Image.open(response.raw)
task_to_run = "semantic"
predicted_map = run_segmentation(image, task_to_run)
show_image_comparison(image, predicted_map, task_to_run)