| 12345678910111213141516171819202122232425262728293031323334353637383940414243444546474849505152535455565758596061626364656667686970717273747576777879808182838485868788899091929394959697989910010110210310410510610710810911011111211311411511611711811912012112212312412512612712812913013113213313413513613713813914014114214314414514614714814915015115215315415515615715815916016116216316416516616716816917017117217317417517617717817918018118218318418518618718818919019119219319419519619719819920020120220320420520620720820921021121221321421521621721821922022122222322422522622722822923023123223323423523623723823924024124224324424524624724824925025125225325425525625725825926026126226326426526626726826927027127227327427527627727827928028128228328428528628728828929029129229329429529629729829930030130230330430530630730830931031131231331431531631731831932032132232332432532632732832933033133233333433533633733833934034134234334434534634734834935035135235335435535635735835936036136236336436536636736836937037137237337437537637737837938038138238338438538638738838939039139239339439539639739839940040140240340440540640740840941041141241341441541641741841942042142242342442542642742842943043143243343443543643743843944044144244344444544644744844945045145245345445545645745845946046146246346446546646746846947047147247347447547647747847948048148248348448548648748848949049149249349449549649749849950050150250350450550650750850951051151251351451551651751851952052152252352452552652752852953053153253353453553653753853954054154254354454554654754854955055155255355455555655755855956056156256356456556656756856957057157257357457557657757857958058158258358458558658758858959059159259359459559659759859960060160260360460560660760860961061161261361461561661761861962062162262362462562662762862963063163263363463563663763863964064164264364464564664764864965065165265365465565665765865966066166266366466566666766866967067167267367467567667767867968068168268368468568668768868969069169269369469569669769869970070170270370470570670770870971071171271371471571671771871972072172272372472572672772872973073173273373473573673773873974074174274374474574674774874975075175275375475575675775875976076176276376476576676776876977077177277377477577677777877978078178278378478578678778878979079179279379479579679779879980080180280380480580680780880981081181281381481581681781881982082182282382482582682782882983083183283383483583683783883984084184284384484584684784884985085185285385485585685785885986086186286386486586686786886987087187287387487587687787887988088188288388488588688788888989089189289389489589689789889990090190290390490590690790890991091191291391491591691791891992092192292392492592692792892993093193293393493593693793893994094194294394494594694794894995095195295395495595695795895996096196296396496596696796896997097197297397497597697797897998098198298398498598698798898999099199299399499599699799899910001001100210031004100510061007100810091010101110121013101410151016101710181019102010211022102310241025102610271028102910301031103210331034103510361037103810391040104110421043104410451046104710481049105010511052105310541055105610571058105910601061 |
- # The implementation here is modified based on timm,
- # originally Apache 2.0 License and publicly available at
- # https://github.com/naver-ai/vidt/blob/vidt-plus/methods/swin_w_ram.py
- import math
- import os
- import numpy as np
- import torch
- import torch.nn as nn
- import torch.nn.functional as F
- import torch.utils.checkpoint as checkpoint
- from timm.models.layers import DropPath, to_2tuple, trunc_normal_
- class Mlp(nn.Module):
- """ Multilayer perceptron."""
- def __init__(self,
- in_features,
- hidden_features=None,
- out_features=None,
- act_layer=nn.GELU,
- drop=0.):
- super().__init__()
- out_features = out_features or in_features
- hidden_features = hidden_features or in_features
- self.fc1 = nn.Linear(in_features, hidden_features)
- self.act = act_layer()
- self.fc2 = nn.Linear(hidden_features, out_features)
- self.drop = nn.Dropout(drop)
- def forward(self, x):
- x = self.fc1(x)
- x = self.act(x)
- x = self.drop(x)
- x = self.fc2(x)
- x = self.drop(x)
- return x
- def masked_sin_pos_encoding(x,
- mask,
- num_pos_feats,
- temperature=10000,
- scale=2 * math.pi):
- """ Masked Sinusoidal Positional Encoding
- Args:
- x: [PATCH] tokens
- mask: the padding mask for [PATCH] tokens
- num_pos_feats: the size of channel dimension
- temperature: the temperature value
- scale: the normalization scale
- Returns:
- pos: Sinusoidal positional encodings
- """
- num_pos_feats = num_pos_feats // 2
- not_mask = ~mask
- y_embed = not_mask.cumsum(1, dtype=torch.float32)
- x_embed = not_mask.cumsum(2, dtype=torch.float32)
- eps = 1e-6
- y_embed = y_embed / (y_embed[:, -1:, :] + eps) * scale
- x_embed = x_embed / (x_embed[:, :, -1:] + eps) * scale
- dim_t = torch.arange(num_pos_feats, dtype=torch.float32, device=x.device)
- dim_t = temperature**(2 * (dim_t // 2) / num_pos_feats)
- pos_x = x_embed[:, :, :, None] / dim_t
- pos_y = y_embed[:, :, :, None] / dim_t
- pos_x = torch.stack(
- (pos_x[:, :, :, 0::2].sin(), pos_x[:, :, :, 1::2].cos()),
- dim=4).flatten(3)
- pos_y = torch.stack(
- (pos_y[:, :, :, 0::2].sin(), pos_y[:, :, :, 1::2].cos()),
- dim=4).flatten(3)
- pos = torch.cat((pos_y, pos_x), dim=3)
- return pos
- def window_partition(x, window_size):
- """
- Args:
- x: (B, H, W, C)
- window_size (int): window size
- Returns:
- windows: (num_windows*B, window_size, window_size, C)
- """
- B, H, W, C = x.shape
- x = x.view(B, H // window_size, window_size, W // window_size, window_size,
- C)
- windows = x.permute(0, 1, 3, 2, 4,
- 5).contiguous().view(-1, window_size, window_size, C)
- return windows
- def window_reverse(windows, window_size, H, W):
- """
- Args:
- windows: (num_windows*B, window_size, window_size, C)
- window_size (int): Window size
- H (int): Height of image
- W (int): Width of image
- Returns:
- x: (B, H, W, C)
- """
- B = int(windows.shape[0] / (H * W / window_size / window_size))
- x = windows.view(B, H // window_size, W // window_size, window_size,
- window_size, -1)
- x = x.permute(0, 1, 3, 2, 4, 5).contiguous().view(B, H, W, -1)
- return x
- class ReconfiguredAttentionModule(nn.Module):
- """ Window based multi-head self attention (W-MSA) module with relative position bias -> extended with RAM.
- It supports both of shifted and non-shifted window.
- !!!!!!!!!!! IMPORTANT !!!!!!!!!!!
- The original attention module in Swin is replaced with the reconfigured attention module in Section 3.
- All the Args are shared, so only the forward function is modified.
- See https://arxiv.org/pdf/2110.03921.pdf
- !!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!
- Args:
- dim (int): Number of input channels.
- window_size (tuple[int]): The height and width of the window.
- num_heads (int): Number of attention heads.
- qkv_bias (bool, optional): If True, add a learnable bias to query, key, value. Default: True
- qk_scale (float | None, optional): Override default qk scale of head_dim ** -0.5 if set
- attn_drop (float, optional): Dropout ratio of attention weight. Default: 0.0
- proj_drop (float, optional): Dropout ratio of output. Default: 0.0
- """
- def __init__(self,
- dim,
- window_size,
- num_heads,
- qkv_bias=True,
- qk_scale=None,
- attn_drop=0.,
- proj_drop=0.):
- super().__init__()
- self.dim = dim
- self.window_size = window_size # Wh, Ww
- self.num_heads = num_heads
- head_dim = dim // num_heads
- self.scale = qk_scale or head_dim**-0.5
- # define a parameter table of relative position bias
- self.relative_position_bias_table = nn.Parameter(
- torch.zeros((2 * window_size[0] - 1) * (2 * window_size[1] - 1),
- num_heads)) # 2*Wh-1 * 2*Ww-1, nH
- # get pair-wise relative position index for each token inside the window
- coords_h = torch.arange(self.window_size[0])
- coords_w = torch.arange(self.window_size[1])
- coords = torch.stack(torch.meshgrid([coords_h, coords_w])) # 2, Wh, Ww
- coords_flatten = torch.flatten(coords, 1) # 2, Wh*Ww
- relative_coords = coords_flatten[:, :,
- None] - coords_flatten[:,
- None, :] # 2, Wh*Ww, Wh*Ww
- relative_coords = relative_coords.permute(
- 1, 2, 0).contiguous() # Wh*Ww, Wh*Ww, 2
- relative_coords[:, :,
- 0] += self.window_size[0] - 1 # shift to start from 0
- relative_coords[:, :, 1] += self.window_size[1] - 1
- relative_coords[:, :, 0] *= 2 * self.window_size[1] - 1
- relative_position_index = relative_coords.sum(-1) # Wh*Ww, Wh*Ww
- self.register_buffer('relative_position_index',
- relative_position_index)
- self.qkv = nn.Linear(dim, dim * 3, bias=qkv_bias)
- self.attn_drop = nn.Dropout(attn_drop)
- self.proj = nn.Linear(dim, dim)
- self.proj_drop = nn.Dropout(proj_drop)
- trunc_normal_(self.relative_position_bias_table, std=.02)
- self.softmax = nn.Softmax(dim=-1)
- def forward(self,
- x,
- det,
- mask=None,
- cross_attn=False,
- cross_attn_mask=None):
- """ Forward function.
- RAM module receives [Patch] and [DET] tokens and returns their calibrated ones
- Args:
- x: [PATCH] tokens
- det: [DET] tokens
- mask: (0/-inf) mask with shape of (num_windows, Wh*Ww, Wh*Ww) or None -> mask for shifted window attention
- "additional inputs for RAM"
- cross_attn: whether to use cross-attention [det x patch] (for selective cross-attention)
- cross_attn_mask: mask for cross-attention
- Returns:
- patch_x: the calibrated [PATCH] tokens
- det_x: the calibrated [DET] tokens
- """
- assert self.window_size[0] == self.window_size[1]
- window_size = self.window_size[0]
- local_map_size = window_size * window_size
- # projection before window partitioning
- if not cross_attn:
- B, H, W, C = x.shape
- N = H * W
- x = x.view(B, N, C)
- x = torch.cat([x, det], dim=1)
- full_qkv = self.qkv(x)
- patch_qkv, det_qkv = full_qkv[:, :N, :], full_qkv[:, N:, :]
- else:
- B, H, W, C = x[0].shape
- N = H * W
- _, ori_H, ori_W, _ = x[1].shape
- ori_N = ori_H * ori_W
- shifted_x = x[0].view(B, N, C)
- cross_x = x[1].view(B, ori_N, C)
- x = torch.cat([shifted_x, cross_x, det], dim=1)
- full_qkv = self.qkv(x)
- patch_qkv, cross_patch_qkv, det_qkv = \
- full_qkv[:, :N, :], full_qkv[:, N:N + ori_N, :], full_qkv[:, N + ori_N:, :]
- patch_qkv = patch_qkv.view(B, H, W, -1)
- # window partitioning for [PATCH] tokens
- patch_qkv = window_partition(
- patch_qkv, window_size) # nW*B, window_size, window_size, C
- B_ = patch_qkv.shape[0]
- patch_qkv = patch_qkv.reshape(B_, window_size * window_size, 3,
- self.num_heads, C // self.num_heads)
- _patch_qkv = patch_qkv.permute(2, 0, 3, 1, 4)
- patch_q, patch_k, patch_v = _patch_qkv[0], _patch_qkv[1], _patch_qkv[2]
- # [PATCH x PATCH] self-attention using window partitions
- patch_q = patch_q * self.scale
- patch_attn = (patch_q @ patch_k.transpose(-2, -1))
- # add relative pos bias for [patch x patch] self-attention
- relative_position_bias = self.relative_position_bias_table[
- self.relative_position_index.view(-1)].view(
- self.window_size[0] * self.window_size[1],
- self.window_size[0] * self.window_size[1],
- -1) # Wh*Ww,Wh*Ww,nH
- relative_position_bias = relative_position_bias.permute(
- 2, 0, 1).contiguous() # nH, Wh*Ww, Wh*Ww
- patch_attn = patch_attn + relative_position_bias.unsqueeze(0)
- # if shifted window is used, it needs to apply the mask
- if mask is not None:
- nW = mask.shape[0]
- tmp0 = patch_attn.view(B_ // nW, nW, self.num_heads,
- local_map_size, local_map_size)
- tmp1 = mask.unsqueeze(1).unsqueeze(0)
- patch_attn = tmp0 + tmp1
- patch_attn = patch_attn.view(-1, self.num_heads, local_map_size,
- local_map_size)
- patch_attn = self.softmax(patch_attn)
- patch_attn = self.attn_drop(patch_attn)
- patch_x = (patch_attn @ patch_v).transpose(1, 2).reshape(
- B_, window_size, window_size, C)
- # extract qkv for [DET] tokens
- det_qkv = det_qkv.view(B, -1, 3, self.num_heads, C // self.num_heads)
- det_qkv = det_qkv.permute(2, 0, 3, 1, 4)
- det_q, det_k, det_v = det_qkv[0], det_qkv[1], det_qkv[2]
- # if cross-attention is activated
- if cross_attn:
- # reconstruct the spatial form of [PATCH] tokens for global [DET x PATCH] attention
- cross_patch_qkv = cross_patch_qkv.view(B, ori_H, ori_W, 3,
- self.num_heads,
- C // self.num_heads)
- patch_kv = cross_patch_qkv[:, :, :,
- 1:, :, :].permute(3, 0, 4, 1, 2,
- 5).contiguous()
- patch_kv = patch_kv.view(2, B, self.num_heads, ori_H * ori_W, -1)
- # extract "key and value" of [PATCH] tokens for cross-attention
- cross_patch_k, cross_patch_v = patch_kv[0], patch_kv[1]
- # bind key and value of [PATCH] and [DET] tokens for [DET X [PATCH, DET]] attention
- det_k, det_v = torch.cat([cross_patch_k, det_k],
- dim=2), torch.cat([cross_patch_v, det_v],
- dim=2)
- # [DET x DET] self-attention or binded [DET x [PATCH, DET]] attention
- det_q = det_q * self.scale
- det_attn = (det_q @ det_k.transpose(-2, -1))
- # apply cross-attention mask if available
- if cross_attn_mask is not None:
- det_attn = det_attn + cross_attn_mask
- det_attn = self.softmax(det_attn)
- det_attn = self.attn_drop(det_attn)
- det_x = (det_attn @ det_v).transpose(1, 2).reshape(B, -1, C)
- # reverse window for [PATCH] tokens <- the output of [PATCH x PATCH] self attention
- patch_x = window_reverse(patch_x, window_size, H, W)
- # projection for outputs from multi-head
- x = torch.cat([patch_x.view(B, H * W, C), det_x], dim=1)
- x = self.proj(x)
- x = self.proj_drop(x)
- # decompose after FFN into [PATCH] and [DET] tokens
- patch_x = x[:, :H * W, :].view(B, H, W, C)
- det_x = x[:, H * W:, :]
- return patch_x, det_x
- class SwinTransformerBlock(nn.Module):
- """ Swin Transformer Block.
- Args:
- dim (int): Number of input channels.
- num_heads (int): Number of attention heads.
- window_size (int): Window size.
- shift_size (int): Shift size for SW-MSA.
- mlp_ratio (float): Ratio of mlp hidden dim to embedding dim.
- qkv_bias (bool, optional): If True, add a learnable bias to query, key, value. Default: True
- qk_scale (float | None, optional): Override default qk scale of head_dim ** -0.5 if set.
- drop (float, optional): Dropout rate. Default: 0.0
- attn_drop (float, optional): Attention dropout rate. Default: 0.0
- drop_path (float, optional): Stochastic depth rate. Default: 0.0
- act_layer (nn.Module, optional): Activation layer. Default: nn.GELU
- norm_layer (nn.Module, optional): Normalization layer. Default: nn.LayerNorm
- """
- def __init__(self,
- dim,
- num_heads,
- window_size=7,
- shift_size=0,
- mlp_ratio=4.,
- qkv_bias=True,
- qk_scale=None,
- drop=0.,
- attn_drop=0.,
- drop_path=0.,
- act_layer=nn.GELU,
- norm_layer=nn.LayerNorm):
- super().__init__()
- self.dim = dim
- self.num_heads = num_heads
- self.window_size = window_size
- self.shift_size = shift_size
- self.mlp_ratio = mlp_ratio
- assert 0 <= self.shift_size < self.window_size, 'shift_size must in 0-window_size'
- self.norm1 = norm_layer(dim)
- self.attn = ReconfiguredAttentionModule(
- dim,
- window_size=to_2tuple(self.window_size),
- num_heads=num_heads,
- qkv_bias=qkv_bias,
- qk_scale=qk_scale,
- attn_drop=attn_drop,
- proj_drop=drop)
- self.drop_path = DropPath(
- drop_path) if drop_path > 0. else nn.Identity()
- self.norm2 = norm_layer(dim)
- mlp_hidden_dim = int(dim * mlp_ratio)
- self.mlp = Mlp(
- in_features=dim,
- hidden_features=mlp_hidden_dim,
- act_layer=act_layer,
- drop=drop)
- self.H = None
- self.W = None
- def forward(self, x, mask_matrix, pos, cross_attn, cross_attn_mask):
- """ Forward function.
- Args:
- x: Input feature, tensor size (B, H*W + DET, C). i.e., binded [PATCH, DET] tokens
- H, W: Spatial resolution of the input feature.
- mask_matrix: Attention mask for cyclic shift.
- "additional inputs'
- pos: (patch_pos, det_pos)
- cross_attn: whether to use cross attn [det x [det + patch]]
- cross_attn_mask: attention mask for cross-attention
- Returns:
- x: calibrated & binded [PATCH, DET] tokens
- """
- B, L, C = x.shape
- H, W = self.H, self.W
- assert L == H * W + self.det_token_num, 'input feature has wrong size'
- shortcut = x
- x = self.norm1(x)
- x, det = x[:, :H * W, :], x[:, H * W:, :]
- x = x.view(B, H, W, C)
- orig_x = x
- # pad feature maps to multiples of window size
- pad_l = pad_t = 0
- pad_r = (self.window_size - W % self.window_size) % self.window_size
- pad_b = (self.window_size - H % self.window_size) % self.window_size
- x = F.pad(x, (0, 0, pad_l, pad_r, pad_t, pad_b))
- _, Hp, Wp, _ = x.shape
- # projection for det positional encodings: make the channel size suitable for the current layer
- patch_pos, det_pos = pos
- det_pos = self.det_pos_linear(det_pos)
- # cyclic shift
- if self.shift_size > 0:
- shifted_x = torch.roll(
- x, shifts=(-self.shift_size, -self.shift_size), dims=(1, 2))
- attn_mask = mask_matrix
- else:
- shifted_x = x
- attn_mask = None
- # prepare cross-attn and add positional encodings
- if cross_attn:
- # patch token (for cross-attention) + Sinusoidal pos encoding
- cross_patch = orig_x + patch_pos
- # det token + learnable pos encoding
- det = det + det_pos
- shifted_x = (shifted_x, cross_patch)
- else:
- # it cross_attn is deactivated, only [PATCH] and [DET] self-attention are performed
- det = det + det_pos
- shifted_x = shifted_x
- # W-MSA/SW-MSA
- shifted_x, det = self.attn(
- shifted_x,
- mask=attn_mask,
- # additional args
- det=det,
- cross_attn=cross_attn,
- cross_attn_mask=cross_attn_mask)
- # reverse cyclic shift
- if self.shift_size > 0:
- x = torch.roll(
- shifted_x,
- shifts=(self.shift_size, self.shift_size),
- dims=(1, 2))
- else:
- x = shifted_x
- if pad_r > 0 or pad_b > 0:
- x = x[:, :H, :W, :].contiguous()
- x = x.view(B, H * W, C)
- x = torch.cat([x, det], dim=1)
- # FFN
- x = shortcut + self.drop_path(x)
- x = x + self.drop_path(self.mlp(self.norm2(x)))
- return x
- class PatchMerging(nn.Module):
- """ Patch Merging Layer
- Args:
- dim (int): Number of input channels.
- norm_layer (nn.Module, optional): Normalization layer. Default: nn.LayerNorm
- """
- def __init__(self, dim, norm_layer=nn.LayerNorm, expand=True):
- super().__init__()
- self.dim = dim
- # if expand is True, the channel size will be expanded, otherwise, return 256 size of channel
- expand_dim = 2 * dim if expand else 256
- self.reduction = nn.Linear(4 * dim, expand_dim, bias=False)
- self.norm = norm_layer(4 * dim)
- # added for detection token [please ignore, not used for training]
- # not implemented yet.
- self.expansion = nn.Linear(dim, expand_dim, bias=False)
- self.norm2 = norm_layer(dim)
- def forward(self, x, H, W):
- """ Forward function.
- Args:
- x: Input feature, tensor size (B, H*W, C), i.e., binded [PATCH, DET] tokens
- H, W: Spatial resolution of the input feature.
- Returns:
- x: merged [PATCH, DET] tokens;
- only [PATCH] tokens are reduced in spatial dim, while [DET] tokens is fix-scale
- """
- B, L, C = x.shape
- assert L == H * W + self.det_token_num, 'input feature has wrong size'
- x, det = x[:, :H * W, :], x[:, H * W:, :]
- x = x.view(B, H, W, C)
- # padding
- pad_input = (H % 2 == 1) or (W % 2 == 1)
- if pad_input:
- x = F.pad(x, (0, 0, 0, W % 2, 0, H % 2))
- x0 = x[:, 0::2, 0::2, :] # B H/2 W/2 C
- x1 = x[:, 1::2, 0::2, :] # B H/2 W/2 C
- x2 = x[:, 0::2, 1::2, :] # B H/2 W/2 C
- x3 = x[:, 1::2, 1::2, :] # B H/2 W/2 C
- x = torch.cat([x0, x1, x2, x3], -1) # B H/2 W/2 4*C
- x = x.view(B, -1, 4 * C) # B H/2*W/2 4*C
- # simply repeating for DET tokens
- det = det.repeat(1, 1, 4)
- x = torch.cat([x, det], dim=1)
- x = self.norm(x)
- x = self.reduction(x)
- return x
- class BasicLayer(nn.Module):
- """ A basic Swin Transformer layer for one stage.
- Args:
- dim (int): Number of feature channels
- depth (int): Depths of this stage.
- num_heads (int): Number of attention head.
- window_size (int): Local window size. Default: 7.
- mlp_ratio (float): Ratio of mlp hidden dim to embedding dim. Default: 4.
- qkv_bias (bool, optional): If True, add a learnable bias to query, key, value. Default: True
- qk_scale (float | None, optional): Override default qk scale of head_dim ** -0.5 if set.
- drop (float, optional): Dropout rate. Default: 0.0
- attn_drop (float, optional): Attention dropout rate. Default: 0.0
- drop_path (float | tuple[float], optional): Stochastic depth rate. Default: 0.0
- norm_layer (nn.Module, optional): Normalization layer. Default: nn.LayerNorm
- downsample (nn.Module | None, optional): Downsample layer at the end of the layer. Default: None
- use_checkpoint (bool): Whether to use checkpointing to save memory. Default: False.
- """
- def __init__(self,
- dim,
- depth,
- num_heads,
- window_size=7,
- mlp_ratio=4.,
- qkv_bias=True,
- qk_scale=None,
- drop=0.,
- attn_drop=0.,
- drop_path=0.,
- norm_layer=nn.LayerNorm,
- downsample=None,
- last=False,
- use_checkpoint=False):
- super().__init__()
- self.window_size = window_size
- self.shift_size = window_size // 2
- self.depth = depth
- self.dim = dim
- self.use_checkpoint = use_checkpoint
- # build blocks
- self.blocks = nn.ModuleList([
- SwinTransformerBlock(
- dim=dim,
- num_heads=num_heads,
- window_size=window_size,
- shift_size=0 if (i % 2 == 0) else window_size // 2,
- mlp_ratio=mlp_ratio,
- qkv_bias=qkv_bias,
- qk_scale=qk_scale,
- drop=drop,
- attn_drop=attn_drop,
- drop_path=drop_path[i]
- if isinstance(drop_path, list) else drop_path,
- norm_layer=norm_layer) for i in range(depth)
- ])
- # patch merging layer
- if downsample is not None:
- self.downsample = downsample(
- dim=dim, norm_layer=norm_layer, expand=(not last))
- else:
- self.downsample = None
- def forward(self, x, H, W, det_pos, input_mask, cross_attn=False):
- """ Forward function.
- Args:
- x: Input feature, tensor size (B, H*W, C).
- H, W: Spatial resolution of the input feature.
- det_pos: pos encoding for det token
- input_mask: padding mask for inputs
- cross_attn: whether to use cross attn [det x [det + patch]]
- """
- B = x.shape[0]
- # calculate attention mask for SW-MSA
- Hp = int(np.ceil(H / self.window_size)) * self.window_size
- Wp = int(np.ceil(W / self.window_size)) * self.window_size
- img_mask = torch.zeros((1, Hp, Wp, 1), device=x.device) # 1 Hp Wp 1
- h_slices = (slice(0, -self.window_size),
- slice(-self.window_size,
- -self.shift_size), slice(-self.shift_size, None))
- w_slices = (slice(0, -self.window_size),
- slice(-self.window_size,
- -self.shift_size), slice(-self.shift_size, None))
- cnt = 0
- for h in h_slices:
- for w in w_slices:
- img_mask[:, h, w, :] = cnt
- cnt += 1
- # mask for cyclic shift
- mask_windows = window_partition(img_mask, self.window_size)
- mask_windows = mask_windows.view(-1,
- self.window_size * self.window_size)
- attn_mask = mask_windows.unsqueeze(1) - mask_windows.unsqueeze(2)
- attn_mask = attn_mask.masked_fill(attn_mask != 0,
- float(-100.0)).masked_fill(
- attn_mask == 0, float(0.0))
- # compute sinusoidal pos encoding and cross-attn mask here to avoid redundant computation
- if cross_attn:
- _H, _W = input_mask.shape[1:]
- if not (_H == H and _W == W):
- input_mask = F.interpolate(
- input_mask[None].float(), size=(H, W)).to(torch.bool)[0]
- # sinusoidal pos encoding for [PATCH] tokens used in cross-attention
- patch_pos = masked_sin_pos_encoding(x, input_mask, self.dim)
- # attention padding mask due to the zero padding in inputs
- # the zero (padded) area is masked by 1.0 in 'input_mask'
- cross_attn_mask = input_mask.float()
- cross_attn_mask = cross_attn_mask.masked_fill(cross_attn_mask != 0.0, float(-100.0)). \
- masked_fill(cross_attn_mask == 0.0, float(0.0))
- # pad for detection token (this padding is required to process the binded [PATCH, DET] attention
- cross_attn_mask = cross_attn_mask.view(
- B, H * W).unsqueeze(1).unsqueeze(2)
- cross_attn_mask = F.pad(
- cross_attn_mask, (0, self.det_token_num), value=0)
- else:
- patch_pos = None
- cross_attn_mask = None
- # zip pos encodings
- pos = (patch_pos, det_pos)
- for n_blk, blk in enumerate(self.blocks):
- blk.H, blk.W = H, W
- # for selective cross-attention
- if cross_attn:
- _cross_attn = True
- _cross_attn_mask = cross_attn_mask
- _pos = pos # i.e., (patch_pos, det_pos)
- else:
- _cross_attn = False
- _cross_attn_mask = None
- _pos = (None, det_pos)
- if self.use_checkpoint:
- x = checkpoint.checkpoint(
- blk,
- x,
- attn_mask,
- # additional inputs
- pos=_pos,
- cross_attn=_cross_attn,
- cross_attn_mask=_cross_attn_mask)
- else:
- x = blk(
- x,
- attn_mask,
- # additional inputs
- pos=_pos,
- cross_attn=_cross_attn,
- cross_attn_mask=_cross_attn_mask)
- # reduce the number of patch tokens, but maintaining a fixed-scale det tokens
- # meanwhile, the channel dim increases by a factor of 2
- if self.downsample is not None:
- x_down = self.downsample(x, H, W)
- Wh, Ww = (H + 1) // 2, (W + 1) // 2
- return x, H, W, x_down, Wh, Ww
- else:
- return x, H, W, x, H, W
- class PatchEmbed(nn.Module):
- """ Image to Patch Embedding
- Args:
- patch_size (int): Patch token size. Default: 4.
- in_chans (int): Number of input image channels. Default: 3.
- embed_dim (int): Number of linear projection output channels. Default: 96.
- norm_layer (nn.Module, optional): Normalization layer. Default: None
- """
- def __init__(self,
- patch_size=4,
- in_chans=3,
- embed_dim=96,
- norm_layer=None):
- super().__init__()
- patch_size = to_2tuple(patch_size)
- self.patch_size = patch_size
- self.in_chans = in_chans
- self.embed_dim = embed_dim
- self.proj = nn.Conv2d(
- in_chans, embed_dim, kernel_size=patch_size, stride=patch_size)
- if norm_layer is not None:
- self.norm = norm_layer(embed_dim)
- else:
- self.norm = None
- def forward(self, x):
- """Forward function."""
- # padding
- _, _, H, W = x.size()
- if W % self.patch_size[1] != 0:
- x = F.pad(x, (0, self.patch_size[1] - W % self.patch_size[1]))
- if H % self.patch_size[0] != 0:
- x = F.pad(x,
- (0, 0, 0, self.patch_size[0] - H % self.patch_size[0]))
- x = self.proj(x) # B C Wh Ww
- if self.norm is not None:
- Wh, Ww = x.size(2), x.size(3)
- x = x.flatten(2).transpose(1, 2)
- x = self.norm(x)
- x = x.transpose(1, 2).view(-1, self.embed_dim, Wh, Ww)
- return x
- class SwinTransformer(nn.Module):
- """ Swin Transformer backbone.
- A PyTorch impl of : `Swin Transformer: Hierarchical Vision Transformer using Shifted Windows` -
- https://arxiv.org/pdf/2103.14030
- Args:
- pretrain_img_size (int): Input image size for training the pretrained model,
- used in absolute position embedding. Default 224.
- patch_size (int | tuple(int)): Patch size. Default: 4.
- in_chans (int): Number of input image channels. Default: 3.
- embed_dim (int): Number of linear projection output channels. Default: 96.
- depths (tuple[int]): Depths of each Swin Transformer stage.
- num_heads (tuple[int]): Number of attention head of each stage.
- window_size (int): Window size. Default: 7.
- mlp_ratio (float): Ratio of mlp hidden dim to embedding dim. Default: 4.
- qkv_bias (bool): If True, add a learnable bias to query, key, value. Default: True
- qk_scale (float): Override default qk scale of head_dim ** -0.5 if set.
- drop_rate (float): Dropout rate.
- attn_drop_rate (float): Attention dropout rate. Default: 0.
- drop_path_rate (float): Stochastic depth rate. Default: 0.2.
- norm_layer (nn.Module): Normalization layer. Default: nn.LayerNorm.
- ape (bool): If True, add absolute position embedding to the patch embedding. Default: False.
- patch_norm (bool): If True, add normalization after patch embedding. Default: True.
- out_indices (Sequence[int]): Output from which stages.
- frozen_stages (int): Stages to be frozen (stop grad and set eval mode).
- -1 means not freezing any args.
- use_checkpoint (bool): Whether to use checkpointing to save memory. Default: False.
- """
- def __init__(
- self,
- pretrain_img_size=224,
- patch_size=4,
- in_chans=3,
- embed_dim=96,
- depths=[2, 2, 6, 2],
- num_heads=[3, 6, 12, 24],
- window_size=7,
- mlp_ratio=4.,
- qkv_bias=True,
- qk_scale=None,
- drop_rate=0.,
- attn_drop_rate=0.,
- drop_path_rate=0.2,
- norm_layer=nn.LayerNorm,
- ape=False,
- patch_norm=True,
- out_indices=[1, 2,
- 3], # not used in the current version, please ignore.
- frozen_stages=-1,
- use_checkpoint=False):
- super().__init__()
- self.pretrain_img_size = pretrain_img_size
- self.num_layers = len(depths)
- self.embed_dim = embed_dim
- self.ape = ape
- self.patch_norm = patch_norm
- self.out_indices = out_indices
- self.frozen_stages = frozen_stages
- # split image into non-overlapping patches
- self.patch_embed = PatchEmbed(
- patch_size=patch_size,
- in_chans=in_chans,
- embed_dim=embed_dim,
- norm_layer=norm_layer if self.patch_norm else None)
- # absolute position embedding
- if self.ape:
- pretrain_img_size = to_2tuple(pretrain_img_size)
- patch_size = to_2tuple(patch_size)
- patches_resolution = [
- pretrain_img_size[0] // patch_size[0],
- pretrain_img_size[1] // patch_size[1]
- ]
- self.absolute_pos_embed = nn.Parameter(
- torch.zeros(1, embed_dim, patches_resolution[0],
- patches_resolution[1]))
- trunc_normal_(self.absolute_pos_embed, std=.02)
- self.pos_drop = nn.Dropout(p=drop_rate)
- # stochastic depth
- dpr = [
- x.item() for x in torch.linspace(0, drop_path_rate, sum(depths))
- ] # stochastic depth decay rule
- # build layers
- self.layers = nn.ModuleList()
- for i_layer in range(self.num_layers):
- layer = BasicLayer(
- dim=int(embed_dim * 2**i_layer),
- depth=depths[i_layer],
- num_heads=num_heads[i_layer],
- window_size=window_size,
- mlp_ratio=mlp_ratio,
- qkv_bias=qkv_bias,
- qk_scale=qk_scale,
- drop=drop_rate,
- attn_drop=attn_drop_rate,
- drop_path=dpr[sum(depths[:i_layer]):sum(depths[:i_layer + 1])],
- norm_layer=norm_layer,
- # modified by ViDT
- downsample=PatchMerging if
- (i_layer < self.num_layers) else None,
- last=None if (i_layer < self.num_layers - 1) else True,
- #
- use_checkpoint=use_checkpoint)
- self.layers.append(layer)
- num_features = [int(embed_dim * 2**i) for i in range(self.num_layers)]
- self.num_features = num_features
- # add a norm layer for each output
- # Not used in the current version -> please ignore. this error will be fixed later
- # we leave this lines to load the pre-trained model ...
- for i_layer in out_indices:
- layer = norm_layer(num_features[i_layer])
- layer_name = f'norm{i_layer}'
- self.add_module(layer_name, layer)
- self._freeze_stages()
- def _freeze_stages(self):
- if self.frozen_stages >= 0:
- self.patch_embed.eval()
- for param in self.patch_embed.parameters():
- param.requires_grad = False
- if self.frozen_stages >= 1 and self.ape:
- self.absolute_pos_embed.requires_grad = False
- if self.frozen_stages >= 2:
- self.pos_drop.eval()
- for i in range(0, self.frozen_stages - 1):
- m = self.layers[i]
- m.eval()
- for param in m.parameters():
- param.requires_grad = False
- @torch.jit.ignore
- def no_weight_decay(self):
- return {'det_pos_embed', 'det_token'}
- def finetune_det(self,
- method,
- det_token_num=100,
- pos_dim=256,
- cross_indices=[3]):
- """ A function to add necessary (leanable) variables to Swin Transformer for object detection
- Args:
- method: vidt or vidt_wo_neck
- det_token_num: the number of object to detect, i.e., number of object queries
- pos_dim: the channel dimension of positional encodings for [DET] and [PATCH] tokens
- cross_indices: the indices where to use the [DET X PATCH] cross-attention
- there are four possible stages in [0, 1, 2, 3]. 3 indicates Stage 4 in the ViDT paper.
- """
- # which method?
- self.method = method
- # how many object we detect?
- self.det_token_num = det_token_num
- self.det_token = nn.Parameter(
- torch.zeros(1, det_token_num, self.num_features[0]))
- self.det_token = trunc_normal_(self.det_token, std=.02)
- # dim size of pos encoding
- self.pos_dim = pos_dim
- # learnable positional encoding for detection tokens
- det_pos_embed = torch.zeros(1, det_token_num, pos_dim)
- det_pos_embed = trunc_normal_(det_pos_embed, std=.02)
- self.det_pos_embed = torch.nn.Parameter(det_pos_embed)
- # info for detection
- self.num_channels = [
- self.num_features[i + 1]
- for i in range(len(self.num_features) - 1)
- ]
- if method == 'vidt':
- self.num_channels.append(
- self.pos_dim) # default: 256 (same to the default pos_dim)
- self.cross_indices = cross_indices
- # divisor to reduce the spatial size of the mask
- self.mask_divisor = 2**(len(self.layers) - len(self.cross_indices))
- # projection matrix for det pos encoding in each Swin layer (there are 4 blocks)
- for layer in self.layers:
- layer.det_token_num = det_token_num
- if layer.downsample is not None:
- layer.downsample.det_token_num = det_token_num
- for block in layer.blocks:
- block.det_token_num = det_token_num
- block.det_pos_linear = nn.Linear(pos_dim, block.dim)
- # neck-free model do not require downsampling at the last stage.
- if method == 'vidt_wo_neck':
- self.layers[-1].downsample = None
- def forward(self, x, mask):
- """ Forward function.
- Args:
- x: input rgb images
- mask: input padding masks [0: rgb values, 1: padded values]
- Returns:
- patch_outs: multi-scale [PATCH] tokens (four scales are used)
- these tokens are the first input of the neck decoder
- det_tgt: final [DET] tokens obtained at the last stage
- this tokens are the second input of the neck decoder
- det_pos: the learnable pos encoding for [DET] tokens.
- these encodings are used to generate reference points in deformable attention
- """
- # original input shape
- B, _, _ = x.shape[0], x.shape[2], x.shape[3]
- # patch embedding
- x = self.patch_embed(x)
- Wh, Ww = x.size(2), x.size(3)
- x = x.flatten(2).transpose(1, 2)
- x = self.pos_drop(x)
- # expand det_token for all examples in the batch
- det_token = self.det_token.expand(B, -1, -1)
- # det pos encoding -> will be projected in each block
- det_pos = self.det_pos_embed
- # prepare a mask for cross attention
- mask = F.interpolate(
- mask[None].float(),
- size=(Wh // self.mask_divisor,
- Ww // self.mask_divisor)).to(torch.bool)[0]
- patch_outs = []
- for stage in range(self.num_layers):
- layer = self.layers[stage]
- # whether to use cross-attention
- cross_attn = True if stage in self.cross_indices else False
- # concat input
- x = torch.cat([x, det_token], dim=1)
- # inference
- x_out, H, W, x, Wh, Ww = layer(
- x,
- Wh,
- Ww,
- # additional input for VIDT
- input_mask=mask,
- det_pos=det_pos,
- cross_attn=cross_attn)
- x, det_token = x[:, :-self.det_token_num, :], x[:, -self.
- det_token_num:, :]
- # Aggregate intermediate outputs
- if stage > 0:
- patch_out = x_out[:, :-self.det_token_num, :].view(
- B, H, W, -1).permute(0, 3, 1, 2)
- patch_outs.append(patch_out)
- # patch token reduced from last stage output
- patch_outs.append(x.view(B, Wh, Ww, -1).permute(0, 3, 1, 2))
- # det token
- det_tgt = x_out[:, -self.det_token_num:, :].permute(0, 2, 1)
- # det token pos encoding
- det_pos = det_pos.permute(0, 2, 1)
- features_0, features_1, features_2, features_3 = patch_outs
- return features_0, features_1, features_2, features_3, det_tgt, det_pos
- def train(self, mode=True):
- """Convert the model into training mode while keep layers freezed."""
- super(SwinTransformer, self).train(mode)
- self._freeze_stages()
- # not working in the current version
- def flops(self):
- flops = 0
- flops += self.patch_embed.flops()
- for i, layer in enumerate(self.layers):
- flops += layer.flops()
- flops += self.num_features * self.patches_resolution[
- 0] * self.patches_resolution[1] // (2**self.num_layers)
- flops += self.num_features * self.num_classes
- return flops
|