社群計算機視覺課程文件
DEtection TRansformer (DETR)
並獲得增強的文件體驗
開始使用
DEtection TRansformer (DETR)
架構概述
DETR主要用於目標檢測任務,即在影像中檢測物體的過程。例如,模型的輸入可以是一張道路圖片,模型的輸出可以是[('car',X1,Y1,W1,H1),('pedestrian',X2,Y2,W2,H2)],其中X、Y、W、H分別表示邊界框位置的x、y座標以及邊界框的寬度和高度。傳統的物體檢測模型如YOLO包含手工設計的特徵,例如錨框先驗,這需要對物體位置和形狀進行初始猜測,影響下游訓練。然後使用後處理步驟來去除重疊的邊界框,這需要仔細選擇其過濾啟發式方法。DEtection TRansformer,簡稱DETR,透過在特徵提取骨幹之後使用編碼器-解碼器Transformer來簡化檢測器,以並行方式直接預測邊界框,需要最少的後處理。
DETR的模型架構始於一個CNN骨幹,類似於其他基於影像的網路,其輸出經過處理並輸入到Transformer編碼器中,生成N個嵌入。編碼器嵌入與學習到的位置嵌入(稱為物件查詢)相加,並用於Transformer解碼器中,生成另外N個嵌入。最後一步,每個N個嵌入都透過單獨的前饋層來預測邊界框的寬度、高度、座標以及物件類別(或是否存在物件)。
主要特點
編碼器-解碼器
與其他Transformer一樣,Transformer編碼器需要CNN骨幹的輸出是一個序列。因此,大小為[dimension, height, width]的特徵圖被縮小並展平為[dimension, less than height x width]。
左圖:可視化了特徵圖中256個維度中的12個。每個維度提取原始貓影像的一些特徵,同時縮小了原始影像。一些維度更關注貓的圖案;一些維度更關注床單。 右圖:保持原始特徵維度256,寬度和高度進一步縮小並展平為850。
由於Transformer是置換不變的,因此在編碼器和解碼器中都添加了位置嵌入,以提醒模型嵌入在影像中的位置。在編碼器中,使用固定位置編碼,而在解碼器中,使用學習到的位置編碼(物件查詢)。固定編碼類似於原始Transformer論文中使用的編碼,其中編碼由在不同特徵維度上具有不同頻率的正弦函式定義。它在沒有任何學習引數的情況下提供了位置感,透過影像上的位置進行索引。學習到的編碼也透過位置進行索引,但每個位置都有一個單獨的編碼,該編碼在整個訓練過程中學習,以模型理解的方式表示位置。
基於集合的全域性損失函式
在流行的目標檢測模型YOLO中,損失函式包括邊界框、物件存在性(即物件存在於感興趣區域的機率)和類別損失。損失是在每個網格單元的多個邊界框上計算的,這是一個固定數量。另一方面,在DETR中,該架構預計以置換不變的方式生成唯一的邊界框(即,輸出中檢測的順序無關緊要,並且邊界框必須不同,不能全部相同)。因此,需要匹配來評估預測的質量。
二分匹配
二分匹配是一種計算真實邊界框與預測框之間一對一匹配的方法。它尋找真實與預測邊界框以及類別之間相似度最高的匹配。這確保了最接近的預測將與相應的真實值匹配,以便在損失函式中正確調整邊界框和類別。如果不進行匹配,與真實值順序不一致的預測即使是正確的也會被標記為不正確。
使用DETR檢測物件
要檢視如何使用Hugging Face Transformer執行DETR推理的示例,請參閱DETR.ipynb。
DETR的演變
可變形DETR
DETR的兩個主要問題是收斂過程緩慢和小型物體檢測次優。可變形注意力
第一個問題透過使用可變形注意力解決,它減少了需要關注的取樣點數量。傳統注意力由於全域性注意力而效率低下,並嚴重限制了影像可以具有的解析度。該模型只關注每個參考點周圍固定數量的取樣點,並且參考點由模型根據輸入學習。例如,在一張狗的影像中,參考點可能在狗的中心,取樣點靠近耳朵、嘴巴、尾巴等。
多尺度可變形注意力模組
第二個問題與YOLOv3的解決方法類似,即引入了多尺度特徵圖。在卷積神經網路中,早期層提取較小的細節(例如線條),而後期層提取較大的細節(例如輪子、耳朵)。以類似的方式,可變形注意力的不同層導致不同級別的解析度。透過將編碼器中一些這些層的輸出連線到解碼器,它可以使模型檢測多種大小的物件。
條件DETR
條件DETR也旨在解決原始DETR中訓練收斂緩慢的問題,導致收斂速度提高了6.7倍以上。作者發現物件查詢是通用的,並且不特定於輸入影像。在解碼器中使用條件交叉注意力,查詢可以更好地定位邊界框迴歸區域。
左圖:DETR解碼器層。右圖:條件DETR解碼器層
上圖中比較了原始DETR和條件DETR解碼器層,主要區別在於交叉注意力塊的查詢輸入。作者區分了內容查詢cq(解碼器自注意力輸出)和空間查詢pq。原始DETR只是簡單地將它們相加。在條件DETR中,它們被連線起來,cq關注物件的內容,pq關注邊界框區域。
空間查詢pq是解碼器嵌入和物件查詢都投影到相同空間(分別變為T和ps)並相乘的結果。前一層解碼器嵌入包含邊界框區域的資訊,物件查詢包含每個邊界框的學習參考點資訊。因此,它們的投影結合成一個表示,允許交叉注意力測量它們與編碼器輸入和正弦位置嵌入的相似性。這比只使用物件查詢和固定參考點的DETR更有效。
DETR推理
您可以使用Hugging Face Hub上現有的DETR模型進行推理,如下所示
from transformers import DetrImageProcessor, DetrForObjectDetection
import torch
from PIL import Image
import requests
url = "http://images.cocodataset.org/val2017/000000039769.jpg"
image = Image.open(requests.get(url, stream=True).raw)
# initialize the model
processor = DetrImageProcessor.from_pretrained(
"facebook/detr-resnet-101", revision="no_timm"
)
model = DetrForObjectDetection.from_pretrained(
"facebook/detr-resnet-101", revision="no_timm"
)
# preprocess the inputs and infer
inputs = processor(images=image, return_tensors="pt")
outputs = model(**inputs)
# convert outputs (bounding boxes and class logits) to COCO API
# non max supression above 0.9
target_sizes = torch.tensor([image.size[::-1]])
results = processor.post_process_object_detection(
outputs, target_sizes=target_sizes, threshold=0.9
)[0]
for score, label, box in zip(results["scores"], results["labels"], results["boxes"]):
box = [round(i, 2) for i in box.tolist()]
print(
f"Detected {model.config.id2label[label.item()]} with confidence "
f"{round(score.item(), 3)} at location {box}"
)輸出如下。
Detected cat with confidence 0.998 at location [344.06, 24.85, 640.34, 373.74]
Detected remote with confidence 0.997 at location [328.13, 75.93, 372.81, 187.66]
Detected remote with confidence 0.997 at location [39.34, 70.13, 175.56, 118.78]
Detected cat with confidence 0.998 at location [15.36, 51.75, 316.89, 471.16]
Detected couch with confidence 0.995 at location [-0.19, 0.71, 639.73, 474.17]DETR的PyTorch實現
原始論文中DETR的實現如下所示
import torch
from torch import nn
from torchvision.models import resnet50
class DETR(nn.Module):
def __init__(
self, num_classes, hidden_dim, nheads, num_encoder_layers, num_decoder_layers
):
super().__init__()
self.backbone = nn.Sequential(*list(resnet50(pretrained=True).children())[:-2])
self.conv = nn.Conv2d(2048, hidden_dim, 1)
self.transformer = nn.Transformer(
hidden_dim, nheads, num_encoder_layers, num_decoder_layers
)
self.linear_class = nn.Linear(hidden_dim, num_classes + 1)
self.linear_bbox = nn.Linear(hidden_dim, 4)
self.query_pos = nn.Parameter(torch.rand(100, hidden_dim))
self.row_embed = nn.Parameter(torch.rand(50, hidden_dim // 2))
self.col_embed = nn.Parameter(torch.rand(50, hidden_dim // 2))
def forward(self, inputs):
x = self.backbone(inputs)
h = self.conv(x)
H, W = h.shape[-2:]
pos = (
torch.cat(
[
self.col_embed[:W].unsqueeze(0).repeat(H, 1, 1),
self.row_embed[:H].unsqueeze(1).repeat(1, W, 1),
],
dim=-1,
)
.flatten(0, 1)
.unsqueeze(1)
)
h = self.transformer(
pos + h.flatten(2).permute(2, 0, 1), self.query_pos.unsqueeze(1)
)
return self.linear_class(h), self.linear_bbox(h).sigmoid()逐行閱讀前向函式:
骨幹
輸入影像首先透過ResNet骨幹網路,然後透過卷積層,將維度降低到hidden_dim。
x = self.backbone(inputs) h = self.conv(x)
它們在__init__函式中宣告。
self.backbone = nn.Sequential(*list(resnet50(pretrained=True).children())[:-2])
self.conv = nn.Conv2d(2048, hidden_dim, 1)位置嵌入
雖然在論文中,固定嵌入和訓練嵌入分別用於編碼器和解碼器,但作者為了簡化,在實現中兩者都使用了訓練嵌入。
pos = (
torch.cat(
[
self.col_embed[:W].unsqueeze(0).repeat(H, 1, 1),
self.row_embed[:H].unsqueeze(1).repeat(1, W, 1),
],
dim=-1,
)
.flatten(0, 1)
.unsqueeze(1)
)它們在這裡宣告為nn.Parameter。行和列嵌入結合起來表示影像中的位置。
self.row_embed = nn.Parameter(torch.rand(50, hidden_dim // 2))
self.col_embed = nn.Parameter(torch.rand(50, hidden_dim // 2))調整大小
在進入Transformer之前,大小為(batch size, hidden_dim, H, W)的特徵被重塑為(hidden_dim, batch size, H*W)。這使它們成為Transformer的序列輸入。
h.flatten(2).permute(2, 0, 1)Transformernn.Transformer函式將第一個引數作為編碼器的輸入,第二個引數作為編碼器的輸入。正如您所看到的,編碼器接收到調整大小的特徵與位置嵌入相加,而解碼器接收到query_pos,即解碼器位置嵌入。
h = self.transformer(pos + h.flatten(2).permute(2, 0, 1), self.query_pos.unsqueeze(1))前饋網路
最後,輸出(一個大小為(query_pos_dim, batch size, hidden_dim)的張量)透過兩個線性層。
return self.linear_class(h), self.linear_bbox(h).sigmoid()其中第一個預測類別。為“無物件”類別添加了一個額外的類別。
self.linear_class = nn.Linear(hidden_dim, num_classes + 1)第二個線性層預測邊界框,輸出大小為4,用於xy座標、高度和寬度。
self.linear_bbox = nn.Linear(hidden_dim, 4)