coat.py 31 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525526527528529530531532533534535536537538539540541542543544545546547548549550551552553554555556557558559560561562563564565566567568569570571572573574575576577578579580581582583584585586587588589590591592593594595596597598599600601602603604605606607608609610611612613614615616617618619620621622623624625626627628629630631632633634635636637638639640641642643644645646647648649650651652653654655656657658659660661662663664665666667668669670671672673674675676677678679680681682683684685686687688689690691692693694695696697698699700701702703704705706707708709710711712713714715716717718719720721722723724725726727728729730731732733734735736737738739740741742743744745746747748749750751752753754755756757758759760761762763764765766767768769770771772773774775776777778779780781782783784785786787788789790791792793794795796797798799800801802803804805806807808809810811812813814815816817818819820821822823824825826827828829830831832833834835836837838839840841842843
  1. """
  2. CoaT architecture.
  3. Paper: Co-Scale Conv-Attentional Image Transformers - https://arxiv.org/abs/2104.06399
  4. Official CoaT code at: https://github.com/mlpc-ucsd/CoaT
  5. Modified from timm/models/vision_transformer.py
  6. """
  7. from typing import List, Optional, Tuple, Union, Type, Any
  8. import torch
  9. import torch.nn as nn
  10. import torch.nn.functional as F
  11. from timm.data import IMAGENET_DEFAULT_MEAN, IMAGENET_DEFAULT_STD
  12. from timm.layers import PatchEmbed, Mlp, DropPath, to_2tuple, trunc_normal_, _assert, LayerNorm
  13. from ._builder import build_model_with_cfg
  14. from ._registry import register_model, generate_default_cfgs
  15. __all__ = ['CoaT']
  16. class ConvRelPosEnc(nn.Module):
  17. """ Convolutional relative position encoding. """
  18. def __init__(
  19. self,
  20. head_chs: int,
  21. num_heads: int,
  22. window: Union[int, dict],
  23. device=None,
  24. dtype=None,
  25. ):
  26. """
  27. Initialization.
  28. Ch: Channels per head.
  29. h: Number of heads.
  30. window: Window size(s) in convolutional relative positional encoding. It can have two forms:
  31. 1. An integer of window size, which assigns all attention heads with the same window s
  32. size in ConvRelPosEnc.
  33. 2. A dict mapping window size to #attention head splits (
  34. e.g. {window size 1: #attention head split 1, window size 2: #attention head split 2})
  35. It will apply different window size to the attention head splits.
  36. """
  37. dd = {'device': device, 'dtype': dtype}
  38. super().__init__()
  39. if isinstance(window, int):
  40. # Set the same window size for all attention heads.
  41. window = {window: num_heads}
  42. self.window = window
  43. elif isinstance(window, dict):
  44. self.window = window
  45. else:
  46. raise ValueError()
  47. self.conv_list = nn.ModuleList()
  48. self.head_splits = []
  49. for cur_window, cur_head_split in window.items():
  50. dilation = 1
  51. # Determine padding size.
  52. # Ref: https://discuss.pytorch.org/t/how-to-keep-the-shape-of-input-and-output-same-when-dilation-conv/14338
  53. padding_size = (cur_window + (cur_window - 1) * (dilation - 1)) // 2
  54. cur_conv = nn.Conv2d(
  55. cur_head_split * head_chs,
  56. cur_head_split * head_chs,
  57. kernel_size=(cur_window, cur_window),
  58. padding=(padding_size, padding_size),
  59. dilation=(dilation, dilation),
  60. groups=cur_head_split * head_chs,
  61. **dd,
  62. )
  63. self.conv_list.append(cur_conv)
  64. self.head_splits.append(cur_head_split)
  65. self.channel_splits = [x * head_chs for x in self.head_splits]
  66. def forward(self, q, v, size: Tuple[int, int]):
  67. B, num_heads, N, C = q.shape
  68. H, W = size
  69. _assert(N == 1 + H * W, '')
  70. # Convolutional relative position encoding.
  71. q_img = q[:, :, 1:, :] # [B, h, H*W, Ch]
  72. v_img = v[:, :, 1:, :] # [B, h, H*W, Ch]
  73. v_img = v_img.transpose(-1, -2).reshape(B, num_heads * C, H, W)
  74. v_img_list = torch.split(v_img, self.channel_splits, dim=1) # Split according to channels
  75. conv_v_img_list = []
  76. for i, conv in enumerate(self.conv_list):
  77. conv_v_img_list.append(conv(v_img_list[i]))
  78. conv_v_img = torch.cat(conv_v_img_list, dim=1)
  79. conv_v_img = conv_v_img.reshape(B, num_heads, C, H * W).transpose(-1, -2)
  80. EV_hat = q_img * conv_v_img
  81. EV_hat = F.pad(EV_hat, (0, 0, 1, 0, 0, 0)) # [B, h, N, Ch].
  82. return EV_hat
  83. class FactorAttnConvRelPosEnc(nn.Module):
  84. """ Factorized attention with convolutional relative position encoding class. """
  85. def __init__(
  86. self,
  87. dim: int,
  88. num_heads: int = 8,
  89. qkv_bias: bool = False,
  90. attn_drop: float = 0.,
  91. proj_drop: float = 0.,
  92. shared_crpe: Optional[Any] = None,
  93. device=None,
  94. dtype=None,
  95. ):
  96. dd = {'device': device, 'dtype': dtype}
  97. super().__init__()
  98. self.num_heads = num_heads
  99. head_dim = dim // num_heads
  100. self.scale = head_dim ** -0.5
  101. self.qkv = nn.Linear(dim, dim * 3, bias=qkv_bias, **dd)
  102. self.attn_drop = nn.Dropout(attn_drop) # Note: attn_drop is actually not used.
  103. self.proj = nn.Linear(dim, dim, **dd)
  104. self.proj_drop = nn.Dropout(proj_drop)
  105. # Shared convolutional relative position encoding.
  106. self.crpe = shared_crpe
  107. def forward(self, x, size: Tuple[int, int]):
  108. B, N, C = x.shape
  109. # Generate Q, K, V.
  110. qkv = self.qkv(x).reshape(B, N, 3, self.num_heads, C // self.num_heads).permute(2, 0, 3, 1, 4)
  111. q, k, v = qkv.unbind(0) # [B, h, N, Ch]
  112. # Factorized attention.
  113. k_softmax = k.softmax(dim=2)
  114. factor_att = k_softmax.transpose(-1, -2) @ v
  115. factor_att = q @ factor_att
  116. # Convolutional relative position encoding.
  117. crpe = self.crpe(q, v, size=size) # [B, h, N, Ch]
  118. # Merge and reshape.
  119. x = self.scale * factor_att + crpe
  120. x = x.transpose(1, 2).reshape(B, N, C) # [B, h, N, Ch] -> [B, N, h, Ch] -> [B, N, C]
  121. # Output projection.
  122. x = self.proj(x)
  123. x = self.proj_drop(x)
  124. return x
  125. class ConvPosEnc(nn.Module):
  126. """ Convolutional Position Encoding.
  127. Note: This module is similar to the conditional position encoding in CPVT.
  128. """
  129. def __init__(
  130. self,
  131. dim: int,
  132. k: int = 3,
  133. device=None,
  134. dtype=None,
  135. ):
  136. dd = {'device': device, 'dtype': dtype}
  137. super().__init__()
  138. self.proj = nn.Conv2d(dim, dim, k, 1, k//2, groups=dim, **dd)
  139. def forward(self, x, size: Tuple[int, int]):
  140. B, N, C = x.shape
  141. H, W = size
  142. _assert(N == 1 + H * W, '')
  143. # Extract CLS token and image tokens.
  144. cls_token, img_tokens = x[:, :1], x[:, 1:] # [B, 1, C], [B, H*W, C]
  145. # Depthwise convolution.
  146. feat = img_tokens.transpose(1, 2).view(B, C, H, W)
  147. x = self.proj(feat) + feat
  148. x = x.flatten(2).transpose(1, 2)
  149. # Combine with CLS token.
  150. x = torch.cat((cls_token, x), dim=1)
  151. return x
  152. class SerialBlock(nn.Module):
  153. """ Serial block class.
  154. Note: In this implementation, each serial block only contains a conv-attention and a FFN (MLP) module. """
  155. def __init__(
  156. self,
  157. dim: int,
  158. num_heads: int,
  159. mlp_ratio: float = 4.,
  160. qkv_bias: bool = False,
  161. proj_drop: float = 0.,
  162. attn_drop: float = 0.,
  163. drop_path: float = 0.,
  164. act_layer: Type[nn.Module] = nn.GELU,
  165. norm_layer: Type[nn.Module] = nn.LayerNorm,
  166. shared_cpe: Optional[Any] = None,
  167. shared_crpe: Optional[Any] = None,
  168. device=None,
  169. dtype=None,
  170. ):
  171. dd = {'device': device, 'dtype': dtype}
  172. super().__init__()
  173. # Conv-Attention.
  174. self.cpe = shared_cpe
  175. self.norm1 = norm_layer(dim, **dd)
  176. self.factoratt_crpe = FactorAttnConvRelPosEnc(
  177. dim,
  178. num_heads=num_heads,
  179. qkv_bias=qkv_bias,
  180. attn_drop=attn_drop,
  181. proj_drop=proj_drop,
  182. shared_crpe=shared_crpe,
  183. **dd,
  184. )
  185. self.drop_path = DropPath(drop_path) if drop_path > 0. else nn.Identity()
  186. # MLP.
  187. self.norm2 = norm_layer(dim, **dd)
  188. mlp_hidden_dim = int(dim * mlp_ratio)
  189. self.mlp = Mlp(
  190. in_features=dim,
  191. hidden_features=mlp_hidden_dim,
  192. act_layer=act_layer,
  193. drop=proj_drop,
  194. **dd,
  195. )
  196. def forward(self, x, size: Tuple[int, int]):
  197. # Conv-Attention.
  198. x = self.cpe(x, size)
  199. cur = self.norm1(x)
  200. cur = self.factoratt_crpe(cur, size)
  201. x = x + self.drop_path(cur)
  202. # MLP.
  203. cur = self.norm2(x)
  204. cur = self.mlp(cur)
  205. x = x + self.drop_path(cur)
  206. return x
  207. class ParallelBlock(nn.Module):
  208. """ Parallel block class. """
  209. def __init__(
  210. self,
  211. dims: List[int],
  212. num_heads: int,
  213. mlp_ratios: List[float] = None,
  214. qkv_bias: bool = False,
  215. proj_drop: float = 0.,
  216. attn_drop: float = 0.,
  217. drop_path: float = 0.,
  218. act_layer: Type[nn.Module] = nn.GELU,
  219. norm_layer: Type[nn.Module] = nn.LayerNorm,
  220. shared_crpes: Optional[List[Any]] = None,
  221. device=None,
  222. dtype=None,
  223. ):
  224. dd = {'device': device, 'dtype': dtype}
  225. super().__init__()
  226. if mlp_ratios is None:
  227. mlp_ratios = []
  228. # Conv-Attention.
  229. self.norm12 = norm_layer(dims[1], **dd)
  230. self.norm13 = norm_layer(dims[2], **dd)
  231. self.norm14 = norm_layer(dims[3], **dd)
  232. self.factoratt_crpe2 = FactorAttnConvRelPosEnc(
  233. dims[1],
  234. num_heads=num_heads,
  235. qkv_bias=qkv_bias,
  236. attn_drop=attn_drop,
  237. proj_drop=proj_drop,
  238. shared_crpe=shared_crpes[1],
  239. **dd,
  240. )
  241. self.factoratt_crpe3 = FactorAttnConvRelPosEnc(
  242. dims[2],
  243. num_heads=num_heads,
  244. qkv_bias=qkv_bias,
  245. attn_drop=attn_drop,
  246. proj_drop=proj_drop,
  247. shared_crpe=shared_crpes[2],
  248. **dd,
  249. )
  250. self.factoratt_crpe4 = FactorAttnConvRelPosEnc(
  251. dims[3],
  252. num_heads=num_heads,
  253. qkv_bias=qkv_bias,
  254. attn_drop=attn_drop,
  255. proj_drop=proj_drop,
  256. shared_crpe=shared_crpes[3],
  257. **dd,
  258. )
  259. self.drop_path = DropPath(drop_path) if drop_path > 0. else nn.Identity()
  260. # MLP.
  261. self.norm22 = norm_layer(dims[1], **dd)
  262. self.norm23 = norm_layer(dims[2], **dd)
  263. self.norm24 = norm_layer(dims[3], **dd)
  264. # In parallel block, we assume dimensions are the same and share the linear transformation.
  265. assert dims[1] == dims[2] == dims[3]
  266. assert mlp_ratios[1] == mlp_ratios[2] == mlp_ratios[3]
  267. mlp_hidden_dim = int(dims[1] * mlp_ratios[1])
  268. self.mlp2 = self.mlp3 = self.mlp4 = Mlp(
  269. in_features=dims[1],
  270. hidden_features=mlp_hidden_dim,
  271. act_layer=act_layer,
  272. drop=proj_drop,
  273. **dd,
  274. )
  275. def upsample(self, x, factor: float, size: Tuple[int, int]):
  276. """ Feature map up-sampling. """
  277. return self.interpolate(x, scale_factor=factor, size=size)
  278. def downsample(self, x, factor: float, size: Tuple[int, int]):
  279. """ Feature map down-sampling. """
  280. return self.interpolate(x, scale_factor=1.0/factor, size=size)
  281. def interpolate(self, x, scale_factor: float, size: Tuple[int, int]):
  282. """ Feature map interpolation. """
  283. B, N, C = x.shape
  284. H, W = size
  285. _assert(N == 1 + H * W, '')
  286. cls_token = x[:, :1, :]
  287. img_tokens = x[:, 1:, :]
  288. img_tokens = img_tokens.transpose(1, 2).reshape(B, C, H, W)
  289. img_tokens = F.interpolate(
  290. img_tokens,
  291. scale_factor=scale_factor,
  292. recompute_scale_factor=False,
  293. mode='bilinear',
  294. align_corners=False,
  295. )
  296. img_tokens = img_tokens.reshape(B, C, -1).transpose(1, 2)
  297. out = torch.cat((cls_token, img_tokens), dim=1)
  298. return out
  299. def forward(self, x1, x2, x3, x4, sizes: List[Tuple[int, int]]):
  300. _, S2, S3, S4 = sizes
  301. cur2 = self.norm12(x2)
  302. cur3 = self.norm13(x3)
  303. cur4 = self.norm14(x4)
  304. cur2 = self.factoratt_crpe2(cur2, size=S2)
  305. cur3 = self.factoratt_crpe3(cur3, size=S3)
  306. cur4 = self.factoratt_crpe4(cur4, size=S4)
  307. upsample3_2 = self.upsample(cur3, factor=2., size=S3)
  308. upsample4_3 = self.upsample(cur4, factor=2., size=S4)
  309. upsample4_2 = self.upsample(cur4, factor=4., size=S4)
  310. downsample2_3 = self.downsample(cur2, factor=2., size=S2)
  311. downsample3_4 = self.downsample(cur3, factor=2., size=S3)
  312. downsample2_4 = self.downsample(cur2, factor=4., size=S2)
  313. cur2 = cur2 + upsample3_2 + upsample4_2
  314. cur3 = cur3 + upsample4_3 + downsample2_3
  315. cur4 = cur4 + downsample3_4 + downsample2_4
  316. x2 = x2 + self.drop_path(cur2)
  317. x3 = x3 + self.drop_path(cur3)
  318. x4 = x4 + self.drop_path(cur4)
  319. # MLP.
  320. cur2 = self.norm22(x2)
  321. cur3 = self.norm23(x3)
  322. cur4 = self.norm24(x4)
  323. cur2 = self.mlp2(cur2)
  324. cur3 = self.mlp3(cur3)
  325. cur4 = self.mlp4(cur4)
  326. x2 = x2 + self.drop_path(cur2)
  327. x3 = x3 + self.drop_path(cur3)
  328. x4 = x4 + self.drop_path(cur4)
  329. return x1, x2, x3, x4
  330. class CoaT(nn.Module):
  331. """ CoaT class. """
  332. def __init__(
  333. self,
  334. img_size: int = 224,
  335. patch_size: int = 16,
  336. in_chans: int = 3,
  337. num_classes: int = 1000,
  338. embed_dims: Tuple[int, int, int, int] = (64, 128, 320, 512),
  339. serial_depths: Tuple[int, int, int, int] = (3, 4, 6, 3),
  340. parallel_depth: int = 0,
  341. num_heads: int = 8,
  342. mlp_ratios: Tuple[float, float, float, float] = (4, 4, 4, 4),
  343. qkv_bias: bool = True,
  344. drop_rate: float = 0.,
  345. proj_drop_rate: float = 0.,
  346. attn_drop_rate: float = 0.,
  347. drop_path_rate: float = 0.,
  348. norm_layer: Type[nn.Module] = LayerNorm,
  349. return_interm_layers: bool = False,
  350. out_features: Optional[List[str]] = None,
  351. crpe_window: Optional[dict] = None,
  352. global_pool: str = 'token',
  353. device=None,
  354. dtype=None,
  355. ):
  356. super().__init__()
  357. dd = {'device': device, 'dtype': dtype}
  358. assert global_pool in ('token', 'avg')
  359. crpe_window = crpe_window or {3: 2, 5: 3, 7: 3}
  360. self.return_interm_layers = return_interm_layers
  361. self.out_features = out_features
  362. self.embed_dims = embed_dims
  363. self.num_features = self.head_hidden_size = embed_dims[-1]
  364. self.num_classes = num_classes
  365. self.global_pool = global_pool
  366. # Patch embeddings.
  367. img_size = to_2tuple(img_size)
  368. self.patch_embed1 = PatchEmbed(
  369. img_size=img_size, patch_size=patch_size, in_chans=in_chans,
  370. embed_dim=embed_dims[0], norm_layer=nn.LayerNorm, **dd)
  371. self.patch_embed2 = PatchEmbed(
  372. img_size=[x // 4 for x in img_size], patch_size=2, in_chans=embed_dims[0],
  373. embed_dim=embed_dims[1], norm_layer=nn.LayerNorm, **dd)
  374. self.patch_embed3 = PatchEmbed(
  375. img_size=[x // 8 for x in img_size], patch_size=2, in_chans=embed_dims[1],
  376. embed_dim=embed_dims[2], norm_layer=nn.LayerNorm, **dd)
  377. self.patch_embed4 = PatchEmbed(
  378. img_size=[x // 16 for x in img_size], patch_size=2, in_chans=embed_dims[2],
  379. embed_dim=embed_dims[3], norm_layer=nn.LayerNorm, **dd)
  380. # Class tokens.
  381. self.cls_token1 = nn.Parameter(torch.zeros(1, 1, embed_dims[0], **dd))
  382. self.cls_token2 = nn.Parameter(torch.zeros(1, 1, embed_dims[1], **dd))
  383. self.cls_token3 = nn.Parameter(torch.zeros(1, 1, embed_dims[2], **dd))
  384. self.cls_token4 = nn.Parameter(torch.zeros(1, 1, embed_dims[3], **dd))
  385. # Convolutional position encodings.
  386. self.cpe1 = ConvPosEnc(dim=embed_dims[0], k=3, **dd)
  387. self.cpe2 = ConvPosEnc(dim=embed_dims[1], k=3, **dd)
  388. self.cpe3 = ConvPosEnc(dim=embed_dims[2], k=3, **dd)
  389. self.cpe4 = ConvPosEnc(dim=embed_dims[3], k=3, **dd)
  390. # Convolutional relative position encodings.
  391. self.crpe1 = ConvRelPosEnc(head_chs=embed_dims[0] // num_heads, num_heads=num_heads, window=crpe_window, **dd)
  392. self.crpe2 = ConvRelPosEnc(head_chs=embed_dims[1] // num_heads, num_heads=num_heads, window=crpe_window, **dd)
  393. self.crpe3 = ConvRelPosEnc(head_chs=embed_dims[2] // num_heads, num_heads=num_heads, window=crpe_window, **dd)
  394. self.crpe4 = ConvRelPosEnc(head_chs=embed_dims[3] // num_heads, num_heads=num_heads, window=crpe_window, **dd)
  395. dpr = drop_path_rate
  396. skwargs = dict(
  397. num_heads=num_heads,
  398. qkv_bias=qkv_bias,
  399. proj_drop=proj_drop_rate,
  400. attn_drop=attn_drop_rate,
  401. drop_path=dpr,
  402. norm_layer=norm_layer,
  403. )
  404. # Serial blocks 1.
  405. self.serial_blocks1 = nn.ModuleList([
  406. SerialBlock(
  407. dim=embed_dims[0],
  408. mlp_ratio=mlp_ratios[0],
  409. shared_cpe=self.cpe1,
  410. shared_crpe=self.crpe1,
  411. **skwargs,
  412. **dd,
  413. )
  414. for _ in range(serial_depths[0])]
  415. )
  416. # Serial blocks 2.
  417. self.serial_blocks2 = nn.ModuleList([
  418. SerialBlock(
  419. dim=embed_dims[1],
  420. mlp_ratio=mlp_ratios[1],
  421. shared_cpe=self.cpe2,
  422. shared_crpe=self.crpe2,
  423. **skwargs,
  424. **dd,
  425. )
  426. for _ in range(serial_depths[1])]
  427. )
  428. # Serial blocks 3.
  429. self.serial_blocks3 = nn.ModuleList([
  430. SerialBlock(
  431. dim=embed_dims[2],
  432. mlp_ratio=mlp_ratios[2],
  433. shared_cpe=self.cpe3,
  434. shared_crpe=self.crpe3,
  435. **skwargs,
  436. **dd,
  437. )
  438. for _ in range(serial_depths[2])]
  439. )
  440. # Serial blocks 4.
  441. self.serial_blocks4 = nn.ModuleList([
  442. SerialBlock(
  443. dim=embed_dims[3],
  444. mlp_ratio=mlp_ratios[3],
  445. shared_cpe=self.cpe4,
  446. shared_crpe=self.crpe4,
  447. **skwargs,
  448. **dd,
  449. )
  450. for _ in range(serial_depths[3])]
  451. )
  452. # Parallel blocks.
  453. self.parallel_depth = parallel_depth
  454. if self.parallel_depth > 0:
  455. self.parallel_blocks = nn.ModuleList([
  456. ParallelBlock(
  457. dims=embed_dims,
  458. mlp_ratios=mlp_ratios,
  459. shared_crpes=(self.crpe1, self.crpe2, self.crpe3, self.crpe4),
  460. **skwargs,
  461. **dd,
  462. )
  463. for _ in range(parallel_depth)]
  464. )
  465. else:
  466. self.parallel_blocks = None
  467. # Classification head(s).
  468. if not self.return_interm_layers:
  469. if self.parallel_blocks is not None:
  470. self.norm2 = norm_layer(embed_dims[1], **dd)
  471. self.norm3 = norm_layer(embed_dims[2], **dd)
  472. else:
  473. self.norm2 = self.norm3 = None
  474. self.norm4 = norm_layer(embed_dims[3], **dd)
  475. if self.parallel_depth > 0:
  476. # CoaT series: Aggregate features of last three scales for classification.
  477. assert embed_dims[1] == embed_dims[2] == embed_dims[3]
  478. self.aggregate = torch.nn.Conv1d(in_channels=3, out_channels=1, kernel_size=1, **dd)
  479. self.head_drop = nn.Dropout(drop_rate)
  480. self.head = nn.Linear(self.num_features, num_classes, **dd) if num_classes > 0 else nn.Identity()
  481. else:
  482. # CoaT-Lite series: Use feature of last scale for classification.
  483. self.aggregate = None
  484. self.head_drop = nn.Dropout(drop_rate)
  485. self.head = nn.Linear(self.num_features, num_classes, **dd) if num_classes > 0 else nn.Identity()
  486. # Initialize weights.
  487. trunc_normal_(self.cls_token1, std=.02)
  488. trunc_normal_(self.cls_token2, std=.02)
  489. trunc_normal_(self.cls_token3, std=.02)
  490. trunc_normal_(self.cls_token4, std=.02)
  491. self.apply(self._init_weights)
  492. def _init_weights(self, m):
  493. if isinstance(m, nn.Linear):
  494. trunc_normal_(m.weight, std=.02)
  495. if isinstance(m, nn.Linear) and m.bias is not None:
  496. nn.init.constant_(m.bias, 0)
  497. elif isinstance(m, nn.LayerNorm):
  498. nn.init.constant_(m.bias, 0)
  499. nn.init.constant_(m.weight, 1.0)
  500. @torch.jit.ignore
  501. def no_weight_decay(self):
  502. return {'cls_token1', 'cls_token2', 'cls_token3', 'cls_token4'}
  503. @torch.jit.ignore
  504. def set_grad_checkpointing(self, enable=True):
  505. assert not enable, 'gradient checkpointing not supported'
  506. @torch.jit.ignore
  507. def group_matcher(self, coarse=False):
  508. matcher = dict(
  509. stem1=r'^cls_token1|patch_embed1|crpe1|cpe1',
  510. serial_blocks1=r'^serial_blocks1\.(\d+)',
  511. stem2=r'^cls_token2|patch_embed2|crpe2|cpe2',
  512. serial_blocks2=r'^serial_blocks2\.(\d+)',
  513. stem3=r'^cls_token3|patch_embed3|crpe3|cpe3',
  514. serial_blocks3=r'^serial_blocks3\.(\d+)',
  515. stem4=r'^cls_token4|patch_embed4|crpe4|cpe4',
  516. serial_blocks4=r'^serial_blocks4\.(\d+)',
  517. parallel_blocks=[ # FIXME (partially?) overlap parallel w/ serial blocks??
  518. (r'^parallel_blocks\.(\d+)', None),
  519. (r'^norm|aggregate', (99999,)),
  520. ]
  521. )
  522. return matcher
  523. @torch.jit.ignore
  524. def get_classifier(self) -> nn.Module:
  525. return self.head
  526. def reset_classifier(self, num_classes: int, global_pool: Optional[str] = None):
  527. self.num_classes = num_classes
  528. if global_pool is not None:
  529. assert global_pool in ('token', 'avg')
  530. self.global_pool = global_pool
  531. self.head = nn.Linear(self.num_features, num_classes) if num_classes > 0 else nn.Identity()
  532. def forward_features(self, x0):
  533. B = x0.shape[0]
  534. # Serial blocks 1.
  535. x1 = self.patch_embed1(x0)
  536. H1, W1 = self.patch_embed1.grid_size
  537. x1 = insert_cls(x1, self.cls_token1)
  538. for blk in self.serial_blocks1:
  539. x1 = blk(x1, size=(H1, W1))
  540. x1_nocls = remove_cls(x1).reshape(B, H1, W1, -1).permute(0, 3, 1, 2).contiguous()
  541. # Serial blocks 2.
  542. x2 = self.patch_embed2(x1_nocls)
  543. H2, W2 = self.patch_embed2.grid_size
  544. x2 = insert_cls(x2, self.cls_token2)
  545. for blk in self.serial_blocks2:
  546. x2 = blk(x2, size=(H2, W2))
  547. x2_nocls = remove_cls(x2).reshape(B, H2, W2, -1).permute(0, 3, 1, 2).contiguous()
  548. # Serial blocks 3.
  549. x3 = self.patch_embed3(x2_nocls)
  550. H3, W3 = self.patch_embed3.grid_size
  551. x3 = insert_cls(x3, self.cls_token3)
  552. for blk in self.serial_blocks3:
  553. x3 = blk(x3, size=(H3, W3))
  554. x3_nocls = remove_cls(x3).reshape(B, H3, W3, -1).permute(0, 3, 1, 2).contiguous()
  555. # Serial blocks 4.
  556. x4 = self.patch_embed4(x3_nocls)
  557. H4, W4 = self.patch_embed4.grid_size
  558. x4 = insert_cls(x4, self.cls_token4)
  559. for blk in self.serial_blocks4:
  560. x4 = blk(x4, size=(H4, W4))
  561. x4_nocls = remove_cls(x4).reshape(B, H4, W4, -1).permute(0, 3, 1, 2).contiguous()
  562. # Only serial blocks: Early return.
  563. if self.parallel_blocks is None:
  564. if not torch.jit.is_scripting() and self.return_interm_layers:
  565. # Return intermediate features for down-stream tasks (e.g. Deformable DETR and Detectron2).
  566. feat_out = {}
  567. if 'x1_nocls' in self.out_features:
  568. feat_out['x1_nocls'] = x1_nocls
  569. if 'x2_nocls' in self.out_features:
  570. feat_out['x2_nocls'] = x2_nocls
  571. if 'x3_nocls' in self.out_features:
  572. feat_out['x3_nocls'] = x3_nocls
  573. if 'x4_nocls' in self.out_features:
  574. feat_out['x4_nocls'] = x4_nocls
  575. return feat_out
  576. else:
  577. # Return features for classification.
  578. x4 = self.norm4(x4)
  579. return x4
  580. # Parallel blocks.
  581. for blk in self.parallel_blocks:
  582. x2, x3, x4 = self.cpe2(x2, (H2, W2)), self.cpe3(x3, (H3, W3)), self.cpe4(x4, (H4, W4))
  583. x1, x2, x3, x4 = blk(x1, x2, x3, x4, sizes=[(H1, W1), (H2, W2), (H3, W3), (H4, W4)])
  584. if not torch.jit.is_scripting() and self.return_interm_layers:
  585. # Return intermediate features for down-stream tasks (e.g. Deformable DETR and Detectron2).
  586. feat_out = {}
  587. if 'x1_nocls' in self.out_features:
  588. x1_nocls = remove_cls(x1).reshape(B, H1, W1, -1).permute(0, 3, 1, 2).contiguous()
  589. feat_out['x1_nocls'] = x1_nocls
  590. if 'x2_nocls' in self.out_features:
  591. x2_nocls = remove_cls(x2).reshape(B, H2, W2, -1).permute(0, 3, 1, 2).contiguous()
  592. feat_out['x2_nocls'] = x2_nocls
  593. if 'x3_nocls' in self.out_features:
  594. x3_nocls = remove_cls(x3).reshape(B, H3, W3, -1).permute(0, 3, 1, 2).contiguous()
  595. feat_out['x3_nocls'] = x3_nocls
  596. if 'x4_nocls' in self.out_features:
  597. x4_nocls = remove_cls(x4).reshape(B, H4, W4, -1).permute(0, 3, 1, 2).contiguous()
  598. feat_out['x4_nocls'] = x4_nocls
  599. return feat_out
  600. else:
  601. x2 = self.norm2(x2)
  602. x3 = self.norm3(x3)
  603. x4 = self.norm4(x4)
  604. return [x2, x3, x4]
  605. def forward_head(self, x_feat: Union[torch.Tensor, List[torch.Tensor]], pre_logits: bool = False):
  606. if isinstance(x_feat, list):
  607. assert self.aggregate is not None
  608. if self.global_pool == 'avg':
  609. x = torch.cat([xl[:, 1:].mean(dim=1, keepdim=True) for xl in x_feat], dim=1) # [B, 3, C]
  610. else:
  611. x = torch.stack([xl[:, 0] for xl in x_feat], dim=1) # [B, 3, C]
  612. x = self.aggregate(x).squeeze(dim=1) # Shape: [B, C]
  613. else:
  614. x = x_feat[:, 1:].mean(dim=1) if self.global_pool == 'avg' else x_feat[:, 0]
  615. x = self.head_drop(x)
  616. return x if pre_logits else self.head(x)
  617. def forward(self, x) -> torch.Tensor:
  618. if not torch.jit.is_scripting() and self.return_interm_layers:
  619. # Return intermediate features (for down-stream tasks).
  620. return self.forward_features(x)
  621. else:
  622. # Return features for classification.
  623. x_feat = self.forward_features(x)
  624. x = self.forward_head(x_feat)
  625. return x
  626. def insert_cls(x, cls_token):
  627. """ Insert CLS token. """
  628. cls_tokens = cls_token.expand(x.shape[0], -1, -1)
  629. x = torch.cat((cls_tokens, x), dim=1)
  630. return x
  631. def remove_cls(x):
  632. """ Remove CLS token. """
  633. return x[:, 1:, :]
  634. def checkpoint_filter_fn(state_dict, model):
  635. out_dict = {}
  636. state_dict = state_dict.get('model', state_dict)
  637. for k, v in state_dict.items():
  638. # original model had unused norm layers, removing them requires filtering pretrained checkpoints
  639. if k.startswith('norm1') or \
  640. (k.startswith('norm2') and getattr(model, 'norm2', None) is None) or \
  641. (k.startswith('norm3') and getattr(model, 'norm3', None) is None) or \
  642. (k.startswith('norm4') and getattr(model, 'norm4', None) is None) or \
  643. (k.startswith('aggregate') and getattr(model, 'aggregate', None) is None) or \
  644. (k.startswith('head') and getattr(model, 'head', None) is None):
  645. continue
  646. out_dict[k] = v
  647. return out_dict
  648. def _create_coat(variant, pretrained=False, default_cfg=None, **kwargs):
  649. if kwargs.get('features_only', None):
  650. raise RuntimeError('features_only not implemented for Vision Transformer models.')
  651. model = build_model_with_cfg(
  652. CoaT,
  653. variant,
  654. pretrained,
  655. pretrained_filter_fn=checkpoint_filter_fn,
  656. **kwargs,
  657. )
  658. return model
  659. def _cfg_coat(url='', **kwargs):
  660. return {
  661. 'url': url,
  662. 'num_classes': 1000, 'input_size': (3, 224, 224), 'pool_size': None,
  663. 'crop_pct': .9, 'interpolation': 'bicubic', 'fixed_input_size': True,
  664. 'mean': IMAGENET_DEFAULT_MEAN, 'std': IMAGENET_DEFAULT_STD,
  665. 'first_conv': 'patch_embed1.proj', 'classifier': 'head',
  666. 'license': 'apache-2.0',
  667. **kwargs
  668. }
  669. default_cfgs = generate_default_cfgs({
  670. 'coat_tiny.in1k': _cfg_coat(hf_hub_id='timm/'),
  671. 'coat_mini.in1k': _cfg_coat(hf_hub_id='timm/'),
  672. 'coat_small.in1k': _cfg_coat(hf_hub_id='timm/'),
  673. 'coat_lite_tiny.in1k': _cfg_coat(hf_hub_id='timm/'),
  674. 'coat_lite_mini.in1k': _cfg_coat(hf_hub_id='timm/'),
  675. 'coat_lite_small.in1k': _cfg_coat(hf_hub_id='timm/'),
  676. 'coat_lite_medium.in1k': _cfg_coat(hf_hub_id='timm/'),
  677. 'coat_lite_medium_384.in1k': _cfg_coat(
  678. hf_hub_id='timm/',
  679. input_size=(3, 384, 384), crop_pct=1.0, crop_mode='squash',
  680. ),
  681. })
  682. @register_model
  683. def coat_tiny(pretrained=False, **kwargs) -> CoaT:
  684. model_cfg = dict(
  685. patch_size=4, embed_dims=[152, 152, 152, 152], serial_depths=[2, 2, 2, 2], parallel_depth=6)
  686. model = _create_coat('coat_tiny', pretrained=pretrained, **dict(model_cfg, **kwargs))
  687. return model
  688. @register_model
  689. def coat_mini(pretrained=False, **kwargs) -> CoaT:
  690. model_cfg = dict(
  691. patch_size=4, embed_dims=[152, 216, 216, 216], serial_depths=[2, 2, 2, 2], parallel_depth=6)
  692. model = _create_coat('coat_mini', pretrained=pretrained, **dict(model_cfg, **kwargs))
  693. return model
  694. @register_model
  695. def coat_small(pretrained=False, **kwargs) -> CoaT:
  696. model_cfg = dict(
  697. patch_size=4, embed_dims=[152, 320, 320, 320], serial_depths=[2, 2, 2, 2], parallel_depth=6, **kwargs)
  698. model = _create_coat('coat_small', pretrained=pretrained, **dict(model_cfg, **kwargs))
  699. return model
  700. @register_model
  701. def coat_lite_tiny(pretrained=False, **kwargs) -> CoaT:
  702. model_cfg = dict(
  703. patch_size=4, embed_dims=[64, 128, 256, 320], serial_depths=[2, 2, 2, 2], mlp_ratios=[8, 8, 4, 4])
  704. model = _create_coat('coat_lite_tiny', pretrained=pretrained, **dict(model_cfg, **kwargs))
  705. return model
  706. @register_model
  707. def coat_lite_mini(pretrained=False, **kwargs) -> CoaT:
  708. model_cfg = dict(
  709. patch_size=4, embed_dims=[64, 128, 320, 512], serial_depths=[2, 2, 2, 2], mlp_ratios=[8, 8, 4, 4])
  710. model = _create_coat('coat_lite_mini', pretrained=pretrained, **dict(model_cfg, **kwargs))
  711. return model
  712. @register_model
  713. def coat_lite_small(pretrained=False, **kwargs) -> CoaT:
  714. model_cfg = dict(
  715. patch_size=4, embed_dims=[64, 128, 320, 512], serial_depths=[3, 4, 6, 3], mlp_ratios=[8, 8, 4, 4])
  716. model = _create_coat('coat_lite_small', pretrained=pretrained, **dict(model_cfg, **kwargs))
  717. return model
  718. @register_model
  719. def coat_lite_medium(pretrained=False, **kwargs) -> CoaT:
  720. model_cfg = dict(
  721. patch_size=4, embed_dims=[128, 256, 320, 512], serial_depths=[3, 6, 10, 8])
  722. model = _create_coat('coat_lite_medium', pretrained=pretrained, **dict(model_cfg, **kwargs))
  723. return model
  724. @register_model
  725. def coat_lite_medium_384(pretrained=False, **kwargs) -> CoaT:
  726. model_cfg = dict(
  727. img_size=384, patch_size=4, embed_dims=[128, 256, 320, 512], serial_depths=[3, 6, 10, 8])
  728. model = _create_coat('coat_lite_medium_384', pretrained=pretrained, **dict(model_cfg, **kwargs))
  729. return model