说之前先提一个视频这个视频还是很好的将transformer
机制的变迁及未来的趋势很详细的说明了一下我觉得蛮有感触的,建议可以看看这里首先提一下代码及其对应的论文视频地址。
paper:Swin Transformer: Hierarchical Vision Transformer using Shifted Windows
code: microsoft/Swin-Transformer
可以理解SwinTransformer是新一代的特征提取神器,很多榜单都有它的影子,这里我们可以理解为是一种新的`backbone,如下所示支持多种下游任务。相对比之前说的Transformer 在图像中的运用(一)VIT(Transformers for Image Recognition at Scale)论文及代码解读 之前需要每个像素
一、 原理
在Transformer种,如果图像像素太多则我们需要构建出更多的特征序列,这样就会导致我们的效率降低,所以我们采用了窗口
以及分层
的形式来替代长序列。
1.1 整体网络架构
- 得到各
Patch
特征构建的序列(注意这里先卷积得到特征图,再对特征图进行切分成Patch
) - 分成计算
attention
(逐步下采样过程) - 其中Block是最核心的, 对attention的计算方法进行了改进
由下面的图我们可以看出特征图大小不断减小, 但是特征图的通道数不断增加。
1.1.1 Patch Embedding
下面举一个例子比如输入的图像数据为(224, 224, 3)
, 输出(3136, 96)
相当于序列长度为3136
, 每个向量是96维
特征。这里的卷积核我们使用Conv2d(3, 96, kernel_size=(4, 4), stride=(4, 4))
。所以3136
就是卷积(224 / 4) * (224 / 4)
得到的。
这时候我们得到的输入特征图为(56, 56, 96)
, 如果默认窗口大小为7,所以总共可以分为8 * 8
个窗口。则输出的特征图为(64(8*8), 7, 7, 96)
之前单位是序列
, 现在单位是窗口(工64个窗口)
。
1.1.2 Swin Transformer Block
下面我们来看下上面图中对应的Transformer Blocks
是什么样子, 如下图所示。
上图的两个组合是串联而成的Block,对于左边为基于窗口的注意力计算
W-MSA(multi-head self attention modules with regular)
,对于右边为窗口滑动后重新计算注意力SW-MSA(multi-head self attention modules with shifted windowing)
1. W-MSA(计算每个不同窗口自身
的注意力机制(下面不同颜色的矩形代表不同的窗口))
对得到的窗口,计算各个窗口自己的自注意力得分,
qkv
三个矩阵放在一起得到(3, 64, 3, 49, 32)
。
-
3
个矩阵 -
64
个窗口 -
3
个heads -
7*7
的窗口大小(每个窗口有49个token即49个像素) -
96/3=32
个单head特征
所以attention
结果为(64, 3, 49, 49)
每个头都会得出每个窗口内的自注意力(3为头,这里可以理解为不同窗口不同头对应窗口的不同token之间的注意力)。
通过上面的计算我们可以得到新的特征(64, 49, 96)
, 之后再进行reshape
操作将其还原到(56, 56, 96)
大小特征图目的就是为了还原输入特征图大小(但是其已经计算过了attentation
), 因为再transformer要经过多层输入大小与输出大小一般都是相同的。
下面给出了省出来的计算量。
这里计算量公式可以参考这篇文章Swin-Transformer网络结构详解。
2. SW-MSA(计算不同窗口之间
的注意力机制)
上面W-MSA
是只是知道窗口内部的特征,但是我们不知道窗口之间的特征我们可以用SW-MSA
机制来弥补。这里的主要区别就是S(shift滑动)
,我们如何去做滑动呢?
上图中我们可以看出网格由红色网格(b)移动到了蓝色网格(c),我们需要通过将上方蓝色区域移动到下方,左边红色区域移动到右边。这么做的目的如下:
记住这里是半个窗口, 还有一点记住是向下取整(如窗口大小3, 则移动为1)
说白了就是换一换所有不同窗口的匹配对,使得模型更加健壮,这就是滑动操作。
由于不同Windows之间互不重叠,每次进行自注意力计算时很显然就丢失了Windows之间的信息,那么如何在降低计算量的同时保留全局信息呢?Shifted Window应运而生。
上面这张图可以用如下的示意图理解:
但是还有一个问题原来是4个windows,但是移动之后变成了9个windows,为了能够做到并行计算应该如何解决呢?我们可以做如下偏移方法。
则得到如下效果:
Attention Mask 机制
因为我们区域(5,3) (7,1) (8,6,2,0)本来是之间不想连接的,所以我们要单独计算各自的区域的MSA。我们借用区域(5,3)举例,这篇博客对于这个解释非常棒Swin-Transformer网络结构详解, 如下所示:
这里我们仅仅计算
区域5
的信息而不想引入区域3
的信息,我们通过掩码mask
的方式即可计算。因为本来公式中是一个很小的数字如果我们减去100
, 再经过softmax
可以理解为就是为0了。注意,全部计算完后需要将数据挪回到原来的位置上。下面演示一下整体流程
因为要经过多层transformer通过W-MSA以及SW-MSA输出的大小保持不变(56*56*96)
1.1.2 Relative Position Bias
下面我们看下加相对偏置与不加相对偏置的效果
发现使用
rel.pos
相对位置偏置更加合理。如何将一元坐标转成二元坐标呢?我们看作者如何去做的。
1.1.3 PatchMerging
这里我们就要说到这里
Patch Merging
操作。它的作用可以缩小特征图大小,提升特征图的通道数(这里也可以理解为就是下采样操作)。二、 代码逻辑解读
# file: models/swin_transformer.py
# class: SwinTransformer
class SwinTransformer(nn.Module):
r""" Swin Transformer
A PyTorch impl of : `Swin Transformer: Hierarchical Vision Transformer using Shifted Windows` -
https://arxiv.org/pdf/2103.14030
Args:
img_size (int | tuple(int)): Input image size. Default 224
patch_size (int | tuple(int)): Patch size. Default: 4
in_chans (int): Number of input image channels. Default: 3
num_classes (int): Number of classes for classification head. Default: 1000
embed_dim (int): Patch embedding dimension. Default: 96
depths (tuple(int)): Depth of each Swin Transformer layer.
num_heads (tuple(int)): Number of attention heads in different layers.
window_size (int): Window size. Default: 7
mlp_ratio (float): Ratio of mlp hidden dim to embedding dim. Default: 4
qkv_bias (bool): If True, add a learnable bias to query, key, value. Default: True
qk_scale (float): Override default qk scale of head_dim ** -0.5 if set. Default: None
drop_rate (float): Dropout rate. Default: 0
attn_drop_rate (float): Attention dropout rate. Default: 0
drop_path_rate (float): Stochastic depth rate. Default: 0.1
norm_layer (nn.Module): Normalization layer. Default: nn.LayerNorm.
ape (bool): If True, add absolute position embedding to the patch embedding. Default: False
patch_norm (bool): If True, add normalization after patch embedding. Default: True
use_checkpoint (bool): Whether to use checkpointing to save memory. Default: False
"""
def __init__(self, img_size=224, patch_size=4, in_chans=3, num_classes=1000,
embed_dim=96, depths=[2, 2, 6, 2], num_heads=[3, 6, 12, 24],
window_size=7, mlp_ratio=4., qkv_bias=True, qk_scale=None,
drop_rate=0., attn_drop_rate=0., drop_path_rate=0.1,
norm_layer=nn.LayerNorm, ape=False, patch_norm=True,
use_checkpoint=False, **kwargs):
super().__init__()
self.num_classes = num_classes
self.num_layers = len(depths)
self.embed_dim = embed_dim
self.ape = ape
self.patch_norm = patch_norm
self.num_features = int(embed_dim * 2 ** (self.num_layers - 1))
self.mlp_ratio = mlp_ratio
# split image into non-overlapping patches
self.patch_embed = PatchEmbed(
img_size=img_size, patch_size=patch_size, in_chans=in_chans, embed_dim=embed_dim,
norm_layer=norm_layer if self.patch_norm else None)
num_patches = self.patch_embed.num_patches
patches_resolution = self.patch_embed.patches_resolution
self.patches_resolution = patches_resolution
# absolute position embedding
if self.ape:
self.absolute_pos_embed = nn.Parameter(torch.zeros(1, num_patches, embed_dim))
trunc_normal_(self.absolute_pos_embed, std=.02)
self.pos_drop = nn.Dropout(p=drop_rate)
# stochastic depth
# 这里的drop rate是会随着模型不同stage不断提升到我们设定的rate
dpr = [x.item() for x in torch.linspace(0, drop_path_rate, sum(depths))] # stochastic depth decay rule
# build layers
self.layers = nn.ModuleList()
for i_layer in range(self.num_layers):
layer = BasicLayer(dim=int(embed_dim * 2 ** i_layer), # 我们的深度不断乘上2
input_resolution=(patches_resolution[0] // (2 ** i_layer),
patches_resolution[1] // (2 ** i_layer)),
depth=depths[i_layer],
num_heads=num_heads[i_layer],
window_size=window_size,
mlp_ratio=self.mlp_ratio,
qkv_bias=qkv_bias, qk_scale=qk_scale,
drop=drop_rate, attn_drop=attn_drop_rate,
drop_path=dpr[sum(depths[:i_layer]):sum(depths[:i_layer + 1])],
norm_layer=norm_layer,
downsample=PatchMerging if (i_layer < self.num_layers - 1) else None, # 这里transoformer和patchMerge是连在一起的最后一个没有transformer只有patchMerge
use_checkpoint=use_checkpoint)
self.layers.append(layer)
self.norm = norm_layer(self.num_features)
self.avgpool = nn.AdaptiveAvgPool1d(1)
self.head = nn.Linear(self.num_features, num_classes) if num_classes > 0 else nn.Identity()
self.apply(self._init_weights)
2.1 input embedding
# file: models/swin_transformer.py
# class: SwinTransformer
def forward_features(self, x):
x = self.patch_embed(x)
if self.ape:
x = x + self.absolute_pos_embed
x = self.pos_drop(x)
for layer in self.layers:
x = layer(x)
x = self.norm(x) # B L C
x = self.avgpool(x.transpose(1, 2)) # B C 1
x = torch.flatten(x, 1)
return x
这里我们的输入大小为4(batch), 3(channel), 224(width), 224(height)
, 接着进入到self.patch_embed
操作。
# file: swin_transformer.py
# class: PatchEmbed
def forward(self, x):
B, C, H, W = x.shape
# FIXME look at relaxing size constraints
assert H == self.img_size[0] and W == self.img_size[1], \
f"Input image size ({H}*{W}) doesn't match model ({self.img_size[0]}*{self.img_size[1]})."
x = self.proj(x).flatten(2).transpose(1, 2) # B Ph*Pw C
if self.norm is not None:
x = self.norm(x)
return x
和以往vit
一样,这里做self.proj
就是进行卷积操作
# 卷积核大小为4, stride也是为4, 这样会导致输出特征图为原来的额1/4 -> (56 * 56)
# 输入输出channel分别为3和96
self.proj = nn.Conv2d(in_chans, embed_dim, kernel_size=patch_size, stride=patch_size
# 这部分flatten操作是将我们的宽度高度展平,输出shape为(4, 3136(56*56), 96)
x = self.proj(x).flatten(2).transpose(1, 2)
在经过self.norm
对应的操作为nn.LayerNorm
。
接着我们会经过我们的self.pos_drop(x)
, 这里的self.pos_drop
为nn.Dropout(p=drop_rate)
操作。
接着进行下面各个层的操作(别忘记此时我们的输入shape为
(4, 3136(56*56), 96))
for layer in self.layers:
x = layer(x)
2.2 Basiclayer
接着上面我们看一下self.layers
是如何构建的
# file: models/swin_transformer.py
# class: SwinTransformer
for i_layer in range(self.num_layers):
layer = BasicLayer(dim=int(embed_dim * 2 ** i_layer),
input_resolution=(patches_resolution[0] // (2 ** i_layer),
patches_resolution[1] // (2 ** i_layer)),
depth=depths[i_layer],
num_heads=num_heads[i_layer],
window_size=window_size,
mlp_ratio=self.mlp_ratio,
qkv_bias=qkv_bias, qk_scale=qk_scale,
drop=drop_rate, attn_drop=attn_drop_rate,
drop_path=dpr[sum(depths[:i_layer]):sum(depths[:i_layer + 1])],
norm_layer=norm_layer,
downsample=PatchMerging if (i_layer < self.num_layers - 1) else None,
use_checkpoint=use_checkpoint)
self.layers.append(layer)
# file: models/swin_transformer.py
# class: BasicLayer
class BasicLayer(nn.Module):
""" A basic Swin Transformer layer for one stage.
Args:
dim (int): Number of input channels.
input_resolution (tuple[int]): Input resolution.
depth (int): Number of blocks.
num_heads (int): Number of attention heads.
window_size (int): Local window size.
mlp_ratio (float): Ratio of mlp hidden dim to embedding dim.
qkv_bias (bool, optional): If True, add a learnable bias to query, key, value. Default: True
qk_scale (float | None, optional): Override default qk scale of head_dim ** -0.5 if set.
drop (float, optional): Dropout rate. Default: 0.0
attn_drop (float, optional): Attention dropout rate. Default: 0.0
drop_path (float | tuple[float], optional): Stochastic depth rate. Default: 0.0
norm_layer (nn.Module, optional): Normalization layer. Default: nn.LayerNorm
downsample (nn.Module | None, optional): Downsample layer at the end of the layer. Default: None
use_checkpoint (bool): Whether to use checkpointing to save memory. Default: False.
"""
def __init__(self, dim, input_resolution, depth, num_heads, window_size,
mlp_ratio=4., qkv_bias=True, qk_scale=None, drop=0., attn_drop=0.,
drop_path=0., norm_layer=nn.LayerNorm, downsample=None, use_checkpoint=False):
super().__init__()
self.dim = dim
self.input_resolution = input_resolution
self.depth = depth
self.use_checkpoint = use_checkpoint
# build blocks
self.blocks = nn.ModuleList([
SwinTransformerBlock(dim=dim, input_resolution=input_resolution,
num_heads=num_heads, window_size=window_size,
shift_size=0 if (i % 2 == 0) else window_size // 2,
mlp_ratio=mlp_ratio,
qkv_bias=qkv_bias, qk_scale=qk_scale,
drop=drop, attn_drop=attn_drop,
drop_path=drop_path[i] if isinstance(drop_path, list) else drop_path,
norm_layer=norm_layer)
for i in range(depth)])
# patch merging layer
if downsample is not None:
self.downsample = downsample(input_resolution, dim=dim, norm_layer=norm_layer)
else:
self.downsample = None
def forward(self, x):
for blk in self.blocks:
if self.use_checkpoint:
x = checkpoint.checkpoint(blk, x)
else:
x = blk(x)
if self.downsample is not None:
x = self.downsample(x)
return x
2.2 SwinTransformerBlock
# file: models/swin_transformer.py
# class: SwinTransformerBlock
class SwinTransformerBlock(nn.Module):
r""" Swin Transformer Block.
Args:
dim (int): Number of input channels.
input_resolution (tuple[int]): Input resulotion.
num_heads (int): Number of attention heads.
window_size (int): Window size.
shift_size (int): Shift size for SW-MSA.
mlp_ratio (float): Ratio of mlp hidden dim to embedding dim.
qkv_bias (bool, optional): If True, add a learnable bias to query, key, value. Default: True
qk_scale (float | None, optional): Override default qk scale of head_dim ** -0.5 if set.
drop (float, optional): Dropout rate. Default: 0.0
attn_drop (float, optional): Attention dropout rate. Default: 0.0
drop_path (float, optional): Stochastic depth rate. Default: 0.0
act_layer (nn.Module, optional): Activation layer. Default: nn.GELU
norm_layer (nn.Module, optional): Normalization layer. Default: nn.LayerNorm
"""
def __init__(self, dim, input_resolution, num_heads, window_size=7, shift_size=0,
mlp_ratio=4., qkv_bias=True, qk_scale=None, drop=0., attn_drop=0., drop_path=0.,
act_layer=nn.GELU, norm_layer=nn.LayerNorm):
super().__init__()
self.dim = dim
self.input_resolution = input_resolution
self.num_heads = num_heads
self.window_size = window_size
self.shift_size = shift_size
self.mlp_ratio = mlp_ratio
if min(self.input_resolution) <= self.window_size:
# if window size is larger than input resolution, we don't partition windows
self.shift_size = 0
self.window_size = min(self.input_resolution)
assert 0 <= self.shift_size < self.window_size, "shift_size must in 0-window_size"
self.norm1 = norm_layer(dim)
self.attn = WindowAttention(
dim, window_size=to_2tuple(self.window_size), num_heads=num_heads,
qkv_bias=qkv_bias, qk_scale=qk_scale, attn_drop=attn_drop, proj_drop=drop)
self.drop_path = DropPath(drop_path) if drop_path > 0. else nn.Identity()
self.norm2 = norm_layer(dim)
mlp_hidden_dim = int(dim * mlp_ratio)
self.mlp = Mlp(in_features=dim, hidden_features=mlp_hidden_dim, act_layer=act_layer, drop=drop)
if self.shift_size > 0:
# calculate attention mask for SW-MSA
H, W = self.input_resolution
img_mask = torch.zeros((1, H, W, 1)) # 1 H W 1
h_slices = (slice(0, -self.window_size),
slice(-self.window_size, -self.shift_size),
slice(-self.shift_size, None))
w_slices = (slice(0, -self.window_size),
slice(-self.window_size, -self.shift_size),
slice(-self.shift_size, None))
cnt = 0
for h in h_slices:
for w in w_slices:
img_mask[:, h, w, :] = cnt
cnt += 1
mask_windows = window_partition(img_mask, self.window_size) # nW, window_size, window_size, 1
mask_windows = mask_windows.view(-1, self.window_size * self.window_size)
attn_mask = mask_windows.unsqueeze(1) - mask_windows.unsqueeze(2)
attn_mask = attn_mask.masked_fill(attn_mask != 0, float(-100.0)).masked_fill(attn_mask == 0, float(0.0))
else:
attn_mask = None
self.register_buffer("attn_mask", attn_mask)
def forward(self, x):
H, W = self.input_resolution
B, L, C = x.shape
assert L == H * W, "input feature has wrong size"
shortcut = x
x = self.norm1(x)
x = x.view(B, H, W, C)
# cyclic shift
if self.shift_size > 0:
shifted_x = torch.roll(x, shifts=(-self.shift_size, -self.shift_size), dims=(1, 2))
else:
shifted_x = x
# partition windows
x_windows = window_partition(shifted_x, self.window_size) # nW*B, window_size, window_size, C
x_windows = x_windows.view(-1, self.window_size * self.window_size, C) # nW*B, window_size*window_size, C
# W-MSA/SW-MSA
attn_windows = self.attn(x_windows, mask=self.attn_mask) # nW*B, window_size*window_size, C
# merge windows
attn_windows = attn_windows.view(-1, self.window_size, self.window_size, C)
shifted_x = window_reverse(attn_windows, self.window_size, H, W) # B H' W' C
# reverse cyclic shift
if self.shift_size > 0:
x = torch.roll(shifted_x, shifts=(self.shift_size, self.shift_size), dims=(1, 2))
else:
x = shifted_x
x = x.view(B, H * W, C)
# FFN
x = shortcut + self.drop_path(x)
x = x + self.drop_path(self.mlp(self.norm2(x)))
return x
2.2.1 W-MSA及SW-MSA输入
我们知道输入是先经过W-MSA
再经过SW-MSA
经过W-MSA
是没有做任何处理的即代码中shifted_x = x
, 但是对于W-MSA
是通过torch.roll
的操作进行的,代码如下所示:
shifted_x = torch.roll(x, shifts=(-self.shift_size, -self.shift_size), dims=(1, 2))
这里有1和2,分别表示要左右上下移动,还有就是这里的self.shift_size为负数,说明移动完处理之后这里还是要复原的
如下代码所示:
# merge windows
attn_windows = attn_windows.view(-1, self.window_size, self.window_size, C)
shifted_x = window_reverse(attn_windows, self.window_size, H, W) # B H' W' C # 第一个block得到(4, 56, 56, 96)
# reverse cyclic shift
if self.shift_size > 0:
x = torch.roll(shifted_x, shifts=(self.shift_size, self.shift_size), dims=(1, 2))
else:
x = shifted_x
x = x.view(B, H * W, C)
最终得到的shape依然是我们原来的输入(4, 3136, 96)
接着下进入如下操作
# partition windows
x_windows = window_partition(shifted_x, self.window_size) # nW*B, window_size, window_size, C
x_windows = x_windows.view(-1, self.window_size * self.window_size, C) # nW*B, window_size*window_size, C
比如一开始第一个block
我们的得到的第一个输出shape为(256, 7, 7, 96)
然后我们得到第二个windows为(256, 49, 96)
。相当于256
个windows
, 每个windows
由49
个像素
, 每个像素
由96
个维度。
对于上面代码中的window_partition
代码如下:
def window_partition(x, window_size):
"""
Args:
x: (B, H, W, C)
window_size (int): window size
Returns:
windows: (num_windows*B, window_size, window_size, C)
"""
B, H, W, C = x.shape
x = x.view(B, H // window_size, window_size, W // window_size, window_size, C)
windows = x.permute(0, 1, 3, 2, 4, 5).contiguous().view(-1, window_size, window_size, C)
return windows
这里的x
shape为(4, 8, 7, 8, 7, 96)
, 我们可以得到windows的数量为
(H/windows_size) * (W/windows_size) * batch
, 这里W
,H
一开始都为56
, windows_size
为7
, 这里设置的batch
为4
, 因此这里我们最终得到的windows
shape为(256 7 7 96)
。
2.2.2 Attention机制
上面的输出之后我们要经过我们的Attention
机制。
attn_windows = self.attn(x_windows, mask=self.attn_mask) # nW*B, window_size*window_size, C
如果x_windows
是W-MSA
则self.atten_mask
为None
, 否则会加入atten_mask
, 具体代码如下(详细理解可以参考bilibili, 在31分钟左右 说的非常好):
if self.shift_size > 0:
# calculate attention mask for SW-MSA
H, W = self.input_resolution
img_mask = torch.zeros((1, H, W, 1)) # 1 H W 1
h_slices = (slice(0, -self.window_size),
slice(-self.window_size, -self.shift_size),
slice(-self.shift_size, None))
w_slices = (slice(0, -self.window_size),
slice(-self.window_size, -self.shift_size),
slice(-self.shift_size, None))
cnt = 0
for h in h_slices:
for w in w_slices:
img_mask[:, h, w, :] = cnt
cnt += 1
mask_windows = window_partition(img_mask, self.window_size) # nW, window_size, window_size, 1
mask_windows = mask_windows.view(-1, self.window_size * self.window_size)
attn_mask = mask_windows.unsqueeze(1) - mask_windows.unsqueeze(2)
attn_mask = attn_mask.masked_fill(attn_mask != 0, float(-100.0)).masked_fill(attn_mask == 0, float(0.0))
else:
attn_mask = None
对应上述代码简单点就是再不需要做内积的地方填入-100
, 这样经过softmax
的时候就被自动设置为0
了。
下面先给出我们进入attention
的代码。
class WindowAttention(nn.Module):
r""" Window based multi-head self attention (W-MSA) module with relative position bias.
It supports both of shifted and non-shifted window.
Args:
dim (int): Number of input channels.
window_size (tuple[int]): The height and width of the window.
num_heads (int): Number of attention heads.
qkv_bias (bool, optional): If True, add a learnable bias to query, key, value. Default: True
qk_scale (float | None, optional): Override default qk scale of head_dim ** -0.5 if set
attn_drop (float, optional): Dropout ratio of attention weight. Default: 0.0
proj_drop (float, optional): Dropout ratio of output. Default: 0.0
"""
def __init__(self, dim, window_size, num_heads, qkv_bias=True, qk_scale=None, attn_drop=0., proj_drop=0.):
super().__init__()
self.dim = dim
self.window_size = window_size # Wh, Ww
self.num_heads = num_heads
head_dim = dim // num_heads
self.scale = qk_scale or head_dim ** -0.5
# define a parameter table of relative position bias
self.relative_position_bias_table = nn.Parameter(
torch.zeros((2 * window_size[0] - 1) * (2 * window_size[1] - 1), num_heads)) # 2*Wh-1 * 2*Ww-1, nH
# get pair-wise relative position index for each token inside the window
coords_h = torch.arange(self.window_size[0])
coords_w = torch.arange(self.window_size[1])
coords = torch.stack(torch.meshgrid([coords_h, coords_w])) # 2, Wh, Ww
coords_flatten = torch.flatten(coords, 1) # 2, Wh*Ww
relative_coords = coords_flatten[:, :, None] - coords_flatten[:, None, :] # 2, Wh*Ww, Wh*Ww
relative_coords = relative_coords.permute(1, 2, 0).contiguous() # Wh*Ww, Wh*Ww, 2
relative_coords[:, :, 0] += self.window_size[0] - 1 # shift to start from 0
relative_coords[:, :, 1] += self.window_size[1] - 1
relative_coords[:, :, 0] *= 2 * self.window_size[1] - 1
relative_position_index = relative_coords.sum(-1) # Wh*Ww, Wh*Ww
self.register_buffer("relative_position_index", relative_position_index)
self.qkv = nn.Linear(dim, dim * 3, bias=qkv_bias)
self.attn_drop = nn.Dropout(attn_drop)
self.proj = nn.Linear(dim, dim)
self.proj_drop = nn.Dropout(proj_drop)
trunc_normal_(self.relative_position_bias_table, std=.02)
self.softmax = nn.Softmax(dim=-1)
def forward(self, x, mask=None):
"""
Args:
x: input features with shape of (num_windows*B, N, C)
mask: (0/-inf) mask with shape of (num_windows, Wh*Ww, Wh*Ww) or None
"""
B_, N, C = x.shape
qkv = self.qkv(x).reshape(B_, N, 3, self.num_heads, C // self.num_heads).permute(2, 0, 3, 1, 4)
q, k, v = qkv[0], qkv[1], qkv[2] # make torchscript happy (cannot use tensor as tuple)
q = q * self.scale
attn = (q @ k.transpose(-2, -1))
relative_position_bias = self.relative_position_bias_table[self.relative_position_index.view(-1)].view(
self.window_size[0] * self.window_size[1], self.window_size[0] * self.window_size[1], -1) # Wh*Ww,Wh*Ww,nH
relative_position_bias = relative_position_bias.permute(2, 0, 1).contiguous() # nH, Wh*Ww, Wh*Ww
attn = attn + relative_position_bias.unsqueeze(0)
if mask is not None:
nW = mask.shape[0]
attn = attn.view(B_ // nW, nW, self.num_heads, N, N) + mask.unsqueeze(1).unsqueeze(0)
attn = attn.view(-1, self.num_heads, N, N)
attn = self.softmax(attn)
else:
attn = self.softmax(attn)
attn = self.attn_drop(attn)
x = (attn @ v).transpose(1, 2).reshape(B_, N, C)
x = self.proj(x)
x = self.proj_drop(x)
return x
def extra_repr(self) -> str:
return f'dim={self.dim}, window_size={self.window_size}, num_heads={self.num_heads}'
def flops(self, N):
# calculate flops for 1 window with token length of N
flops = 0
# qkv = self.qkv(x)
flops += N * self.dim * 3 * self.dim
# attn = (q @ k.transpose(-2, -1))
flops += self.num_heads * N * (self.dim // self.num_heads) * N
# x = (attn @ v)
flops += self.num_heads * N * N * (self.dim // self.num_heads)
# x = self.proj(x)
flops += N * self.dim * self.dim
return flops
可以看出首先会经过self.qkv
生成我们的q, k, v
矩阵,内部代码就是很简单的nn.Linear
,
# 这里的`dim`, 我们设置为96
self.qkv = nn.Linear(dim, dim * 3, bias=qkv_bias)
# 这里我们得到的self.qkv shape 为[3, 256, 3, 49, 32] 这里的3分别对应qkv,
# 256个窗口分别做attention,
# 刚开始head为3,
# 每个窗口有49个元素,
# 32 代表每个头有32个维度
# q, k, v shape分别为[256, 3, 49, 32]
q, k, v = qkv[0], qkv[1], qkv[2] # make torchscript happy (cannot use tensor as tuple)
接着用得到我们的注意力机制,如下所示,这里的self.scale
可以理解为我们的v
q = q * self.scale
attn = (q @ k.transpose(-2, -1))
最终让我们attention
与position bias
相加, 如下所示获得我们最终的atten
。
attn = attn + relative_position_bias.unsqueeze(0)
这里的position bias
下面解释。
2.2.3 Relative Position Bias Table
我们上面说了相对位置偏置矩阵的大小为(2M-1) * (2M-1)
, 这里的M
为windows-size
大小(详细理解可以参考bilibili, 在56分钟左右 说的非常好)。
# define a parameter table of relative position bias
self.relative_position_bias_table = nn.Parameter(
torch.zeros((2 * window_size[0] - 1) * (2 * window_size[1] - 1), num_heads)) # 2*Wh-1 * 2*Ww-1, nH
# get pair-wise relative position index for each token inside the window
coords_h = torch.arange(self.window_size[0])
coords_w = torch.arange(self.window_size[1])
coords = torch.stack(torch.meshgrid([coords_h, coords_w])) # 2, Wh, Ww
coords_flatten = torch.flatten(coords, 1) # 2, Wh*Ww
relative_coords = coords_flatten[:, :, None] - coords_flatten[:, None, :] # 2, Wh*Ww, Wh*Ww
relative_coords = relative_coords.permute(1, 2, 0).contiguous() # Wh*Ww, Wh*Ww, 2
relative_coords[:, :, 0] += self.window_size[0] - 1 # shift to start from 0
relative_coords[:, :, 1] += self.window_size[1] - 1
relative_coords[:, :, 0] *= 2 * self.window_size[1] - 1
relative_position_index = relative_coords.sum(-1) # Wh*Ww, Wh*Ww
self.register_buffer("relative_position_index", relative_position_index)
self.qkv = nn.Linear(dim, dim * 3, bias=qkv_bias)
self.attn_drop = nn.Dropout(attn_drop)
self.proj = nn.Linear(dim, dim)
self.proj_drop = nn.Dropout(proj_drop)
trunc_normal_(self.relative_position_bias_table, std=.02)
self.softmax = nn.Softmax(dim=-1)
下面就是之前说的经softmax
, 如果mask
不相同索引的我们设置为-100
, 经过softmax
计算就变成了0.
if mask is not None:
nW = mask.shape[0]
attn = attn.view(B_ // nW, nW, self.num_heads, N, N) + mask.unsqueeze(1).unsqueeze(0)
attn = attn.view(-1, self.num_heads, N, N)
attn = self.softmax(attn)
else:
attn = self.softmax(attn)
attn = self.attn_drop(attn)
x = (attn @ v).transpose(1, 2).reshape(B_, N, C)
x = self.proj(x)
x = self.proj_drop(x)
return x
在经过
x = (attn @ v).transpose(1, 2).reshape(B_, N, C)
操作之后我们得到的attention之后的向量为(256, 49. 96)
, self.proj_drop为drop_out
。
2.2.4 FFN(残差操作)
最后要做一次残差连接
# FFN
x = shortcut + self.drop_path(x)
x = x + self.drop_path(self.mlp(self.norm2(x)))
上述说完就完成了我们SwinTransformerBlock
的部分了。
3. Patch Merging
通过结构图我们可以看出经过
Swin Transformer Block
之后会经过Patch Merging
层,原理如下图所示。对应的代码如下:
class PatchMerging(nn.Module):
r""" Patch Merging Layer.
Args:
input_resolution (tuple[int]): Resolution of input feature.
dim (int): Number of input channels.
norm_layer (nn.Module, optional): Normalization layer. Default: nn.LayerNorm
"""
def __init__(self, input_resolution, dim, norm_layer=nn.LayerNorm):
super().__init__()
self.input_resolution = input_resolution
self.dim = dim
self.reduction = nn.Linear(4 * dim, 2 * dim, bias=False)
self.norm = norm_layer(4 * dim)
def forward(self, x):
"""
x: B, H*W, C
"""
H, W = self.input_resolution
B, L, C = x.shape
assert L == H * W, "input feature has wrong size"
assert H % 2 == 0 and W % 2 == 0, f"x size ({H}*{W}) are not even."
x = x.view(B, H, W, C)
x0 = x[:, 0::2, 0::2, :] # B H/2 W/2 C
x1 = x[:, 1::2, 0::2, :] # B H/2 W/2 C
x2 = x[:, 0::2, 1::2, :] # B H/2 W/2 C
x3 = x[:, 1::2, 1::2, :] # B H/2 W/2 C
x = torch.cat([x0, x1, x2, x3], -1) # B H/2 W/2 4*C
x = x.view(B, -1, 4 * C) # B H/2*W/2 4*C
x = self.norm(x)
x = self.reduction(x)
return x
4. 输出层
self.norm = norm_layer(self.num_features)
self.avgpool = nn.AdaptiveAvgPool1d(1)
self.head = nn.Linear(self.num_features, num_classes) if num_classes > 0 else nn.Identity()
x = self.norm(x) # B L C
x = self.avgpool(x.transpose(1, 2)) # B C 1
x = torch.flatten(x, 1)
经过平均池化将原来shape由(4, 49, 768)
转成(4, 768, 1)
后面再接一下全连接层
nn.Linear(self.num_features, num_classes) if num_classes > 0 else nn.Identity()
即可。
参考:
[1] Swin Transformer
[2] 如何看待swin transformer成为ICCV2021的 best paper?
[3] Swin-Transformer网络结构详解