社群計算機視覺課程文件
Swin Transformer
並獲得增強的文件體驗
開始使用
Swin Transformer
Swin Transformer架構在2021年的論文Swin Transformer:使用移位視窗的分層視覺Transformer中引入,它採用移位視窗(而不是滑動視窗)方法來最佳化延遲和效能,從而減少了所需的運算元量。Swin被認為是計算機視覺的分層骨幹網路。Swin可用於影像分類等任務。
在深度學習中,骨幹網路是神經網路中執行特徵提取的部分。可以在骨幹網路上新增額外的層來執行各種視覺任務。分層骨幹網路具有分層結構,有時具有不同的解析度。這與VitDet模型中的非分層平面骨幹網路形成對比。
主要亮點
移位視窗
在原始ViT中,注意力是在每個補丁與所有其他補丁之間進行的,這在計算上是密集型的。Swin透過將ViT的通常二次複雜度降低為線性複雜度(相對於影像大小)來最佳化此過程。Swin使用類似於CNN的技術實現了這一點,其中補丁僅關注同一視窗中的其他補丁,而不是所有其他補丁,然後逐漸與相鄰補丁合併。這就是Swin成為分層模型的原因。
圖片摘自Swin Transformer論文
優勢
計算效率
Swin比完全基於補丁的方法(如ViT)效能更好。
大型資料集
SwinV2是首批30億引數模型之一。隨著訓練規模的增加,Swin超越了CNN。大量的引數使得學習能力和更復雜的表示能力得以提升。
Swin Transformer V2 (論文)
Swin Transformer V2是一個大型視覺模型,可支援高達30億引數,並能用高解析度影像進行訓練。它透過穩定訓練、將低解析度影像預訓練模型遷移到高解析度任務,以及使用SimMIM(一種自監督訓練方法,可減少訓練所需的標記影像數量)來改進原始Swin Transformer。
影像修復中的應用
SwinIR (論文)
SwinIR是一個基於Swin Transformer的模型,用於將低解析度影像轉換為高解析度影像。
Swin2SR (論文)
Swin2SR是另一個影像修復模型。它透過結合Swin Transformer V2,應用Swin V2的優勢(如訓練穩定性和更高影像解析度能力)對SwinIR進行了改進。
Swin的PyTorch實現概述
下面概述了原始論文中Swin的實現的關鍵部分
Swin Transformer類
初始化引數。除了各種dropout和歸一化引數之外,這些引數還包括
window_size:用於區域性自注意力的視窗大小。ape (bool):如果為True,則將絕對位置嵌入新增到補丁嵌入中。fused_window_process:可選的硬體最佳化。
應用補丁嵌入:與ViT類似,影像被分割成不重疊的補丁,並使用
Conv2D進行線性嵌入。應用位置嵌入:
SwinTransformer可選地使用絕對位置嵌入(ape),新增到補丁嵌入中。絕對位置嵌入通常有助於模型學習使用每個補丁的位置資訊以進行更明智的預測。應用深度衰減:深度衰減有助於正則化和防止過擬合。深度衰減通常透過在訓練期間跳過層來完成。在此Swin實現中,使用隨機深度衰減,這意味著層越深,跳過的可能性越大。
層構建:
- 模型由多個
SwinTransformerBlock層(BasicLayer)組成,每個層都使用PatchMerging對特徵圖進行下采樣以進行分層處理。 - 特徵的維度和特徵圖的解析度在不同層之間變化。
- 模型由多個
分類頭:與ViT類似,它使用多層感知器(MLP)頭進行分類任務,在最後一步定義為
self.head。
class SwinTransformer(nn.Module):
def __init__(
self,
img_size=224,
patch_size=4,
in_chans=3,
num_classes=1000,
embed_dim=96,
depths=[2, 2, 6, 2],
num_heads=[3, 6, 12, 24],
window_size=7,
mlp_ratio=4.0,
qkv_bias=True,
qk_scale=None,
drop_rate=0.0,
attn_drop_rate=0.0,
drop_path_rate=0.1,
norm_layer=nn.LayerNorm,
ape=False,
patch_norm=True,
use_checkpoint=False,
fused_window_process=False,
**kwargs,
):
super().__init__()
self.num_classes = num_classes
self.num_layers = len(depths)
self.embed_dim = embed_dim
self.ape = ape
self.patch_norm = patch_norm
self.num_features = int(embed_dim * 2 ** (self.num_layers - 1))
self.mlp_ratio = mlp_ratio
# split image into non-overlapping patches
self.patch_embed = PatchEmbed(
img_size=img_size,
patch_size=patch_size,
in_chans=in_chans,
embed_dim=embed_dim,
norm_layer=norm_layer if self.patch_norm else None,
)
num_patches = self.patch_embed.num_patches
patches_resolution = self.patch_embed.patches_resolution
self.patches_resolution = patches_resolution
# absolute position embedding
if self.ape:
self.absolute_pos_embed = nn.Parameter(torch.zeros(1, num_patches, embed_dim))
trunc_normal_(self.absolute_pos_embed, std=0.02)
self.pos_drop = nn.Dropout(p=drop_rate)
# stochastic depth
dpr = [
x.item() for x in torch.linspace(0, drop_path_rate, sum(depths))
] # stochastic depth decay rule
# build layers
self.layers = nn.ModuleList()
for i_layer in range(self.num_layers):
layer = BasicLayer(
dim=int(embed_dim * 2**i_layer),
input_resolution=(
patches_resolution[0] // (2**i_layer),
patches_resolution[1] // (2**i_layer),
),
depth=depths[i_layer],
num_heads=num_heads[i_layer],
window_size=window_size,
mlp_ratio=self.mlp_ratio,
qkv_bias=qkv_bias,
qk_scale=qk_scale,
drop=drop_rate,
attn_drop=attn_drop_rate,
drop_path=dpr[sum(depths[:i_layer]) : sum(depths[: i_layer + 1])],
norm_layer=norm_layer,
downsample=PatchMerging if (i_layer < self.num_layers - 1) else None,
use_checkpoint=use_checkpoint,
fused_window_process=fused_window_process,
)
self.layers.append(layer)
self.norm = norm_layer(self.num_features)
self.avgpool = nn.AdaptiveAvgPool1d(1)
self.head = (
nn.Linear(self.num_features, num_classes)
if num_classes > 0
else nn.Identity()
)
self.apply(self._init_weights)
def _init_weights(self, m):
if isinstance(m, nn.Linear):
trunc_normal_(m.weight, std=0.02)
if isinstance(m, nn.Linear) and m.bias is not None:
nn.init.constant_(m.bias, 0)
elif isinstance(m, nn.LayerNorm):
nn.init.constant_(m.bias, 0)
nn.init.constant_(m.weight, 1.0)
@torch.jit.ignore
def no_weight_decay(self):
return {"absolute_pos_embed"}
@torch.jit.ignore
def no_weight_decay_keywords(self):
return {"relative_position_bias_table"}
def forward_features(self, x):
x = self.patch_embed(x)
if self.ape:
x = x + self.absolute_pos_embed
x = self.pos_drop(x)
for layer in self.layers:
x = layer(x)
x = self.norm(x) # B L C
x = self.avgpool(x.transpose(1, 2)) # B C 1
x = torch.flatten(x, 1)
return x
def forward(self, x):
x = self.forward_features(x)
x = self.head(x)
return xSwin Transformer塊
SwinTransformerBlock封裝了Swin Transformer的核心操作:區域性視窗注意力(local windowed attention)和隨後的MLP處理。它透過關注區域性補丁同時保持學習全域性表示的能力,在Swin Transformer高效處理大影像方面發揮了關鍵作用。
層元件:
- 歸一化層1 (
self.norm1):在注意力機制之前應用。 - 視窗注意力 (
self.attn):在區域性視窗內計算自注意力。 - 丟棄路徑 (
self.drop_path):實現隨機深度以進行正則化。 - 歸一化層2 (
self.norm2):在MLP層之前應用。 - MLP (
mlp):一個多層感知器,用於處理注意力後的特徵。 - 注意力掩碼 (
self.register_buffer):注意力掩碼用於自注意力計算期間,以控制視窗化輸入中的哪些元素被允許相互互動(即相互關注)。移位視窗方法透過允許一些跨視窗互動來幫助模型捕獲更廣泛的上下文資訊。
Swin Transformer塊的初始化
class SwinTransformerBlock(nn.Module):
r"""Swin Transformer Block.
Args:
dim (int): Number of input channels.
input_resolution (tuple[int]): Input resulotion.
num_heads (int): Number of attention heads.
window_size (int): Window size.
shift_size (int): Shift size for SW-MSA.
mlp_ratio (float): Ratio of mlp hidden dim to embedding dim.
qkv_bias (bool, optional): If True, add a learnable bias to query, key, value. Default: True
qk_scale (float | None, optional): Override default qk scale of head_dim ** -0.5 if set.
drop (float, optional): Dropout rate. Default: 0.0
attn_drop (float, optional): Attention dropout rate. Default: 0.0
drop_path (float, optional): Stochastic depth rate. Default: 0.0
act_layer (nn.Module, optional): Activation layer. Default: nn.GELU
norm_layer (nn.Module, optional): Normalization layer. Default: nn.LayerNorm
fused_window_process (bool, optional): If True, use one kernel to fused window shift & window partition for acceleration, similar for the reversed part. Default: False
"""
def __init__(
self,
dim,
input_resolution,
num_heads,
window_size=7,
shift_size=0,
mlp_ratio=4.0,
qkv_bias=True,
qk_scale=None,
drop=0.0,
attn_drop=0.0,
drop_path=0.0,
act_layer=nn.GELU,
norm_layer=nn.LayerNorm,
fused_window_process=False,
):
super().__init__()
self.dim = dim
self.input_resolution = input_resolution
self.num_heads = num_heads
self.window_size = window_size
self.shift_size = shift_size
self.mlp_ratio = mlp_ratio
if min(self.input_resolution) <= self.window_size:
# if window size is larger than input resolution, we don't partition windows
self.shift_size = 0
self.window_size = min(self.input_resolution)
assert 0 <= self.shift_size < self.window_size, "shift_size must in 0-window_size"
self.norm1 = norm_layer(dim)
self.attn = WindowAttention(
dim,
window_size=to_2tuple(self.window_size),
num_heads=num_heads,
qkv_bias=qkv_bias,
qk_scale=qk_scale,
attn_drop=attn_drop,
proj_drop=drop,
)
self.drop_path = DropPath(drop_path) if drop_path > 0.0 else nn.Identity()
self.norm2 = norm_layer(dim)
mlp_hidden_dim = int(dim * mlp_ratio)
self.mlp = Mlp(
in_features=dim,
hidden_features=mlp_hidden_dim,
act_layer=act_layer,
drop=drop,
)
if self.shift_size > 0:
# calculate attention mask for SW-MSA
H, W = self.input_resolution
img_mask = torch.zeros((1, H, W, 1)) # 1 H W 1
h_slices = (
slice(0, -self.window_size),
slice(-self.window_size, -self.shift_size),
slice(-self.shift_size, None),
)
w_slices = (
slice(0, -self.window_size),
slice(-self.window_size, -self.shift_size),
slice(-self.shift_size, None),
)
cnt = 0
for h in h_slices:
for w in w_slices:
img_mask[:, h, w, :] = cnt
cnt += 1
mask_windows = window_partition(
img_mask, self.window_size
) # nW, window_size, window_size, 1
mask_windows = mask_windows.view(-1, self.window_size * self.window_size)
attn_mask = mask_windows.unsqueeze(1) - mask_windows.unsqueeze(2)
attn_mask = attn_mask.masked_fill(attn_mask != 0, float(-100.0)).masked_fill(
attn_mask == 0, float(0.0)
)
else:
attn_mask = None
self.register_buffer("attn_mask", attn_mask)
self.fused_window_process = fused_window_process
### New cell ###
def forward(self, x):
H, W = self.input_resolution
B, L, C = x.shape
assert L == H * W, "input feature has wrong size"
shortcut = x
x = self.norm1(x)
x = x.view(B, H, W, C)
# cyclic shift
if self.shift_size > 0:
if not self.fused_window_process:
shifted_x = torch.roll(
x, shifts=(-self.shift_size, -self.shift_size), dims=(1, 2)
)
# partition windows
x_windows = window_partition(
shifted_x, self.window_size
) # nW*B, window_size, window_size, C
else:
x_windows = WindowProcess.apply(
x, B, H, W, C, -self.shift_size, self.window_size
)
else:
shifted_x = x
# partition windows
x_windows = window_partition(
shifted_x, self.window_size
) # nW*B, window_size, window_size, C
x_windows = x_windows.view(
-1, self.window_size * self.window_size, C
) # nW*B, window_size*window_size, C
# W-MSA/SW-MSA
attn_windows = self.attn(
x_windows, mask=self.attn_mask
) # nW*B, window_size*window_size, C
# merge windows
attn_windows = attn_windows.view(-1, self.window_size, self.window_size, C)
# reverse cyclic shift
if self.shift_size > 0:
if not self.fused_window_process:
shifted_x = window_reverse(
attn_windows, self.window_size, H, W
) # B H' W' C
x = torch.roll(
shifted_x, shifts=(self.shift_size, self.shift_size), dims=(1, 2)
)
else:
x = WindowProcessReverse.apply(
attn_windows, B, H, W, C, self.shift_size, self.window_size
)
else:
shifted_x = window_reverse(attn_windows, self.window_size, H, W) # B H' W' C
x = shifted_x
x = x.view(B, H * W, C)
x = shortcut + self.drop_path(x)
# Feed-forward network (FFN)
x = x + self.drop_path(self.mlp(self.norm2(x)))
return xSwin Transformer塊的前向傳播
共有4個關鍵步驟
- 迴圈移位:特徵圖透過
window_partition被劃分為視窗。然後對這些分割槽應用迴圈移位。迴圈移位是透過將序列中的元素(在此例中為分割槽)向左或向右移動,並將超出邊界的元素迴圈回到另一端來完成的。這個過程改變了元素彼此之間的相對位置,但保持了序列的完整性。例如,如果將序列A, B, C, D向右迴圈移位一個位置,它將變為D, A, B, C。
迴圈移位允許模型捕獲相鄰視窗之間的關係,從而增強其學習超出單個視窗區域性範圍的空間上下文的能力。
視窗注意力:使用基於視窗的多頭自注意力(W-MSA)模組執行注意力。
合併補丁:透過
PatchMerging合併補丁。逆向迴圈移位:注意力完成後,透過
reverse_window撤銷視窗劃分,並反轉迴圈移位操作,使特徵圖恢復其原始形式。
class WindowAttention(nn.Module):
"""
Args:
dim (int): Number of input channels.
window_size (tuple[int]): The height and width of the window.
num_heads (int): Number of attention heads.
qkv_bias (bool, optional): If True, add a learnable bias to query, key, value. Default: True
qk_scale (float | None, optional): Override default qk scale of head_dim ** -0.5 if set
attn_drop (float, optional): Dropout ratio of attention weight. Default: 0.0
proj_drop (float, optional): Dropout ratio of output. Default: 0.0
"""
def __init__(
self,
dim,
window_size,
num_heads,
qkv_bias=True,
qk_scale=None,
attn_drop=0.0,
proj_drop=0.0,
):
super().__init__()
self.dim = dim
self.window_size = window_size # Wh, Ww
self.num_heads = num_heads
head_dim = dim // num_heads
self.scale = qk_scale or head_dim**-0.5
# define a parameter table of relative position bias
self.relative_position_bias_table = nn.Parameter(
torch.zeros((2 * window_size[0] - 1) * (2 * window_size[1] - 1), num_heads)
) # 2*Wh-1 * 2*Ww-1, nH
# get pair-wise relative position index for each token inside the window
coords_h = torch.arange(self.window_size[0])
coords_w = torch.arange(self.window_size[1])
coords = torch.stack(torch.meshgrid([coords_h, coords_w])) # 2, Wh, Ww
coords_flatten = torch.flatten(coords, 1) # 2, Wh*Ww
relative_coords = (
coords_flatten[:, :, None] - coords_flatten[:, None, :]
) # 2, Wh*Ww, Wh*Ww
relative_coords = relative_coords.permute(1, 2, 0).contiguous() # Wh*Ww, Wh*Ww, 2
relative_coords[:, :, 0] += self.window_size[0] - 1 # shift to start from 0
relative_coords[:, :, 1] += self.window_size[1] - 1
relative_coords[:, :, 0] *= 2 * self.window_size[1] - 1
relative_position_index = relative_coords.sum(-1) # Wh*Ww, Wh*Ww
self.register_buffer("relative_position_index", relative_position_index)
self.qkv = nn.Linear(dim, dim * 3, bias=qkv_bias)
self.attn_drop = nn.Dropout(attn_drop)
self.proj = nn.Linear(dim, dim)
self.proj_drop = nn.Dropout(proj_drop)
trunc_normal_(self.relative_position_bias_table, std=0.02)
self.softmax = nn.Softmax(dim=-1)
def forward(self, x, mask=None):
"""
Args:
x: input features with shape of (num_windows*B, N, C)
mask: (0/-inf) mask with shape of (num_windows, Wh*Ww, Wh*Ww) or None
"""
B_, N, C = x.shape
qkv = (
self.qkv(x)
.reshape(B_, N, 3, self.num_heads, C // self.num_heads)
.permute(2, 0, 3, 1, 4)
)
q, k, v = (
qkv[0],
qkv[1],
qkv[2],
) # make torchscript happy (cannot use tensor as tuple)
q = q * self.scale
attn = q @ k.transpose(-2, -1)
relative_position_bias = self.relative_position_bias_table[
self.relative_position_index.view(-1)
].view(
self.window_size[0] * self.window_size[1],
self.window_size[0] * self.window_size[1],
-1,
) # Wh*Ww,Wh*Ww,nH
relative_position_bias = relative_position_bias.permute(
2, 0, 1
).contiguous() # nH, Wh*Ww, Wh*Ww
attn = attn + relative_position_bias.unsqueeze(0)
if mask is not None:
nW = mask.shape[0]
attn = attn.view(B_ // nW, nW, self.num_heads, N, N) + mask.unsqueeze(
1
).unsqueeze(0)
attn = attn.view(-1, self.num_heads, N, N)
attn = self.softmax(attn)
else:
attn = self.softmax(attn)
attn = self.attn_drop(attn)
x = (attn @ v).transpose(1, 2).reshape(B_, N, C)
x = self.proj(x)
x = self.proj_drop(x)
return x視窗注意力
WindowAttention是一個基於視窗的多頭自注意力(W-MSA)模組,帶有相對位置偏置。它既可用於移位視窗,也可用於非移位視窗。
class PatchMerging(nn.Module):
r"""Patch Merging Layer.
Args:
input_resolution (tuple[int]): Resolution of input feature.
dim (int): Number of input channels.
norm_layer (nn.Module, optional): Normalization layer. Default: nn.LayerNorm
"""
def __init__(self, input_resolution, dim, norm_layer=nn.LayerNorm):
super().__init__()
self.input_resolution = input_resolution
self.dim = dim
self.reduction = nn.Linear(4 * dim, 2 * dim, bias=False)
self.norm = norm_layer(4 * dim)
def forward(self, x):
"""
x: B, H*W, C
"""
H, W = self.input_resolution
B, L, C = x.shape
assert L == H * W, "input feature has wrong size"
assert H % 2 == 0 and W % 2 == 0, f"x size ({H}*{W}) are not even."
x = x.view(B, H, W, C)
x0 = x[:, 0::2, 0::2, :] # B H/2 W/2 C
x1 = x[:, 1::2, 0::2, :] # B H/2 W/2 C
x2 = x[:, 0::2, 1::2, :] # B H/2 W/2 C
x3 = x[:, 1::2, 1::2, :] # B H/2 W/2 C
x = torch.cat([x0, x1, x2, x3], -1) # B H/2 W/2 4*C
x = x.view(B, -1, 4 * C) # B H/2*W/2 4*C
x = self.norm(x)
x = self.reduction(x)
return x補丁合併層
補丁合併方法用於下采樣。它用於減少特徵圖的空間維度,類似於傳統卷積神經網路(CNN)中的池化。它透過逐步增加感受野和降低空間解析度來幫助構建分層特徵表示。
from datasets import load_dataset
from transformers import AutoImageProcessor, SwinForImageClassification
import torch
model = SwinForImageClassification.from_pretrained(
"microsoft/swin-tiny-patch4-window7-224"
)
image_processor = AutoImageProcessor.from_pretrained(
"microsoft/swin-tiny-patch4-window7-224"
)
dataset = load_dataset("huggingface/cats-image")
image = dataset["test"]["image"][0]
inputs = image_processor(image, return_tensors="pt")
with torch.no_grad():
logits = model(**inputs).logits
predicted_label_id = logits.argmax(-1).item()
predicted_label_text = model.config.id2label[predicted_label_id]
print(predicted_label_text)試一試
您可以在這裡找到Swin的🤗文件。
使用預訓練Swin模型進行分類
以下是使用Swin模型將貓影像分類到1000個ImageNet類別之一的方法
from datasets import load_dataset
from transformers import AutoImageProcessor, SwinForImageClassification
import torch
model = SwinForImageClassification.from_pretrained(
"microsoft/swin-tiny-patch4-window7-224"
)
image_processor = AutoImageProcessor.from_pretrained(
"microsoft/swin-tiny-patch4-window7-224"
)
dataset = load_dataset("huggingface/cats-image")
image = dataset["test"]["image"][0]
inputs = image_processor(image, return_tensors="pt")
with torch.no_grad():
logits = model(**inputs).logits
predicted_label_id = logits.argmax(-1).item()
predicted_label_text = model.config.id2label[predicted_label_id]
print(predicted_label_text)