backbone.py 39 KB

12345678910111213141516171819202122232425262728293031323334353637383940414243444546474849505152535455565758596061626364656667686970717273747576777879808182838485868788899091929394959697989910010110210310410510610710810911011111211311411511611711811912012112212312412512612712812913013113213313413513613713813914014114214314414514614714814915015115215315415515615715815916016116216316416516616716816917017117217317417517617717817918018118218318418518618718818919019119219319419519619719819920020120220320420520620720820921021121221321421521621721821922022122222322422522622722822923023123223323423523623723823924024124224324424524624724824925025125225325425525625725825926026126226326426526626726826927027127227327427527627727827928028128228328428528628728828929029129229329429529629729829930030130230330430530630730830931031131231331431531631731831932032132232332432532632732832933033133233333433533633733833934034134234334434534634734834935035135235335435535635735835936036136236336436536636736836937037137237337437537637737837938038138238338438538638738838939039139239339439539639739839940040140240340440540640740840941041141241341441541641741841942042142242342442542642742842943043143243343443543643743843944044144244344444544644744844945045145245345445545645745845946046146246346446546646746846947047147247347447547647747847948048148248348448548648748848949049149249349449549649749849950050150250350450550650750850951051151251351451551651751851952052152252352452552652752852953053153253353453553653753853954054154254354454554654754854955055155255355455555655755855956056156256356456556656756856957057157257357457557657757857958058158258358458558658758858959059159259359459559659759859960060160260360460560660760860961061161261361461561661761861962062162262362462562662762862963063163263363463563663763863964064164264364464564664764864965065165265365465565665765865966066166266366466566666766866967067167267367467567667767867968068168268368468568668768868969069169269369469569669769869970070170270370470570670770870971071171271371471571671771871972072172272372472572672772872973073173273373473573673773873974074174274374474574674774874975075175275375475575675775875976076176276376476576676776876977077177277377477577677777877978078178278378478578678778878979079179279379479579679779879980080180280380480580680780880981081181281381481581681781881982082182282382482582682782882983083183283383483583683783883984084184284384484584684784884985085185285385485585685785885986086186286386486586686786886987087187287387487587687787887988088188288388488588688788888989089189289389489589689789889990090190290390490590690790890991091191291391491591691791891992092192292392492592692792892993093193293393493593693793893994094194294394494594694794894995095195295395495595695795895996096196296396496596696796896997097197297397497597697797897998098198298398498598698798898999099199299399499599699799899910001001100210031004100510061007100810091010101110121013101410151016101710181019102010211022102310241025102610271028102910301031103210331034103510361037103810391040104110421043104410451046104710481049105010511052105310541055105610571058105910601061
  1. # The implementation here is modified based on timm,
  2. # originally Apache 2.0 License and publicly available at
  3. # https://github.com/naver-ai/vidt/blob/vidt-plus/methods/swin_w_ram.py
  4. import math
  5. import os
  6. import numpy as np
  7. import torch
  8. import torch.nn as nn
  9. import torch.nn.functional as F
  10. import torch.utils.checkpoint as checkpoint
  11. from timm.models.layers import DropPath, to_2tuple, trunc_normal_
  12. class Mlp(nn.Module):
  13. """ Multilayer perceptron."""
  14. def __init__(self,
  15. in_features,
  16. hidden_features=None,
  17. out_features=None,
  18. act_layer=nn.GELU,
  19. drop=0.):
  20. super().__init__()
  21. out_features = out_features or in_features
  22. hidden_features = hidden_features or in_features
  23. self.fc1 = nn.Linear(in_features, hidden_features)
  24. self.act = act_layer()
  25. self.fc2 = nn.Linear(hidden_features, out_features)
  26. self.drop = nn.Dropout(drop)
  27. def forward(self, x):
  28. x = self.fc1(x)
  29. x = self.act(x)
  30. x = self.drop(x)
  31. x = self.fc2(x)
  32. x = self.drop(x)
  33. return x
  34. def masked_sin_pos_encoding(x,
  35. mask,
  36. num_pos_feats,
  37. temperature=10000,
  38. scale=2 * math.pi):
  39. """ Masked Sinusoidal Positional Encoding
  40. Args:
  41. x: [PATCH] tokens
  42. mask: the padding mask for [PATCH] tokens
  43. num_pos_feats: the size of channel dimension
  44. temperature: the temperature value
  45. scale: the normalization scale
  46. Returns:
  47. pos: Sinusoidal positional encodings
  48. """
  49. num_pos_feats = num_pos_feats // 2
  50. not_mask = ~mask
  51. y_embed = not_mask.cumsum(1, dtype=torch.float32)
  52. x_embed = not_mask.cumsum(2, dtype=torch.float32)
  53. eps = 1e-6
  54. y_embed = y_embed / (y_embed[:, -1:, :] + eps) * scale
  55. x_embed = x_embed / (x_embed[:, :, -1:] + eps) * scale
  56. dim_t = torch.arange(num_pos_feats, dtype=torch.float32, device=x.device)
  57. dim_t = temperature**(2 * (dim_t // 2) / num_pos_feats)
  58. pos_x = x_embed[:, :, :, None] / dim_t
  59. pos_y = y_embed[:, :, :, None] / dim_t
  60. pos_x = torch.stack(
  61. (pos_x[:, :, :, 0::2].sin(), pos_x[:, :, :, 1::2].cos()),
  62. dim=4).flatten(3)
  63. pos_y = torch.stack(
  64. (pos_y[:, :, :, 0::2].sin(), pos_y[:, :, :, 1::2].cos()),
  65. dim=4).flatten(3)
  66. pos = torch.cat((pos_y, pos_x), dim=3)
  67. return pos
  68. def window_partition(x, window_size):
  69. """
  70. Args:
  71. x: (B, H, W, C)
  72. window_size (int): window size
  73. Returns:
  74. windows: (num_windows*B, window_size, window_size, C)
  75. """
  76. B, H, W, C = x.shape
  77. x = x.view(B, H // window_size, window_size, W // window_size, window_size,
  78. C)
  79. windows = x.permute(0, 1, 3, 2, 4,
  80. 5).contiguous().view(-1, window_size, window_size, C)
  81. return windows
  82. def window_reverse(windows, window_size, H, W):
  83. """
  84. Args:
  85. windows: (num_windows*B, window_size, window_size, C)
  86. window_size (int): Window size
  87. H (int): Height of image
  88. W (int): Width of image
  89. Returns:
  90. x: (B, H, W, C)
  91. """
  92. B = int(windows.shape[0] / (H * W / window_size / window_size))
  93. x = windows.view(B, H // window_size, W // window_size, window_size,
  94. window_size, -1)
  95. x = x.permute(0, 1, 3, 2, 4, 5).contiguous().view(B, H, W, -1)
  96. return x
  97. class ReconfiguredAttentionModule(nn.Module):
  98. """ Window based multi-head self attention (W-MSA) module with relative position bias -> extended with RAM.
  99. It supports both of shifted and non-shifted window.
  100. !!!!!!!!!!! IMPORTANT !!!!!!!!!!!
  101. The original attention module in Swin is replaced with the reconfigured attention module in Section 3.
  102. All the Args are shared, so only the forward function is modified.
  103. See https://arxiv.org/pdf/2110.03921.pdf
  104. !!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!
  105. Args:
  106. dim (int): Number of input channels.
  107. window_size (tuple[int]): The height and width of the window.
  108. num_heads (int): Number of attention heads.
  109. qkv_bias (bool, optional): If True, add a learnable bias to query, key, value. Default: True
  110. qk_scale (float | None, optional): Override default qk scale of head_dim ** -0.5 if set
  111. attn_drop (float, optional): Dropout ratio of attention weight. Default: 0.0
  112. proj_drop (float, optional): Dropout ratio of output. Default: 0.0
  113. """
  114. def __init__(self,
  115. dim,
  116. window_size,
  117. num_heads,
  118. qkv_bias=True,
  119. qk_scale=None,
  120. attn_drop=0.,
  121. proj_drop=0.):
  122. super().__init__()
  123. self.dim = dim
  124. self.window_size = window_size # Wh, Ww
  125. self.num_heads = num_heads
  126. head_dim = dim // num_heads
  127. self.scale = qk_scale or head_dim**-0.5
  128. # define a parameter table of relative position bias
  129. self.relative_position_bias_table = nn.Parameter(
  130. torch.zeros((2 * window_size[0] - 1) * (2 * window_size[1] - 1),
  131. num_heads)) # 2*Wh-1 * 2*Ww-1, nH
  132. # get pair-wise relative position index for each token inside the window
  133. coords_h = torch.arange(self.window_size[0])
  134. coords_w = torch.arange(self.window_size[1])
  135. coords = torch.stack(torch.meshgrid([coords_h, coords_w])) # 2, Wh, Ww
  136. coords_flatten = torch.flatten(coords, 1) # 2, Wh*Ww
  137. relative_coords = coords_flatten[:, :,
  138. None] - coords_flatten[:,
  139. None, :] # 2, Wh*Ww, Wh*Ww
  140. relative_coords = relative_coords.permute(
  141. 1, 2, 0).contiguous() # Wh*Ww, Wh*Ww, 2
  142. relative_coords[:, :,
  143. 0] += self.window_size[0] - 1 # shift to start from 0
  144. relative_coords[:, :, 1] += self.window_size[1] - 1
  145. relative_coords[:, :, 0] *= 2 * self.window_size[1] - 1
  146. relative_position_index = relative_coords.sum(-1) # Wh*Ww, Wh*Ww
  147. self.register_buffer('relative_position_index',
  148. relative_position_index)
  149. self.qkv = nn.Linear(dim, dim * 3, bias=qkv_bias)
  150. self.attn_drop = nn.Dropout(attn_drop)
  151. self.proj = nn.Linear(dim, dim)
  152. self.proj_drop = nn.Dropout(proj_drop)
  153. trunc_normal_(self.relative_position_bias_table, std=.02)
  154. self.softmax = nn.Softmax(dim=-1)
  155. def forward(self,
  156. x,
  157. det,
  158. mask=None,
  159. cross_attn=False,
  160. cross_attn_mask=None):
  161. """ Forward function.
  162. RAM module receives [Patch] and [DET] tokens and returns their calibrated ones
  163. Args:
  164. x: [PATCH] tokens
  165. det: [DET] tokens
  166. mask: (0/-inf) mask with shape of (num_windows, Wh*Ww, Wh*Ww) or None -> mask for shifted window attention
  167. "additional inputs for RAM"
  168. cross_attn: whether to use cross-attention [det x patch] (for selective cross-attention)
  169. cross_attn_mask: mask for cross-attention
  170. Returns:
  171. patch_x: the calibrated [PATCH] tokens
  172. det_x: the calibrated [DET] tokens
  173. """
  174. assert self.window_size[0] == self.window_size[1]
  175. window_size = self.window_size[0]
  176. local_map_size = window_size * window_size
  177. # projection before window partitioning
  178. if not cross_attn:
  179. B, H, W, C = x.shape
  180. N = H * W
  181. x = x.view(B, N, C)
  182. x = torch.cat([x, det], dim=1)
  183. full_qkv = self.qkv(x)
  184. patch_qkv, det_qkv = full_qkv[:, :N, :], full_qkv[:, N:, :]
  185. else:
  186. B, H, W, C = x[0].shape
  187. N = H * W
  188. _, ori_H, ori_W, _ = x[1].shape
  189. ori_N = ori_H * ori_W
  190. shifted_x = x[0].view(B, N, C)
  191. cross_x = x[1].view(B, ori_N, C)
  192. x = torch.cat([shifted_x, cross_x, det], dim=1)
  193. full_qkv = self.qkv(x)
  194. patch_qkv, cross_patch_qkv, det_qkv = \
  195. full_qkv[:, :N, :], full_qkv[:, N:N + ori_N, :], full_qkv[:, N + ori_N:, :]
  196. patch_qkv = patch_qkv.view(B, H, W, -1)
  197. # window partitioning for [PATCH] tokens
  198. patch_qkv = window_partition(
  199. patch_qkv, window_size) # nW*B, window_size, window_size, C
  200. B_ = patch_qkv.shape[0]
  201. patch_qkv = patch_qkv.reshape(B_, window_size * window_size, 3,
  202. self.num_heads, C // self.num_heads)
  203. _patch_qkv = patch_qkv.permute(2, 0, 3, 1, 4)
  204. patch_q, patch_k, patch_v = _patch_qkv[0], _patch_qkv[1], _patch_qkv[2]
  205. # [PATCH x PATCH] self-attention using window partitions
  206. patch_q = patch_q * self.scale
  207. patch_attn = (patch_q @ patch_k.transpose(-2, -1))
  208. # add relative pos bias for [patch x patch] self-attention
  209. relative_position_bias = self.relative_position_bias_table[
  210. self.relative_position_index.view(-1)].view(
  211. self.window_size[0] * self.window_size[1],
  212. self.window_size[0] * self.window_size[1],
  213. -1) # Wh*Ww,Wh*Ww,nH
  214. relative_position_bias = relative_position_bias.permute(
  215. 2, 0, 1).contiguous() # nH, Wh*Ww, Wh*Ww
  216. patch_attn = patch_attn + relative_position_bias.unsqueeze(0)
  217. # if shifted window is used, it needs to apply the mask
  218. if mask is not None:
  219. nW = mask.shape[0]
  220. tmp0 = patch_attn.view(B_ // nW, nW, self.num_heads,
  221. local_map_size, local_map_size)
  222. tmp1 = mask.unsqueeze(1).unsqueeze(0)
  223. patch_attn = tmp0 + tmp1
  224. patch_attn = patch_attn.view(-1, self.num_heads, local_map_size,
  225. local_map_size)
  226. patch_attn = self.softmax(patch_attn)
  227. patch_attn = self.attn_drop(patch_attn)
  228. patch_x = (patch_attn @ patch_v).transpose(1, 2).reshape(
  229. B_, window_size, window_size, C)
  230. # extract qkv for [DET] tokens
  231. det_qkv = det_qkv.view(B, -1, 3, self.num_heads, C // self.num_heads)
  232. det_qkv = det_qkv.permute(2, 0, 3, 1, 4)
  233. det_q, det_k, det_v = det_qkv[0], det_qkv[1], det_qkv[2]
  234. # if cross-attention is activated
  235. if cross_attn:
  236. # reconstruct the spatial form of [PATCH] tokens for global [DET x PATCH] attention
  237. cross_patch_qkv = cross_patch_qkv.view(B, ori_H, ori_W, 3,
  238. self.num_heads,
  239. C // self.num_heads)
  240. patch_kv = cross_patch_qkv[:, :, :,
  241. 1:, :, :].permute(3, 0, 4, 1, 2,
  242. 5).contiguous()
  243. patch_kv = patch_kv.view(2, B, self.num_heads, ori_H * ori_W, -1)
  244. # extract "key and value" of [PATCH] tokens for cross-attention
  245. cross_patch_k, cross_patch_v = patch_kv[0], patch_kv[1]
  246. # bind key and value of [PATCH] and [DET] tokens for [DET X [PATCH, DET]] attention
  247. det_k, det_v = torch.cat([cross_patch_k, det_k],
  248. dim=2), torch.cat([cross_patch_v, det_v],
  249. dim=2)
  250. # [DET x DET] self-attention or binded [DET x [PATCH, DET]] attention
  251. det_q = det_q * self.scale
  252. det_attn = (det_q @ det_k.transpose(-2, -1))
  253. # apply cross-attention mask if available
  254. if cross_attn_mask is not None:
  255. det_attn = det_attn + cross_attn_mask
  256. det_attn = self.softmax(det_attn)
  257. det_attn = self.attn_drop(det_attn)
  258. det_x = (det_attn @ det_v).transpose(1, 2).reshape(B, -1, C)
  259. # reverse window for [PATCH] tokens <- the output of [PATCH x PATCH] self attention
  260. patch_x = window_reverse(patch_x, window_size, H, W)
  261. # projection for outputs from multi-head
  262. x = torch.cat([patch_x.view(B, H * W, C), det_x], dim=1)
  263. x = self.proj(x)
  264. x = self.proj_drop(x)
  265. # decompose after FFN into [PATCH] and [DET] tokens
  266. patch_x = x[:, :H * W, :].view(B, H, W, C)
  267. det_x = x[:, H * W:, :]
  268. return patch_x, det_x
  269. class SwinTransformerBlock(nn.Module):
  270. """ Swin Transformer Block.
  271. Args:
  272. dim (int): Number of input channels.
  273. num_heads (int): Number of attention heads.
  274. window_size (int): Window size.
  275. shift_size (int): Shift size for SW-MSA.
  276. mlp_ratio (float): Ratio of mlp hidden dim to embedding dim.
  277. qkv_bias (bool, optional): If True, add a learnable bias to query, key, value. Default: True
  278. qk_scale (float | None, optional): Override default qk scale of head_dim ** -0.5 if set.
  279. drop (float, optional): Dropout rate. Default: 0.0
  280. attn_drop (float, optional): Attention dropout rate. Default: 0.0
  281. drop_path (float, optional): Stochastic depth rate. Default: 0.0
  282. act_layer (nn.Module, optional): Activation layer. Default: nn.GELU
  283. norm_layer (nn.Module, optional): Normalization layer. Default: nn.LayerNorm
  284. """
  285. def __init__(self,
  286. dim,
  287. num_heads,
  288. window_size=7,
  289. shift_size=0,
  290. mlp_ratio=4.,
  291. qkv_bias=True,
  292. qk_scale=None,
  293. drop=0.,
  294. attn_drop=0.,
  295. drop_path=0.,
  296. act_layer=nn.GELU,
  297. norm_layer=nn.LayerNorm):
  298. super().__init__()
  299. self.dim = dim
  300. self.num_heads = num_heads
  301. self.window_size = window_size
  302. self.shift_size = shift_size
  303. self.mlp_ratio = mlp_ratio
  304. assert 0 <= self.shift_size < self.window_size, 'shift_size must in 0-window_size'
  305. self.norm1 = norm_layer(dim)
  306. self.attn = ReconfiguredAttentionModule(
  307. dim,
  308. window_size=to_2tuple(self.window_size),
  309. num_heads=num_heads,
  310. qkv_bias=qkv_bias,
  311. qk_scale=qk_scale,
  312. attn_drop=attn_drop,
  313. proj_drop=drop)
  314. self.drop_path = DropPath(
  315. drop_path) if drop_path > 0. else nn.Identity()
  316. self.norm2 = norm_layer(dim)
  317. mlp_hidden_dim = int(dim * mlp_ratio)
  318. self.mlp = Mlp(
  319. in_features=dim,
  320. hidden_features=mlp_hidden_dim,
  321. act_layer=act_layer,
  322. drop=drop)
  323. self.H = None
  324. self.W = None
  325. def forward(self, x, mask_matrix, pos, cross_attn, cross_attn_mask):
  326. """ Forward function.
  327. Args:
  328. x: Input feature, tensor size (B, H*W + DET, C). i.e., binded [PATCH, DET] tokens
  329. H, W: Spatial resolution of the input feature.
  330. mask_matrix: Attention mask for cyclic shift.
  331. "additional inputs'
  332. pos: (patch_pos, det_pos)
  333. cross_attn: whether to use cross attn [det x [det + patch]]
  334. cross_attn_mask: attention mask for cross-attention
  335. Returns:
  336. x: calibrated & binded [PATCH, DET] tokens
  337. """
  338. B, L, C = x.shape
  339. H, W = self.H, self.W
  340. assert L == H * W + self.det_token_num, 'input feature has wrong size'
  341. shortcut = x
  342. x = self.norm1(x)
  343. x, det = x[:, :H * W, :], x[:, H * W:, :]
  344. x = x.view(B, H, W, C)
  345. orig_x = x
  346. # pad feature maps to multiples of window size
  347. pad_l = pad_t = 0
  348. pad_r = (self.window_size - W % self.window_size) % self.window_size
  349. pad_b = (self.window_size - H % self.window_size) % self.window_size
  350. x = F.pad(x, (0, 0, pad_l, pad_r, pad_t, pad_b))
  351. _, Hp, Wp, _ = x.shape
  352. # projection for det positional encodings: make the channel size suitable for the current layer
  353. patch_pos, det_pos = pos
  354. det_pos = self.det_pos_linear(det_pos)
  355. # cyclic shift
  356. if self.shift_size > 0:
  357. shifted_x = torch.roll(
  358. x, shifts=(-self.shift_size, -self.shift_size), dims=(1, 2))
  359. attn_mask = mask_matrix
  360. else:
  361. shifted_x = x
  362. attn_mask = None
  363. # prepare cross-attn and add positional encodings
  364. if cross_attn:
  365. # patch token (for cross-attention) + Sinusoidal pos encoding
  366. cross_patch = orig_x + patch_pos
  367. # det token + learnable pos encoding
  368. det = det + det_pos
  369. shifted_x = (shifted_x, cross_patch)
  370. else:
  371. # it cross_attn is deactivated, only [PATCH] and [DET] self-attention are performed
  372. det = det + det_pos
  373. shifted_x = shifted_x
  374. # W-MSA/SW-MSA
  375. shifted_x, det = self.attn(
  376. shifted_x,
  377. mask=attn_mask,
  378. # additional args
  379. det=det,
  380. cross_attn=cross_attn,
  381. cross_attn_mask=cross_attn_mask)
  382. # reverse cyclic shift
  383. if self.shift_size > 0:
  384. x = torch.roll(
  385. shifted_x,
  386. shifts=(self.shift_size, self.shift_size),
  387. dims=(1, 2))
  388. else:
  389. x = shifted_x
  390. if pad_r > 0 or pad_b > 0:
  391. x = x[:, :H, :W, :].contiguous()
  392. x = x.view(B, H * W, C)
  393. x = torch.cat([x, det], dim=1)
  394. # FFN
  395. x = shortcut + self.drop_path(x)
  396. x = x + self.drop_path(self.mlp(self.norm2(x)))
  397. return x
  398. class PatchMerging(nn.Module):
  399. """ Patch Merging Layer
  400. Args:
  401. dim (int): Number of input channels.
  402. norm_layer (nn.Module, optional): Normalization layer. Default: nn.LayerNorm
  403. """
  404. def __init__(self, dim, norm_layer=nn.LayerNorm, expand=True):
  405. super().__init__()
  406. self.dim = dim
  407. # if expand is True, the channel size will be expanded, otherwise, return 256 size of channel
  408. expand_dim = 2 * dim if expand else 256
  409. self.reduction = nn.Linear(4 * dim, expand_dim, bias=False)
  410. self.norm = norm_layer(4 * dim)
  411. # added for detection token [please ignore, not used for training]
  412. # not implemented yet.
  413. self.expansion = nn.Linear(dim, expand_dim, bias=False)
  414. self.norm2 = norm_layer(dim)
  415. def forward(self, x, H, W):
  416. """ Forward function.
  417. Args:
  418. x: Input feature, tensor size (B, H*W, C), i.e., binded [PATCH, DET] tokens
  419. H, W: Spatial resolution of the input feature.
  420. Returns:
  421. x: merged [PATCH, DET] tokens;
  422. only [PATCH] tokens are reduced in spatial dim, while [DET] tokens is fix-scale
  423. """
  424. B, L, C = x.shape
  425. assert L == H * W + self.det_token_num, 'input feature has wrong size'
  426. x, det = x[:, :H * W, :], x[:, H * W:, :]
  427. x = x.view(B, H, W, C)
  428. # padding
  429. pad_input = (H % 2 == 1) or (W % 2 == 1)
  430. if pad_input:
  431. x = F.pad(x, (0, 0, 0, W % 2, 0, H % 2))
  432. x0 = x[:, 0::2, 0::2, :] # B H/2 W/2 C
  433. x1 = x[:, 1::2, 0::2, :] # B H/2 W/2 C
  434. x2 = x[:, 0::2, 1::2, :] # B H/2 W/2 C
  435. x3 = x[:, 1::2, 1::2, :] # B H/2 W/2 C
  436. x = torch.cat([x0, x1, x2, x3], -1) # B H/2 W/2 4*C
  437. x = x.view(B, -1, 4 * C) # B H/2*W/2 4*C
  438. # simply repeating for DET tokens
  439. det = det.repeat(1, 1, 4)
  440. x = torch.cat([x, det], dim=1)
  441. x = self.norm(x)
  442. x = self.reduction(x)
  443. return x
  444. class BasicLayer(nn.Module):
  445. """ A basic Swin Transformer layer for one stage.
  446. Args:
  447. dim (int): Number of feature channels
  448. depth (int): Depths of this stage.
  449. num_heads (int): Number of attention head.
  450. window_size (int): Local window size. Default: 7.
  451. mlp_ratio (float): Ratio of mlp hidden dim to embedding dim. Default: 4.
  452. qkv_bias (bool, optional): If True, add a learnable bias to query, key, value. Default: True
  453. qk_scale (float | None, optional): Override default qk scale of head_dim ** -0.5 if set.
  454. drop (float, optional): Dropout rate. Default: 0.0
  455. attn_drop (float, optional): Attention dropout rate. Default: 0.0
  456. drop_path (float | tuple[float], optional): Stochastic depth rate. Default: 0.0
  457. norm_layer (nn.Module, optional): Normalization layer. Default: nn.LayerNorm
  458. downsample (nn.Module | None, optional): Downsample layer at the end of the layer. Default: None
  459. use_checkpoint (bool): Whether to use checkpointing to save memory. Default: False.
  460. """
  461. def __init__(self,
  462. dim,
  463. depth,
  464. num_heads,
  465. window_size=7,
  466. mlp_ratio=4.,
  467. qkv_bias=True,
  468. qk_scale=None,
  469. drop=0.,
  470. attn_drop=0.,
  471. drop_path=0.,
  472. norm_layer=nn.LayerNorm,
  473. downsample=None,
  474. last=False,
  475. use_checkpoint=False):
  476. super().__init__()
  477. self.window_size = window_size
  478. self.shift_size = window_size // 2
  479. self.depth = depth
  480. self.dim = dim
  481. self.use_checkpoint = use_checkpoint
  482. # build blocks
  483. self.blocks = nn.ModuleList([
  484. SwinTransformerBlock(
  485. dim=dim,
  486. num_heads=num_heads,
  487. window_size=window_size,
  488. shift_size=0 if (i % 2 == 0) else window_size // 2,
  489. mlp_ratio=mlp_ratio,
  490. qkv_bias=qkv_bias,
  491. qk_scale=qk_scale,
  492. drop=drop,
  493. attn_drop=attn_drop,
  494. drop_path=drop_path[i]
  495. if isinstance(drop_path, list) else drop_path,
  496. norm_layer=norm_layer) for i in range(depth)
  497. ])
  498. # patch merging layer
  499. if downsample is not None:
  500. self.downsample = downsample(
  501. dim=dim, norm_layer=norm_layer, expand=(not last))
  502. else:
  503. self.downsample = None
  504. def forward(self, x, H, W, det_pos, input_mask, cross_attn=False):
  505. """ Forward function.
  506. Args:
  507. x: Input feature, tensor size (B, H*W, C).
  508. H, W: Spatial resolution of the input feature.
  509. det_pos: pos encoding for det token
  510. input_mask: padding mask for inputs
  511. cross_attn: whether to use cross attn [det x [det + patch]]
  512. """
  513. B = x.shape[0]
  514. # calculate attention mask for SW-MSA
  515. Hp = int(np.ceil(H / self.window_size)) * self.window_size
  516. Wp = int(np.ceil(W / self.window_size)) * self.window_size
  517. img_mask = torch.zeros((1, Hp, Wp, 1), device=x.device) # 1 Hp Wp 1
  518. h_slices = (slice(0, -self.window_size),
  519. slice(-self.window_size,
  520. -self.shift_size), slice(-self.shift_size, None))
  521. w_slices = (slice(0, -self.window_size),
  522. slice(-self.window_size,
  523. -self.shift_size), slice(-self.shift_size, None))
  524. cnt = 0
  525. for h in h_slices:
  526. for w in w_slices:
  527. img_mask[:, h, w, :] = cnt
  528. cnt += 1
  529. # mask for cyclic shift
  530. mask_windows = window_partition(img_mask, self.window_size)
  531. mask_windows = mask_windows.view(-1,
  532. self.window_size * self.window_size)
  533. attn_mask = mask_windows.unsqueeze(1) - mask_windows.unsqueeze(2)
  534. attn_mask = attn_mask.masked_fill(attn_mask != 0,
  535. float(-100.0)).masked_fill(
  536. attn_mask == 0, float(0.0))
  537. # compute sinusoidal pos encoding and cross-attn mask here to avoid redundant computation
  538. if cross_attn:
  539. _H, _W = input_mask.shape[1:]
  540. if not (_H == H and _W == W):
  541. input_mask = F.interpolate(
  542. input_mask[None].float(), size=(H, W)).to(torch.bool)[0]
  543. # sinusoidal pos encoding for [PATCH] tokens used in cross-attention
  544. patch_pos = masked_sin_pos_encoding(x, input_mask, self.dim)
  545. # attention padding mask due to the zero padding in inputs
  546. # the zero (padded) area is masked by 1.0 in 'input_mask'
  547. cross_attn_mask = input_mask.float()
  548. cross_attn_mask = cross_attn_mask.masked_fill(cross_attn_mask != 0.0, float(-100.0)). \
  549. masked_fill(cross_attn_mask == 0.0, float(0.0))
  550. # pad for detection token (this padding is required to process the binded [PATCH, DET] attention
  551. cross_attn_mask = cross_attn_mask.view(
  552. B, H * W).unsqueeze(1).unsqueeze(2)
  553. cross_attn_mask = F.pad(
  554. cross_attn_mask, (0, self.det_token_num), value=0)
  555. else:
  556. patch_pos = None
  557. cross_attn_mask = None
  558. # zip pos encodings
  559. pos = (patch_pos, det_pos)
  560. for n_blk, blk in enumerate(self.blocks):
  561. blk.H, blk.W = H, W
  562. # for selective cross-attention
  563. if cross_attn:
  564. _cross_attn = True
  565. _cross_attn_mask = cross_attn_mask
  566. _pos = pos # i.e., (patch_pos, det_pos)
  567. else:
  568. _cross_attn = False
  569. _cross_attn_mask = None
  570. _pos = (None, det_pos)
  571. if self.use_checkpoint:
  572. x = checkpoint.checkpoint(
  573. blk,
  574. x,
  575. attn_mask,
  576. # additional inputs
  577. pos=_pos,
  578. cross_attn=_cross_attn,
  579. cross_attn_mask=_cross_attn_mask)
  580. else:
  581. x = blk(
  582. x,
  583. attn_mask,
  584. # additional inputs
  585. pos=_pos,
  586. cross_attn=_cross_attn,
  587. cross_attn_mask=_cross_attn_mask)
  588. # reduce the number of patch tokens, but maintaining a fixed-scale det tokens
  589. # meanwhile, the channel dim increases by a factor of 2
  590. if self.downsample is not None:
  591. x_down = self.downsample(x, H, W)
  592. Wh, Ww = (H + 1) // 2, (W + 1) // 2
  593. return x, H, W, x_down, Wh, Ww
  594. else:
  595. return x, H, W, x, H, W
  596. class PatchEmbed(nn.Module):
  597. """ Image to Patch Embedding
  598. Args:
  599. patch_size (int): Patch token size. Default: 4.
  600. in_chans (int): Number of input image channels. Default: 3.
  601. embed_dim (int): Number of linear projection output channels. Default: 96.
  602. norm_layer (nn.Module, optional): Normalization layer. Default: None
  603. """
  604. def __init__(self,
  605. patch_size=4,
  606. in_chans=3,
  607. embed_dim=96,
  608. norm_layer=None):
  609. super().__init__()
  610. patch_size = to_2tuple(patch_size)
  611. self.patch_size = patch_size
  612. self.in_chans = in_chans
  613. self.embed_dim = embed_dim
  614. self.proj = nn.Conv2d(
  615. in_chans, embed_dim, kernel_size=patch_size, stride=patch_size)
  616. if norm_layer is not None:
  617. self.norm = norm_layer(embed_dim)
  618. else:
  619. self.norm = None
  620. def forward(self, x):
  621. """Forward function."""
  622. # padding
  623. _, _, H, W = x.size()
  624. if W % self.patch_size[1] != 0:
  625. x = F.pad(x, (0, self.patch_size[1] - W % self.patch_size[1]))
  626. if H % self.patch_size[0] != 0:
  627. x = F.pad(x,
  628. (0, 0, 0, self.patch_size[0] - H % self.patch_size[0]))
  629. x = self.proj(x) # B C Wh Ww
  630. if self.norm is not None:
  631. Wh, Ww = x.size(2), x.size(3)
  632. x = x.flatten(2).transpose(1, 2)
  633. x = self.norm(x)
  634. x = x.transpose(1, 2).view(-1, self.embed_dim, Wh, Ww)
  635. return x
  636. class SwinTransformer(nn.Module):
  637. """ Swin Transformer backbone.
  638. A PyTorch impl of : `Swin Transformer: Hierarchical Vision Transformer using Shifted Windows` -
  639. https://arxiv.org/pdf/2103.14030
  640. Args:
  641. pretrain_img_size (int): Input image size for training the pretrained model,
  642. used in absolute position embedding. Default 224.
  643. patch_size (int | tuple(int)): Patch size. Default: 4.
  644. in_chans (int): Number of input image channels. Default: 3.
  645. embed_dim (int): Number of linear projection output channels. Default: 96.
  646. depths (tuple[int]): Depths of each Swin Transformer stage.
  647. num_heads (tuple[int]): Number of attention head of each stage.
  648. window_size (int): Window size. Default: 7.
  649. mlp_ratio (float): Ratio of mlp hidden dim to embedding dim. Default: 4.
  650. qkv_bias (bool): If True, add a learnable bias to query, key, value. Default: True
  651. qk_scale (float): Override default qk scale of head_dim ** -0.5 if set.
  652. drop_rate (float): Dropout rate.
  653. attn_drop_rate (float): Attention dropout rate. Default: 0.
  654. drop_path_rate (float): Stochastic depth rate. Default: 0.2.
  655. norm_layer (nn.Module): Normalization layer. Default: nn.LayerNorm.
  656. ape (bool): If True, add absolute position embedding to the patch embedding. Default: False.
  657. patch_norm (bool): If True, add normalization after patch embedding. Default: True.
  658. out_indices (Sequence[int]): Output from which stages.
  659. frozen_stages (int): Stages to be frozen (stop grad and set eval mode).
  660. -1 means not freezing any args.
  661. use_checkpoint (bool): Whether to use checkpointing to save memory. Default: False.
  662. """
  663. def __init__(
  664. self,
  665. pretrain_img_size=224,
  666. patch_size=4,
  667. in_chans=3,
  668. embed_dim=96,
  669. depths=[2, 2, 6, 2],
  670. num_heads=[3, 6, 12, 24],
  671. window_size=7,
  672. mlp_ratio=4.,
  673. qkv_bias=True,
  674. qk_scale=None,
  675. drop_rate=0.,
  676. attn_drop_rate=0.,
  677. drop_path_rate=0.2,
  678. norm_layer=nn.LayerNorm,
  679. ape=False,
  680. patch_norm=True,
  681. out_indices=[1, 2,
  682. 3], # not used in the current version, please ignore.
  683. frozen_stages=-1,
  684. use_checkpoint=False):
  685. super().__init__()
  686. self.pretrain_img_size = pretrain_img_size
  687. self.num_layers = len(depths)
  688. self.embed_dim = embed_dim
  689. self.ape = ape
  690. self.patch_norm = patch_norm
  691. self.out_indices = out_indices
  692. self.frozen_stages = frozen_stages
  693. # split image into non-overlapping patches
  694. self.patch_embed = PatchEmbed(
  695. patch_size=patch_size,
  696. in_chans=in_chans,
  697. embed_dim=embed_dim,
  698. norm_layer=norm_layer if self.patch_norm else None)
  699. # absolute position embedding
  700. if self.ape:
  701. pretrain_img_size = to_2tuple(pretrain_img_size)
  702. patch_size = to_2tuple(patch_size)
  703. patches_resolution = [
  704. pretrain_img_size[0] // patch_size[0],
  705. pretrain_img_size[1] // patch_size[1]
  706. ]
  707. self.absolute_pos_embed = nn.Parameter(
  708. torch.zeros(1, embed_dim, patches_resolution[0],
  709. patches_resolution[1]))
  710. trunc_normal_(self.absolute_pos_embed, std=.02)
  711. self.pos_drop = nn.Dropout(p=drop_rate)
  712. # stochastic depth
  713. dpr = [
  714. x.item() for x in torch.linspace(0, drop_path_rate, sum(depths))
  715. ] # stochastic depth decay rule
  716. # build layers
  717. self.layers = nn.ModuleList()
  718. for i_layer in range(self.num_layers):
  719. layer = BasicLayer(
  720. dim=int(embed_dim * 2**i_layer),
  721. depth=depths[i_layer],
  722. num_heads=num_heads[i_layer],
  723. window_size=window_size,
  724. mlp_ratio=mlp_ratio,
  725. qkv_bias=qkv_bias,
  726. qk_scale=qk_scale,
  727. drop=drop_rate,
  728. attn_drop=attn_drop_rate,
  729. drop_path=dpr[sum(depths[:i_layer]):sum(depths[:i_layer + 1])],
  730. norm_layer=norm_layer,
  731. # modified by ViDT
  732. downsample=PatchMerging if
  733. (i_layer < self.num_layers) else None,
  734. last=None if (i_layer < self.num_layers - 1) else True,
  735. #
  736. use_checkpoint=use_checkpoint)
  737. self.layers.append(layer)
  738. num_features = [int(embed_dim * 2**i) for i in range(self.num_layers)]
  739. self.num_features = num_features
  740. # add a norm layer for each output
  741. # Not used in the current version -> please ignore. this error will be fixed later
  742. # we leave this lines to load the pre-trained model ...
  743. for i_layer in out_indices:
  744. layer = norm_layer(num_features[i_layer])
  745. layer_name = f'norm{i_layer}'
  746. self.add_module(layer_name, layer)
  747. self._freeze_stages()
  748. def _freeze_stages(self):
  749. if self.frozen_stages >= 0:
  750. self.patch_embed.eval()
  751. for param in self.patch_embed.parameters():
  752. param.requires_grad = False
  753. if self.frozen_stages >= 1 and self.ape:
  754. self.absolute_pos_embed.requires_grad = False
  755. if self.frozen_stages >= 2:
  756. self.pos_drop.eval()
  757. for i in range(0, self.frozen_stages - 1):
  758. m = self.layers[i]
  759. m.eval()
  760. for param in m.parameters():
  761. param.requires_grad = False
  762. @torch.jit.ignore
  763. def no_weight_decay(self):
  764. return {'det_pos_embed', 'det_token'}
  765. def finetune_det(self,
  766. method,
  767. det_token_num=100,
  768. pos_dim=256,
  769. cross_indices=[3]):
  770. """ A function to add necessary (leanable) variables to Swin Transformer for object detection
  771. Args:
  772. method: vidt or vidt_wo_neck
  773. det_token_num: the number of object to detect, i.e., number of object queries
  774. pos_dim: the channel dimension of positional encodings for [DET] and [PATCH] tokens
  775. cross_indices: the indices where to use the [DET X PATCH] cross-attention
  776. there are four possible stages in [0, 1, 2, 3]. 3 indicates Stage 4 in the ViDT paper.
  777. """
  778. # which method?
  779. self.method = method
  780. # how many object we detect?
  781. self.det_token_num = det_token_num
  782. self.det_token = nn.Parameter(
  783. torch.zeros(1, det_token_num, self.num_features[0]))
  784. self.det_token = trunc_normal_(self.det_token, std=.02)
  785. # dim size of pos encoding
  786. self.pos_dim = pos_dim
  787. # learnable positional encoding for detection tokens
  788. det_pos_embed = torch.zeros(1, det_token_num, pos_dim)
  789. det_pos_embed = trunc_normal_(det_pos_embed, std=.02)
  790. self.det_pos_embed = torch.nn.Parameter(det_pos_embed)
  791. # info for detection
  792. self.num_channels = [
  793. self.num_features[i + 1]
  794. for i in range(len(self.num_features) - 1)
  795. ]
  796. if method == 'vidt':
  797. self.num_channels.append(
  798. self.pos_dim) # default: 256 (same to the default pos_dim)
  799. self.cross_indices = cross_indices
  800. # divisor to reduce the spatial size of the mask
  801. self.mask_divisor = 2**(len(self.layers) - len(self.cross_indices))
  802. # projection matrix for det pos encoding in each Swin layer (there are 4 blocks)
  803. for layer in self.layers:
  804. layer.det_token_num = det_token_num
  805. if layer.downsample is not None:
  806. layer.downsample.det_token_num = det_token_num
  807. for block in layer.blocks:
  808. block.det_token_num = det_token_num
  809. block.det_pos_linear = nn.Linear(pos_dim, block.dim)
  810. # neck-free model do not require downsampling at the last stage.
  811. if method == 'vidt_wo_neck':
  812. self.layers[-1].downsample = None
  813. def forward(self, x, mask):
  814. """ Forward function.
  815. Args:
  816. x: input rgb images
  817. mask: input padding masks [0: rgb values, 1: padded values]
  818. Returns:
  819. patch_outs: multi-scale [PATCH] tokens (four scales are used)
  820. these tokens are the first input of the neck decoder
  821. det_tgt: final [DET] tokens obtained at the last stage
  822. this tokens are the second input of the neck decoder
  823. det_pos: the learnable pos encoding for [DET] tokens.
  824. these encodings are used to generate reference points in deformable attention
  825. """
  826. # original input shape
  827. B, _, _ = x.shape[0], x.shape[2], x.shape[3]
  828. # patch embedding
  829. x = self.patch_embed(x)
  830. Wh, Ww = x.size(2), x.size(3)
  831. x = x.flatten(2).transpose(1, 2)
  832. x = self.pos_drop(x)
  833. # expand det_token for all examples in the batch
  834. det_token = self.det_token.expand(B, -1, -1)
  835. # det pos encoding -> will be projected in each block
  836. det_pos = self.det_pos_embed
  837. # prepare a mask for cross attention
  838. mask = F.interpolate(
  839. mask[None].float(),
  840. size=(Wh // self.mask_divisor,
  841. Ww // self.mask_divisor)).to(torch.bool)[0]
  842. patch_outs = []
  843. for stage in range(self.num_layers):
  844. layer = self.layers[stage]
  845. # whether to use cross-attention
  846. cross_attn = True if stage in self.cross_indices else False
  847. # concat input
  848. x = torch.cat([x, det_token], dim=1)
  849. # inference
  850. x_out, H, W, x, Wh, Ww = layer(
  851. x,
  852. Wh,
  853. Ww,
  854. # additional input for VIDT
  855. input_mask=mask,
  856. det_pos=det_pos,
  857. cross_attn=cross_attn)
  858. x, det_token = x[:, :-self.det_token_num, :], x[:, -self.
  859. det_token_num:, :]
  860. # Aggregate intermediate outputs
  861. if stage > 0:
  862. patch_out = x_out[:, :-self.det_token_num, :].view(
  863. B, H, W, -1).permute(0, 3, 1, 2)
  864. patch_outs.append(patch_out)
  865. # patch token reduced from last stage output
  866. patch_outs.append(x.view(B, Wh, Ww, -1).permute(0, 3, 1, 2))
  867. # det token
  868. det_tgt = x_out[:, -self.det_token_num:, :].permute(0, 2, 1)
  869. # det token pos encoding
  870. det_pos = det_pos.permute(0, 2, 1)
  871. features_0, features_1, features_2, features_3 = patch_outs
  872. return features_0, features_1, features_2, features_3, det_tgt, det_pos
  873. def train(self, mode=True):
  874. """Convert the model into training mode while keep layers freezed."""
  875. super(SwinTransformer, self).train(mode)
  876. self._freeze_stages()
  877. # not working in the current version
  878. def flops(self):
  879. flops = 0
  880. flops += self.patch_embed.flops()
  881. for i, layer in enumerate(self.layers):
  882. flops += layer.flops()
  883. flops += self.num_features * self.patches_resolution[
  884. 0] * self.patches_resolution[1] // (2**self.num_layers)
  885. flops += self.num_features * self.num_classes
  886. return flops