mvit.py 33 KB

12345678910111213141516171819202122232425262728293031323334353637383940414243444546474849505152535455565758596061626364656667686970717273747576777879808182838485868788899091929394959697989910010110210310410510610710810911011111211311411511611711811912012112212312412512612712812913013113213313413513613713813914014114214314414514614714814915015115215315415515615715815916016116216316416516616716816917017117217317417517617717817918018118218318418518618718818919019119219319419519619719819920020120220320420520620720820921021121221321421521621721821922022122222322422522622722822923023123223323423523623723823924024124224324424524624724824925025125225325425525625725825926026126226326426526626726826927027127227327427527627727827928028128228328428528628728828929029129229329429529629729829930030130230330430530630730830931031131231331431531631731831932032132232332432532632732832933033133233333433533633733833934034134234334434534634734834935035135235335435535635735835936036136236336436536636736836937037137237337437537637737837938038138238338438538638738838939039139239339439539639739839940040140240340440540640740840941041141241341441541641741841942042142242342442542642742842943043143243343443543643743843944044144244344444544644744844945045145245345445545645745845946046146246346446546646746846947047147247347447547647747847948048148248348448548648748848949049149249349449549649749849950050150250350450550650750850951051151251351451551651751851952052152252352452552652752852953053153253353453553653753853954054154254354454554654754854955055155255355455555655755855956056156256356456556656756856957057157257357457557657757857958058158258358458558658758858959059159259359459559659759859960060160260360460560660760860961061161261361461561661761861962062162262362462562662762862963063163263363463563663763863964064164264364464564664764864965065165265365465565665765865966066166266366466566666766866967067167267367467567667767867968068168268368468568668768868969069169269369469569669769869970070170270370470570670770870971071171271371471571671771871972072172272372472572672772872973073173273373473573673773873974074174274374474574674774874975075175275375475575675775875976076176276376476576676776876977077177277377477577677777877978078178278378478578678778878979079179279379479579679779879980080180280380480580680780880981081181281381481581681781881982082182282382482582682782882983083183283383483583683783883984084184284384484584684784884985085185285385485585685785885986086186286386486586686786886987087187287387487587687787887988088188288388488588688788888989089189289389489589689789889990090190290390490590690790890991091191291391491591691791891992092192292392492592692792892993093193293393493593693793893994094194294394494594694794894995095195295395495595695795895996096196296396496596696796896997097197297397497597697797897998098198298398498598698798898999099199299399499599699799899910001001100210031004100510061007
  1. # Copyright 2021-2022 The Alibaba DAMO NLP Team Authors.
  2. # Copyright (c) Meta Platforms, Inc. and affiliates. All Rights Reserved. All Rights Reserved.
  3. from collections import OrderedDict
  4. from functools import partial
  5. import numpy as np
  6. import torch
  7. import torch.nn as nn
  8. import torch.nn.functional as F
  9. import torch.utils.checkpoint as checkpoint
  10. from timm.models.layers import trunc_normal_
  11. try:
  12. from fairscale.nn.checkpoint import checkpoint_wrapper
  13. except ImportError:
  14. checkpoint_wrapper = None
  15. MViTv2_Base_config = {
  16. 'depth':
  17. 24,
  18. 'dim_mul': [[2, 2.0], [5, 2.0], [21, 2.0]],
  19. 'head_mul': [[2, 2.0], [5, 2.0], [21, 2.0]],
  20. 'pool_q_stride':
  21. [[0, 1, 1, 1], [1, 1, 1, 1], [2, 1, 2, 2], [3, 1, 1, 1], [4, 1, 1, 1],
  22. [5, 1, 2, 2], [6, 1, 1, 1], [7, 1, 1, 1], [8, 1, 1, 1], [9, 1, 1, 1],
  23. [10, 1, 1, 1], [11, 1, 1, 1], [12, 1, 1, 1], [13, 1, 1, 1], [14, 1, 1, 1],
  24. [15, 1, 1, 1], [16, 1, 1, 1], [17, 1, 1, 1], [18, 1, 1, 1], [19, 1, 1, 1],
  25. [20, 1, 1, 1], [21, 1, 2, 2], [22, 1, 1, 1], [23, 1, 1, 1]],
  26. 'pool_kvq_kernel': [3, 3, 3],
  27. 'pool_kv_stride_adaptive': [1, 4, 4],
  28. }
  29. def interpolate_rel_pos_embed(state_dict_origin,
  30. state_dict_model,
  31. temporal=True,
  32. verbose=False):
  33. rel_pos_embed_types = ['rel_pos_h', 'rel_pos_w']
  34. if temporal:
  35. rel_pos_embed_types += ['rel_pos_t']
  36. state_dict_inflated = state_dict_origin.copy()
  37. for k, v2d in state_dict_origin.items():
  38. if any([x in k for x in rel_pos_embed_types]):
  39. v3d = state_dict_model[k]
  40. if v2d.shape[0] != v3d.shape[0]:
  41. rel_pos_resized = F.interpolate(
  42. v2d.reshape(1, v2d.shape[0], -1).permute(0, 2, 1),
  43. size=v3d.shape[0],
  44. mode='linear',
  45. )
  46. v3d = rel_pos_resized.reshape(-1, v3d.shape[0]).permute(1, 0)
  47. if verbose:
  48. print('Inflate {}: {} -> {}: {}'.format(
  49. k, v2d.shape, k, v3d.shape))
  50. else:
  51. v3d = v2d
  52. state_dict_inflated[k] = v3d.clone()
  53. return state_dict_inflated
  54. def _prepare_mvit_configs(cfg):
  55. depth = cfg['depth']
  56. dim_mul, head_mul = torch.ones(depth + 1), torch.ones(depth + 1)
  57. for i in range(len(cfg['dim_mul'])):
  58. dim_mul[cfg['dim_mul'][i][0]] = cfg['dim_mul'][i][1]
  59. for i in range(len(cfg['head_mul'])):
  60. head_mul[cfg['head_mul'][i][0]] = cfg['head_mul'][i][1]
  61. pool_q = [[] for i in range(depth)]
  62. pool_kv = [[] for i in range(depth)]
  63. stride_q = [[] for i in range(depth)]
  64. stride_kv = [[] for i in range(depth)]
  65. for i in range(len(cfg['pool_q_stride'])):
  66. stride_q[cfg['pool_q_stride'][i][0]] = cfg['pool_q_stride'][i][1:]
  67. pool_q[cfg['pool_q_stride'][i][0]] = cfg['pool_kvq_kernel']
  68. if cfg['pool_kv_stride_adaptive'] is not None:
  69. _stride_kv = cfg['pool_kv_stride_adaptive']
  70. cfg['pool_kv_stride'] = []
  71. for i in range(cfg['depth']):
  72. if len(stride_q[i]) > 0:
  73. _stride_kv = [
  74. max(_stride_kv[d] // stride_q[i][d], 1)
  75. for d in range(len(_stride_kv))
  76. ]
  77. cfg['pool_kv_stride'].append([i] + _stride_kv)
  78. for i in range(len(cfg['pool_kv_stride'])):
  79. stride_kv[cfg['pool_kv_stride'][i][0]] = cfg['pool_kv_stride'][i][1:]
  80. pool_kv[cfg['pool_kv_stride'][i][0]] = cfg['pool_kvq_kernel']
  81. return dim_mul, head_mul, pool_q, pool_kv, stride_q, stride_kv
  82. class Mlp(nn.Module):
  83. def __init__(
  84. self,
  85. in_features,
  86. hidden_features=None,
  87. out_features=None,
  88. act_layer=nn.GELU,
  89. drop_rate=0.0,
  90. ):
  91. super().__init__()
  92. self.drop_rate = drop_rate
  93. out_features = out_features or in_features
  94. hidden_features = hidden_features or in_features
  95. self.fc1 = nn.Linear(in_features, hidden_features)
  96. self.act = act_layer()
  97. self.fc2 = nn.Linear(hidden_features, out_features)
  98. if self.drop_rate > 0.0:
  99. self.drop = nn.Dropout(drop_rate)
  100. def forward(self, x):
  101. x = self.fc1(x)
  102. x = self.act(x)
  103. if self.drop_rate > 0.0:
  104. x = self.drop(x)
  105. x = self.fc2(x)
  106. if self.drop_rate > 0.0:
  107. x = self.drop(x)
  108. return x
  109. class Permute(nn.Module):
  110. def __init__(self, dims):
  111. super().__init__()
  112. self.dims = dims
  113. def forward(self, x):
  114. return x.permute(*self.dims)
  115. def drop_path(x, drop_prob: float = 0.0, training: bool = False):
  116. """
  117. Stochastic Depth per sample.
  118. """
  119. if drop_prob == 0.0 or not training:
  120. return x
  121. keep_prob = 1 - drop_prob
  122. shape = (x.shape[0], ) + (1, ) * (
  123. x.ndim - 1) # work with diff dim tensors, not just 2D ConvNets
  124. mask = keep_prob + torch.rand(shape, dtype=x.dtype, device=x.device)
  125. mask.floor_() # binarize
  126. output = x.div(keep_prob) * mask
  127. return output
  128. class DropPath(nn.Module):
  129. """Drop paths (Stochastic Depth) per sample (when applied in main path of residual blocks)."""
  130. def __init__(self, drop_prob=None):
  131. super(DropPath, self).__init__()
  132. self.drop_prob = drop_prob
  133. def forward(self, x):
  134. return drop_path(x, self.drop_prob, self.training)
  135. def round_width(width, multiplier, min_width=1, divisor=1, verbose=False):
  136. if not multiplier:
  137. return width
  138. width *= multiplier
  139. min_width = min_width or divisor
  140. if verbose:
  141. print(f'min width {min_width}')
  142. print(f'width {width} divisor {divisor}')
  143. print(f'other {int(width + divisor / 2) // divisor * divisor}')
  144. width_out = max(min_width, int(width + divisor / 2) // divisor * divisor)
  145. if width_out < 0.9 * width:
  146. width_out += divisor
  147. return int(width_out)
  148. class PatchEmbed(nn.Module):
  149. """
  150. PatchEmbed.
  151. """
  152. def __init__(
  153. self,
  154. dim_in=3,
  155. dim_out=768,
  156. kernel=(7, 7),
  157. stride=(4, 4),
  158. padding=(3, 3),
  159. conv2d=False,
  160. ):
  161. super().__init__()
  162. if conv2d:
  163. conv_function = nn.Conv2d
  164. else:
  165. conv_function = nn.Conv3d
  166. self.proj = conv_function(
  167. dim_in,
  168. dim_out,
  169. kernel_size=kernel,
  170. stride=stride,
  171. padding=padding,
  172. )
  173. def forward(self, x):
  174. x = self.proj(x)
  175. # B C H W -> B HW C
  176. return x.flatten(2).transpose(1, 2), x.shape
  177. def attention_pool(tensor, pool, thw_shape, has_cls_embed=True, norm=None):
  178. if pool is None:
  179. return tensor, thw_shape
  180. tensor_dim = tensor.ndim
  181. if tensor_dim == 4:
  182. pass
  183. elif tensor_dim == 3:
  184. tensor = tensor.unsqueeze(1)
  185. else:
  186. raise NotImplementedError(
  187. f'Unsupported input dimension {tensor.shape}')
  188. if has_cls_embed:
  189. cls_tok, tensor = tensor[:, :, :1, :], tensor[:, :, 1:, :]
  190. B, N, L, C = tensor.shape
  191. T, H, W = thw_shape
  192. tensor = (
  193. tensor.reshape(B * N, T, H, W, C).permute(0, 4, 1, 2, 3).contiguous())
  194. tensor = pool(tensor)
  195. thw_shape = [tensor.shape[2], tensor.shape[3], tensor.shape[4]]
  196. L_pooled = tensor.shape[2] * tensor.shape[3] * tensor.shape[4]
  197. tensor = tensor.reshape(B, N, C, L_pooled).transpose(2, 3)
  198. if has_cls_embed:
  199. tensor = torch.cat((cls_tok, tensor), dim=2)
  200. if norm is not None:
  201. tensor = norm(tensor)
  202. # Assert tensor_dim in [3, 4]
  203. if tensor_dim == 4:
  204. pass
  205. else: # tensor_dim == 3:
  206. tensor = tensor.squeeze(1)
  207. return tensor, thw_shape
  208. def get_rel_pos(rel_pos, d):
  209. if isinstance(d, int):
  210. ori_d = rel_pos.shape[0]
  211. if ori_d == d:
  212. return rel_pos
  213. else:
  214. # Interpolate rel pos.
  215. new_pos_embed = F.interpolate(
  216. rel_pos.reshape(1, ori_d, -1).permute(0, 2, 1),
  217. size=d,
  218. mode='linear',
  219. )
  220. return new_pos_embed.reshape(-1, d).permute(1, 0)
  221. def cal_rel_pos_spatial(attn, q, k, has_cls_embed, q_shape, k_shape, rel_pos_h,
  222. rel_pos_w):
  223. """
  224. Decomposed Spatial Relative Positional Embeddings.
  225. """
  226. sp_idx = 1 if has_cls_embed else 0
  227. q_t, q_h, q_w = q_shape
  228. k_t, k_h, k_w = k_shape
  229. dh = int(2 * max(q_h, k_h) - 1)
  230. dw = int(2 * max(q_w, k_w) - 1)
  231. # Scale up rel pos if shapes for q and k are different.
  232. q_h_ratio = max(k_h / q_h, 1.0)
  233. k_h_ratio = max(q_h / k_h, 1.0)
  234. dist_h = (
  235. torch.arange(q_h)[:, None] * q_h_ratio
  236. - torch.arange(k_h)[None, :] * k_h_ratio)
  237. dist_h += (k_h - 1) * k_h_ratio
  238. q_w_ratio = max(k_w / q_w, 1.0)
  239. k_w_ratio = max(q_w / k_w, 1.0)
  240. dist_w = (
  241. torch.arange(q_w)[:, None] * q_w_ratio
  242. - torch.arange(k_w)[None, :] * k_w_ratio)
  243. dist_w += (k_w - 1) * k_w_ratio
  244. # Interpolate rel pos if needed.
  245. rel_pos_h = get_rel_pos(rel_pos_h, dh)
  246. rel_pos_w = get_rel_pos(rel_pos_w, dw)
  247. Rh = rel_pos_h[dist_h.long()]
  248. Rw = rel_pos_w[dist_w.long()]
  249. B, n_head, q_N, dim = q.shape
  250. r_q = q[:, :, sp_idx:].reshape(B, n_head, q_t, q_h, q_w, dim)
  251. rel_h_q = torch.einsum('bythwc,hkc->bythwk', r_q,
  252. Rh) # [B, H, q_t, qh, qw, k_h]
  253. rel_w_q = torch.einsum('bythwc,wkc->bythwk', r_q,
  254. Rw) # [B, H, q_t, qh, qw, k_w]
  255. attn[:, :, sp_idx:, sp_idx:] = (
  256. attn[:, :, sp_idx:, sp_idx:].view(B, -1, q_t, q_h, q_w, k_t, k_h, k_w)
  257. + rel_h_q[:, :, :, :, :, None, :, None]
  258. + rel_w_q[:, :, :, :, :, None, None, :]).view(B, -1, q_t * q_h * q_w,
  259. k_t * k_h * k_w)
  260. return attn
  261. def cal_rel_pos_temporal(attn, q, has_cls_embed, q_shape, k_shape, rel_pos_t):
  262. """
  263. Temporal Relative Positional Embeddings.
  264. """
  265. sp_idx = 1 if has_cls_embed else 0
  266. q_t, q_h, q_w = q_shape
  267. k_t, k_h, k_w = k_shape
  268. dt = int(2 * max(q_t, k_t) - 1)
  269. # Interpolate rel pos if needed.
  270. rel_pos_t = get_rel_pos(rel_pos_t, dt)
  271. # Scale up rel pos if shapes for q and k are different.
  272. q_t_ratio = max(k_t / q_t, 1.0)
  273. k_t_ratio = max(q_t / k_t, 1.0)
  274. dist_t = (
  275. torch.arange(q_t)[:, None] * q_t_ratio
  276. - torch.arange(k_t)[None, :] * k_t_ratio)
  277. dist_t += (k_t - 1) * k_t_ratio
  278. Rt = rel_pos_t[dist_t.long()]
  279. B, n_head, q_N, dim = q.shape
  280. r_q = q[:, :, sp_idx:].reshape(B, n_head, q_t, q_h, q_w, dim)
  281. # [B, H, q_t, q_h, q_w, dim] -> [q_t, B, H, q_h, q_w, dim] -> [q_t, B*H*q_h*q_w, dim]
  282. r_q = r_q.permute(2, 0, 1, 3, 4, 5).reshape(q_t, B * n_head * q_h * q_w,
  283. dim)
  284. # [q_t, B*H*q_h*q_w, dim] * [q_t, dim, k_t] = [q_t, B*H*q_h*q_w, k_t] -> [B*H*q_h*q_w, q_t, k_t]
  285. rel = torch.matmul(r_q, Rt.transpose(1, 2)).transpose(0, 1)
  286. # [B*H*q_h*q_w, q_t, k_t] -> [B, H, q_t, q_h, q_w, k_t]
  287. rel = rel.view(B, n_head, q_h, q_w, q_t, k_t).permute(0, 1, 4, 2, 3, 5)
  288. attn[:, :, sp_idx:, sp_idx:] = (
  289. attn[:, :, sp_idx:, sp_idx:].view(B, -1, q_t, q_h, q_w, k_t, k_h, k_w)
  290. + rel[:, :, :, :, :, :, None, None]).view(B, -1, q_t * q_h * q_w,
  291. k_t * k_h * k_w)
  292. return attn
  293. class MultiScaleAttention(nn.Module):
  294. def __init__(
  295. self,
  296. dim,
  297. dim_out,
  298. input_size,
  299. num_heads=8,
  300. qkv_bias=False,
  301. drop_rate=0.0,
  302. kernel_q=(1, 1, 1),
  303. kernel_kv=(1, 1, 1),
  304. stride_q=(1, 1, 1),
  305. stride_kv=(1, 1, 1),
  306. norm_layer=nn.LayerNorm,
  307. has_cls_embed=True,
  308. # Options include `conv`, `avg`, and `max`.
  309. mode='conv',
  310. # If True, perform pool before projection.
  311. pool_first=False,
  312. rel_pos_spatial=False,
  313. rel_pos_temporal=False,
  314. rel_pos_zero_init=False,
  315. residual_pooling=True,
  316. separate_qkv=False,
  317. ):
  318. super().__init__()
  319. self.pool_first = pool_first
  320. self.separate_qkv = separate_qkv
  321. self.drop_rate = drop_rate
  322. self.num_heads = num_heads
  323. self.dim_out = dim_out
  324. head_dim = dim_out // num_heads
  325. self.scale = head_dim**-0.5
  326. self.has_cls_embed = has_cls_embed
  327. padding_q = [int(q // 2) for q in kernel_q]
  328. padding_kv = [int(kv // 2) for kv in kernel_kv]
  329. if pool_first or separate_qkv:
  330. self.q = nn.Linear(dim, dim_out, bias=qkv_bias)
  331. self.k = nn.Linear(dim, dim_out, bias=qkv_bias)
  332. self.v = nn.Linear(dim, dim_out, bias=qkv_bias)
  333. else:
  334. self.qkv = nn.Linear(dim, dim_out * 3, bias=qkv_bias)
  335. self.proj = nn.Linear(dim_out, dim_out)
  336. if drop_rate > 0.0:
  337. self.proj_drop = nn.Dropout(drop_rate)
  338. # Skip pooling with kernel and stride size of (1, 1, 1).
  339. if np.prod(kernel_q) == 1 and np.prod(stride_q) == 1:
  340. kernel_q = ()
  341. if np.prod(kernel_kv) == 1 and np.prod(stride_kv) == 1:
  342. kernel_kv = ()
  343. self.mode = mode
  344. if mode in ('avg', 'max'):
  345. pool_op = nn.MaxPool3d if mode == 'max' else nn.AvgPool3d
  346. self.pool_q = (
  347. pool_op(kernel_q, stride_q, padding_q, ceil_mode=False)
  348. if len(kernel_q) > 0 else None)
  349. self.pool_k = (
  350. pool_op(kernel_kv, stride_kv, padding_kv, ceil_mode=False)
  351. if len(kernel_kv) > 0 else None)
  352. self.pool_v = (
  353. pool_op(kernel_kv, stride_kv, padding_kv, ceil_mode=False)
  354. if len(kernel_kv) > 0 else None)
  355. elif mode == 'conv' or mode == 'conv_unshared':
  356. if pool_first:
  357. dim_conv = dim // num_heads if mode == 'conv' else dim
  358. else:
  359. dim_conv = dim_out // num_heads if mode == 'conv' else dim_out
  360. self.pool_q = (
  361. nn.Conv3d(
  362. dim_conv,
  363. dim_conv,
  364. kernel_q,
  365. stride=stride_q,
  366. padding=padding_q,
  367. groups=dim_conv,
  368. bias=False,
  369. ) if len(kernel_q) > 0 else None)
  370. self.norm_q = norm_layer(dim_conv) if len(kernel_q) > 0 else None
  371. self.pool_k = (
  372. nn.Conv3d(
  373. dim_conv,
  374. dim_conv,
  375. kernel_kv,
  376. stride=stride_kv,
  377. padding=padding_kv,
  378. groups=dim_conv,
  379. bias=False,
  380. ) if len(kernel_kv) > 0 else None)
  381. self.norm_k = norm_layer(dim_conv) if len(kernel_kv) > 0 else None
  382. self.pool_v = (
  383. nn.Conv3d(
  384. dim_conv,
  385. dim_conv,
  386. kernel_kv,
  387. stride=stride_kv,
  388. padding=padding_kv,
  389. groups=dim_conv,
  390. bias=False,
  391. ) if len(kernel_kv) > 0 else None)
  392. self.norm_v = norm_layer(dim_conv) if len(kernel_kv) > 0 else None
  393. else:
  394. raise NotImplementedError(f'Unsupported model {mode}')
  395. self.rel_pos_spatial = rel_pos_spatial
  396. self.rel_pos_temporal = rel_pos_temporal
  397. if self.rel_pos_spatial:
  398. assert input_size[1] == input_size[2]
  399. size = input_size[1]
  400. q_size = size // stride_q[1] if len(stride_q) > 0 else size
  401. kv_size = size // stride_kv[1] if len(stride_kv) > 0 else size
  402. rel_sp_dim = 2 * max(q_size, kv_size) - 1
  403. self.rel_pos_h = nn.Parameter(torch.zeros(rel_sp_dim, head_dim))
  404. self.rel_pos_w = nn.Parameter(torch.zeros(rel_sp_dim, head_dim))
  405. if not rel_pos_zero_init:
  406. trunc_normal_(self.rel_pos_h, std=0.02)
  407. trunc_normal_(self.rel_pos_w, std=0.02)
  408. if self.rel_pos_temporal:
  409. self.rel_pos_t = nn.Parameter(
  410. torch.zeros(2 * input_size[0] - 1, head_dim))
  411. # if not rel_pos_zero_init:
  412. # trunc_normal_(self.rel_pos_t, std=0.02)
  413. self.residual_pooling = residual_pooling
  414. def forward(self, x, thw_shape):
  415. B, N, _ = x.shape
  416. if self.pool_first:
  417. if self.mode == 'conv_unshared':
  418. fold_dim = 1
  419. else:
  420. fold_dim = self.num_heads
  421. x = x.reshape(B, N, fold_dim, -1).permute(0, 2, 1, 3)
  422. q = k = v = x
  423. else:
  424. assert self.mode != 'conv_unshared'
  425. if not self.separate_qkv:
  426. qkv = (
  427. self.qkv(x).reshape(B, N, 3, self.num_heads,
  428. -1).permute(2, 0, 3, 1, 4))
  429. q, k, v = qkv[0], qkv[1], qkv[2]
  430. else:
  431. q = k = v = x
  432. q = (
  433. self.q(q).reshape(B, N, self.num_heads,
  434. -1).permute(0, 2, 1, 3))
  435. k = (
  436. self.k(k).reshape(B, N, self.num_heads,
  437. -1).permute(0, 2, 1, 3))
  438. v = (
  439. self.v(v).reshape(B, N, self.num_heads,
  440. -1).permute(0, 2, 1, 3))
  441. q, q_shape = attention_pool(
  442. q,
  443. self.pool_q,
  444. thw_shape,
  445. has_cls_embed=self.has_cls_embed,
  446. norm=self.norm_q if hasattr(self, 'norm_q') else None,
  447. )
  448. k, k_shape = attention_pool(
  449. k,
  450. self.pool_k,
  451. thw_shape,
  452. has_cls_embed=self.has_cls_embed,
  453. norm=self.norm_k if hasattr(self, 'norm_k') else None,
  454. )
  455. v, v_shape = attention_pool(
  456. v,
  457. self.pool_v,
  458. thw_shape,
  459. has_cls_embed=self.has_cls_embed,
  460. norm=self.norm_v if hasattr(self, 'norm_v') else None,
  461. )
  462. if self.pool_first:
  463. q_N = (
  464. np.prod(q_shape)
  465. + 1 if self.has_cls_embed else np.prod(q_shape))
  466. k_N = (
  467. np.prod(k_shape)
  468. + 1 if self.has_cls_embed else np.prod(k_shape))
  469. v_N = (
  470. np.prod(v_shape)
  471. + 1 if self.has_cls_embed else np.prod(v_shape))
  472. q = q.permute(0, 2, 1, 3).reshape(B, q_N, -1)
  473. q = (
  474. self.q(q).reshape(B, q_N, self.num_heads,
  475. -1).permute(0, 2, 1, 3))
  476. v = v.permute(0, 2, 1, 3).reshape(B, v_N, -1)
  477. v = (
  478. self.v(v).reshape(B, v_N, self.num_heads,
  479. -1).permute(0, 2, 1, 3))
  480. k = k.permute(0, 2, 1, 3).reshape(B, k_N, -1)
  481. k = (
  482. self.k(k).reshape(B, k_N, self.num_heads,
  483. -1).permute(0, 2, 1, 3))
  484. N = q.shape[2]
  485. attn = (q * self.scale) @ k.transpose(-2, -1)
  486. if self.rel_pos_spatial:
  487. attn = cal_rel_pos_spatial(
  488. attn,
  489. q,
  490. k,
  491. self.has_cls_embed,
  492. q_shape,
  493. k_shape,
  494. self.rel_pos_h,
  495. self.rel_pos_w,
  496. )
  497. if self.rel_pos_temporal:
  498. attn = cal_rel_pos_temporal(
  499. attn,
  500. q,
  501. self.has_cls_embed,
  502. q_shape,
  503. k_shape,
  504. self.rel_pos_t,
  505. )
  506. attn = attn.softmax(dim=-1)
  507. x = attn @ v
  508. if self.residual_pooling:
  509. # Minor Difference
  510. if self.has_cls_embed:
  511. x[:, :, 1:, :] += q[:, :, 1:, :]
  512. else:
  513. x = x + q
  514. x = x.transpose(1, 2).reshape(B, -1, self.dim_out)
  515. x = self.proj(x)
  516. if self.drop_rate > 0.0:
  517. x = self.proj_drop(x)
  518. return x, q_shape
  519. class MultiScaleBlock(nn.Module):
  520. def __init__(
  521. self,
  522. dim,
  523. dim_out,
  524. num_heads,
  525. input_size,
  526. mlp_ratio=4.0,
  527. qkv_bias=False,
  528. qk_scale=None,
  529. drop_rate=0.0,
  530. drop_path=0.0,
  531. act_layer=nn.GELU,
  532. norm_layer=nn.LayerNorm,
  533. up_rate=None,
  534. kernel_q=(1, 1, 1),
  535. kernel_kv=(1, 1, 1),
  536. stride_q=(1, 1, 1),
  537. stride_kv=(1, 1, 1),
  538. mode='conv',
  539. has_cls_embed=True,
  540. pool_first=False,
  541. rel_pos_spatial=False,
  542. rel_pos_temporal=False,
  543. rel_pos_zero_init=False,
  544. residual_pooling=True,
  545. dim_mul_in_att=False,
  546. separate_qkv=False,
  547. use_grad_checkpoint=False,
  548. ):
  549. super().__init__()
  550. self.dim = dim
  551. self.dim_out = dim_out
  552. self.norm1 = norm_layer(dim)
  553. self.dim_mul_in_att = dim_mul_in_att
  554. kernel_skip = [s + 1 if s > 1 else s for s in stride_q]
  555. stride_skip = stride_q
  556. padding_skip = [int(skip // 2) for skip in kernel_skip]
  557. att_dim = dim_out if dim_mul_in_att else dim
  558. self.use_grad_checkpoint = use_grad_checkpoint
  559. self.attn = MultiScaleAttention(
  560. dim,
  561. att_dim,
  562. num_heads=num_heads,
  563. input_size=input_size,
  564. qkv_bias=qkv_bias,
  565. drop_rate=drop_rate,
  566. kernel_q=kernel_q,
  567. kernel_kv=kernel_kv,
  568. stride_q=stride_q,
  569. stride_kv=stride_kv,
  570. norm_layer=norm_layer,
  571. has_cls_embed=has_cls_embed,
  572. mode=mode,
  573. pool_first=pool_first,
  574. rel_pos_spatial=rel_pos_spatial,
  575. rel_pos_temporal=rel_pos_temporal,
  576. rel_pos_zero_init=rel_pos_zero_init,
  577. residual_pooling=residual_pooling,
  578. separate_qkv=separate_qkv,
  579. )
  580. self.drop_path = (
  581. DropPath(drop_path) if drop_path > 0.0 else nn.Identity())
  582. self.norm2 = norm_layer(att_dim)
  583. mlp_hidden_dim = int(att_dim * mlp_ratio)
  584. self.has_cls_embed = has_cls_embed
  585. # TODO: check the use case for up_rate, and merge the following lines
  586. if up_rate is not None and up_rate > 1:
  587. mlp_dim_out = dim * up_rate
  588. else:
  589. mlp_dim_out = dim_out
  590. self.mlp = Mlp(
  591. in_features=att_dim,
  592. hidden_features=mlp_hidden_dim,
  593. out_features=mlp_dim_out,
  594. act_layer=act_layer,
  595. drop_rate=drop_rate,
  596. )
  597. if dim != dim_out:
  598. self.proj = nn.Linear(dim, dim_out)
  599. self.pool_skip = (
  600. nn.MaxPool3d(
  601. kernel_skip, stride_skip, padding_skip, ceil_mode=False)
  602. if len(kernel_skip) > 0 else None)
  603. def forward(self, x, thw_shape):
  604. x_norm = self.norm1(x)
  605. if self.use_grad_checkpoint:
  606. x_block, thw_shape_new = checkpoint.checkpoint(
  607. self.attn, x_norm, thw_shape)
  608. else:
  609. x_block, thw_shape_new = self.attn(x_norm, thw_shape)
  610. if self.dim_mul_in_att and self.dim != self.dim_out:
  611. x = self.proj(x_norm)
  612. x_res, _ = attention_pool(
  613. x, self.pool_skip, thw_shape, has_cls_embed=self.has_cls_embed)
  614. x = x_res + self.drop_path(x_block)
  615. x_norm = self.norm2(x)
  616. if self.use_grad_checkpoint:
  617. x_mlp = checkpoint.checkpoint(self.mlp, x_norm)
  618. else:
  619. x_mlp = self.mlp(x_norm)
  620. if not self.dim_mul_in_att and self.dim != self.dim_out:
  621. x = self.proj(x_norm)
  622. x = x + self.drop_path(x_mlp)
  623. return x, thw_shape_new
  624. class MViTv2(nn.Module):
  625. """
  626. Improved Multiscale Vision Transformers for Classification and Detection
  627. Yanghao Li*, Chao-Yuan Wu*, Haoqi Fan, Karttikeya Mangalam, Bo Xiong, Jitendra Malik,
  628. Christoph Feichtenhofer*
  629. https://arxiv.org/abs/2112.01526
  630. Multiscale Vision Transformers
  631. Haoqi Fan*, Bo Xiong*, Karttikeya Mangalam*, Yanghao Li*, Zhicheng Yan, Jitendra Malik,
  632. Christoph Feichtenhofer*
  633. https://arxiv.org/abs/2104.11227
  634. """
  635. def __init__(
  636. self,
  637. img_size=224,
  638. embed_dim=96,
  639. num_classes=1000,
  640. num_frames=4,
  641. num_heads=1,
  642. depth=24,
  643. patch_kernel=[3, 7, 7],
  644. patch_stride=[2, 4, 4],
  645. patch_padding=[1, 3, 3],
  646. config=None,
  647. dropout_rate=0.,
  648. drop_path_rate=0.,
  649. mlp_ratio=4.,
  650. qkv_bias=True,
  651. mode='conv',
  652. cls_embed_on=True,
  653. use_abs_pos=False,
  654. rel_pos_spatial=True,
  655. rel_pos_temporal=True,
  656. rel_pos_zero_init=False,
  657. residual_pooling=True,
  658. dim_mul_in_att=True,
  659. pool_first=False,
  660. zero_decay_pos_cls=False,
  661. separate_qkv=False,
  662. norm_stem=False,
  663. sep_pos_embed=False,
  664. use_grad_checkpoint=True,
  665. ):
  666. super().__init__()
  667. # Prepare input.
  668. in_chans = 3
  669. self.img_size = img_size
  670. # Prepare output.
  671. self.num_classes = num_classes
  672. self.embed_dim = embed_dim
  673. # MViT params.
  674. self.num_heads = num_heads
  675. self.depth = depth
  676. self.cls_embed_on = cls_embed_on
  677. self.use_abs_pos = use_abs_pos
  678. self.zero_decay_pos_cls = zero_decay_pos_cls
  679. self.use_grad_checkpoint = use_grad_checkpoint
  680. self.sep_pos_embed = sep_pos_embed
  681. self.drop_rate = dropout_rate
  682. norm_layer = partial(nn.LayerNorm, eps=1e-6)
  683. if use_grad_checkpoint:
  684. self.patch_embed = checkpoint_wrapper(
  685. PatchEmbed(
  686. dim_in=in_chans,
  687. dim_out=embed_dim,
  688. kernel=patch_kernel,
  689. stride=patch_stride,
  690. padding=patch_padding,
  691. ))
  692. else:
  693. self.patch_embed = PatchEmbed(
  694. dim_in=in_chans,
  695. dim_out=embed_dim,
  696. kernel=patch_kernel,
  697. stride=patch_stride,
  698. padding=patch_padding,
  699. )
  700. patch_dims = [
  701. num_frames // patch_stride[0],
  702. img_size // patch_stride[1],
  703. img_size // patch_stride[2],
  704. ]
  705. num_patches = np.prod(patch_dims)
  706. dpr = [x.item() for x in torch.linspace(0, drop_path_rate, depth)
  707. ] # stochastic depth decay rule
  708. if self.cls_embed_on:
  709. self.cls_token = nn.Parameter(torch.zeros(1, 1, embed_dim))
  710. pos_embed_dim = num_patches + 1
  711. else:
  712. pos_embed_dim = num_patches
  713. if self.use_abs_pos:
  714. self.pos_embed = nn.Parameter(
  715. torch.zeros(1, pos_embed_dim, embed_dim))
  716. if self.use_abs_pos:
  717. if self.sep_pos_embed:
  718. self.pos_embed_spatial = nn.Parameter(
  719. torch.zeros(1, self.patch_dims[1] * self.patch_dims[2],
  720. embed_dim))
  721. self.pos_embed_temporal = nn.Parameter(
  722. torch.zeros(1, self.patch_dims[0], embed_dim))
  723. if self.cls_embed_on:
  724. self.pos_embed_class = nn.Parameter(
  725. torch.zeros(1, 1, embed_dim))
  726. else:
  727. self.pos_embed = nn.Parameter(
  728. torch.zeros(1, pos_embed_dim, embed_dim))
  729. assert config is not None
  730. # MViT backbone configs
  731. dim_mul, head_mul, pool_q, pool_kv, stride_q, stride_kv = _prepare_mvit_configs(
  732. config)
  733. input_size = patch_dims
  734. self.norm_stem = norm_layer(embed_dim) if norm_stem else None
  735. self.blocks = nn.ModuleList()
  736. for i in range(depth):
  737. num_heads = round_width(num_heads, head_mul[i])
  738. if dim_mul_in_att:
  739. dim_out = round_width(
  740. embed_dim,
  741. dim_mul[i],
  742. divisor=round_width(num_heads, head_mul[i]),
  743. )
  744. else:
  745. dim_out = round_width(
  746. embed_dim,
  747. dim_mul[i + 1],
  748. divisor=round_width(num_heads, head_mul[i + 1]),
  749. )
  750. attention_block = MultiScaleBlock(
  751. dim=embed_dim,
  752. dim_out=dim_out,
  753. num_heads=num_heads,
  754. input_size=input_size,
  755. mlp_ratio=mlp_ratio,
  756. qkv_bias=qkv_bias,
  757. drop_rate=self.drop_rate,
  758. drop_path=dpr[i],
  759. norm_layer=norm_layer,
  760. kernel_q=pool_q[i] if len(pool_q) > i else [],
  761. kernel_kv=pool_kv[i] if len(pool_kv) > i else [],
  762. stride_q=stride_q[i] if len(stride_q) > i else [],
  763. stride_kv=stride_kv[i] if len(stride_kv) > i else [],
  764. mode=mode,
  765. has_cls_embed=self.cls_embed_on,
  766. pool_first=pool_first,
  767. rel_pos_spatial=rel_pos_spatial,
  768. rel_pos_temporal=rel_pos_temporal,
  769. rel_pos_zero_init=rel_pos_zero_init,
  770. residual_pooling=residual_pooling,
  771. dim_mul_in_att=dim_mul_in_att,
  772. separate_qkv=separate_qkv,
  773. use_grad_checkpoint=False)
  774. if use_grad_checkpoint:
  775. attention_block = checkpoint_wrapper(
  776. attention_block, offload_to_cpu=False)
  777. self.blocks.append(attention_block)
  778. if len(stride_q[i]) > 0:
  779. input_size = [
  780. size // stride
  781. for size, stride in zip(input_size, stride_q[i])
  782. ]
  783. embed_dim = dim_out
  784. self.norm = norm_layer(embed_dim)
  785. self.head = nn.Identity()
  786. if self.use_abs_pos:
  787. if self.sep_pos_embed:
  788. trunc_normal_(self.pos_embed_spatial, std=0.02)
  789. trunc_normal_(self.pos_embed_temporal, std=0.02)
  790. if self.cls_embed_on:
  791. trunc_normal_(self.pos_embed_class, std=0.02)
  792. else:
  793. trunc_normal_(self.pos_embed, std=0.02)
  794. if self.cls_embed_on:
  795. trunc_normal_(self.cls_token, std=0.02)
  796. self.apply(self._init_weights)
  797. def _init_weights(self, m):
  798. if isinstance(m, nn.Linear):
  799. nn.init.trunc_normal_(m.weight, std=0.02)
  800. if isinstance(m, nn.Linear) and m.bias is not None:
  801. nn.init.constant_(m.bias, 0)
  802. elif isinstance(m, nn.LayerNorm):
  803. nn.init.constant_(m.bias, 0)
  804. nn.init.constant_(m.weight, 1.0)
  805. @torch.jit.ignore
  806. def no_weight_decay(self):
  807. names = []
  808. if self.zero_decay_pos_cls:
  809. if self.use_abs_pos:
  810. if self.sep_pos_embed:
  811. names.extend([
  812. 'pos_embed_spatial',
  813. 'pos_embed_temporal',
  814. 'pos_embed_class',
  815. ])
  816. else:
  817. names.append(['pos_embed'])
  818. if self.rel_pos_spatial:
  819. names.extend(['rel_pos_h', 'rel_pos_w', 'rel_pos_hw'])
  820. if self.rel_pos_temporal:
  821. names.extend(['rel_pos_t'])
  822. if self.cls_embed_on:
  823. names.append('cls_token')
  824. return names
  825. def _get_pos_embed(self, pos_embed, bcthw):
  826. t, h, w = bcthw[-3], bcthw[-2], bcthw[-1]
  827. if self.cls_embed_on:
  828. cls_pos_embed = pos_embed[:, 0:1, :]
  829. pos_embed = pos_embed[:, 1:]
  830. txy_num = pos_embed.shape[1]
  831. p_t, p_h, p_w = self.patch_dims
  832. assert p_t * p_h * p_w == txy_num
  833. if (p_t, p_h, p_w) != (t, h, w):
  834. new_pos_embed = F.interpolate(
  835. pos_embed[:, :, :].reshape(1, p_t, p_h, p_w,
  836. -1).permute(0, 4, 1, 2, 3),
  837. size=(t, h, w),
  838. mode='trilinear',
  839. )
  840. pos_embed = new_pos_embed.reshape(1, -1,
  841. t * h * w).permute(0, 2, 1)
  842. if self.cls_embed_on:
  843. pos_embed = torch.cat((cls_pos_embed, pos_embed), dim=1)
  844. return pos_embed
  845. def forward_features(self, x):
  846. x = x.permute(0, 2, 1, 3, 4)
  847. x, bcthw = self.patch_embed(x)
  848. T, H, W = bcthw[-3], bcthw[-2], bcthw[-1]
  849. B, N, C = x.shape
  850. if self.cls_embed_on:
  851. cls_tokens = self.cls_token.expand(
  852. B, -1, -1) # stole cls_tokens impl from Phil Wang, thanks
  853. x = torch.cat((cls_tokens, x), dim=1)
  854. if self.use_abs_pos:
  855. if self.sep_pos_embed:
  856. pos_embed = self.pos_embed_spatial.repeat(
  857. 1, self.patch_dims[0], 1) + torch.repeat_interleave(
  858. self.pos_embed_temporal,
  859. self.patch_dims[1] * self.patch_dims[2],
  860. dim=1)
  861. if self.cls_embed_on:
  862. pos_embed = torch.cat([self.pos_embed_class, pos_embed], 1)
  863. pos_embed = self._get_pos_embed(pos_embed, bcthw)
  864. x = x + pos_embed
  865. else:
  866. pos_embed = self._get_pos_embed(self.pos_embed, bcthw)
  867. x = x + pos_embed
  868. if self.drop_rate:
  869. x = self.pos_drop(x)
  870. if self.norm_stem:
  871. x = self.norm_stem(x)
  872. thw = [T, H, W]
  873. for blk in self.blocks:
  874. x, thw = blk(x, thw)
  875. x = self.norm(x)
  876. return x
  877. def forward(self, x):
  878. x = self.forward_features(x)
  879. x = self.head(x)
  880. return x