crossvit.py 25 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525526527528529530531532533534535536537538539540541542543544545546547548549550551552553554555556557558559560561562563564565566567568569570571572573574575576577578579580581582583584585586587588589590591592593594595596597598599600601602603604605606607608609610611612613614615616617618619620621622623624625626627628629630631632633634635636637638639640641642643644645646647648649650651652653
  1. """ CrossViT Model
  2. @inproceedings{
  3. chen2021crossvit,
  4. title={{CrossViT: Cross-Attention Multi-Scale Vision Transformer for Image Classification}},
  5. author={Chun-Fu (Richard) Chen and Quanfu Fan and Rameswar Panda},
  6. booktitle={International Conference on Computer Vision (ICCV)},
  7. year={2021}
  8. }
  9. Paper link: https://arxiv.org/abs/2103.14899
  10. Original code: https://github.com/IBM/CrossViT/blob/main/models/crossvit.py
  11. NOTE: model names have been renamed from originals to represent actual input res all *_224 -> *_240 and *_384 -> *_408
  12. Modifications and additions for timm hacked together by / Copyright 2021, Ross Wightman
  13. Modified from Timm. https://github.com/rwightman/pytorch-image-models/blob/master/timm/models/vision_transformer.py
  14. """
  15. # Copyright IBM All Rights Reserved.
  16. # SPDX-License-Identifier: Apache-2.0
  17. from functools import partial
  18. from typing import List, Optional, Tuple, Type, Union
  19. import torch
  20. import torch.nn as nn
  21. from timm.data import IMAGENET_DEFAULT_MEAN, IMAGENET_DEFAULT_STD
  22. from timm.layers import DropPath, calculate_drop_path_rates, to_2tuple, trunc_normal_, _assert
  23. from ._builder import build_model_with_cfg
  24. from ._features_fx import register_notrace_function
  25. from ._registry import register_model, generate_default_cfgs
  26. from .vision_transformer import Block
  27. __all__ = ['CrossVit'] # model_registry will add each entrypoint fn to this
  28. class PatchEmbed(nn.Module):
  29. """ Image to Patch Embedding
  30. """
  31. def __init__(
  32. self,
  33. img_size: Union[int, Tuple[int, int]] = 224,
  34. patch_size: int = 16,
  35. in_chans: int = 3,
  36. embed_dim: int = 768,
  37. multi_conv: bool = False,
  38. device=None,
  39. dtype=None,
  40. ):
  41. dd = {'device': device, 'dtype': dtype}
  42. super().__init__()
  43. img_size = to_2tuple(img_size)
  44. patch_size = to_2tuple(patch_size)
  45. num_patches = (img_size[1] // patch_size[1]) * (img_size[0] // patch_size[0])
  46. self.img_size = img_size
  47. self.patch_size = patch_size
  48. self.num_patches = num_patches
  49. if multi_conv:
  50. if patch_size[0] == 12:
  51. self.proj = nn.Sequential(
  52. nn.Conv2d(in_chans, embed_dim // 4, kernel_size=7, stride=4, padding=3, **dd),
  53. nn.ReLU(inplace=True),
  54. nn.Conv2d(embed_dim // 4, embed_dim // 2, kernel_size=3, stride=3, padding=0, **dd),
  55. nn.ReLU(inplace=True),
  56. nn.Conv2d(embed_dim // 2, embed_dim, kernel_size=3, stride=1, padding=1, **dd),
  57. )
  58. elif patch_size[0] == 16:
  59. self.proj = nn.Sequential(
  60. nn.Conv2d(in_chans, embed_dim // 4, kernel_size=7, stride=4, padding=3, **dd),
  61. nn.ReLU(inplace=True),
  62. nn.Conv2d(embed_dim // 4, embed_dim // 2, kernel_size=3, stride=2, padding=1, **dd),
  63. nn.ReLU(inplace=True),
  64. nn.Conv2d(embed_dim // 2, embed_dim, kernel_size=3, stride=2, padding=1, **dd),
  65. )
  66. else:
  67. self.proj = nn.Conv2d(in_chans, embed_dim, kernel_size=patch_size, stride=patch_size, **dd)
  68. def forward(self, x):
  69. B, C, H, W = x.shape
  70. # FIXME look at relaxing size constraints
  71. _assert(H == self.img_size[0],
  72. f"Input image size ({H}*{W}) doesn't match model ({self.img_size[0]}*{self.img_size[1]}).")
  73. _assert(W == self.img_size[1],
  74. f"Input image size ({H}*{W}) doesn't match model ({self.img_size[0]}*{self.img_size[1]}).")
  75. x = self.proj(x).flatten(2).transpose(1, 2)
  76. return x
  77. class CrossAttention(nn.Module):
  78. def __init__(
  79. self,
  80. dim: int,
  81. num_heads: int = 8,
  82. qkv_bias: bool = False,
  83. attn_drop: float = 0.,
  84. proj_drop: float = 0.,
  85. device=None,
  86. dtype=None,
  87. ):
  88. dd = {'device': device, 'dtype': dtype}
  89. super().__init__()
  90. self.num_heads = num_heads
  91. head_dim = dim // num_heads
  92. # NOTE scale factor was wrong in my original version, can set manually to be compat with prev weights
  93. self.scale = head_dim ** -0.5
  94. self.wq = nn.Linear(dim, dim, bias=qkv_bias, **dd)
  95. self.wk = nn.Linear(dim, dim, bias=qkv_bias, **dd)
  96. self.wv = nn.Linear(dim, dim, bias=qkv_bias, **dd)
  97. self.attn_drop = nn.Dropout(attn_drop)
  98. self.proj = nn.Linear(dim, dim, **dd)
  99. self.proj_drop = nn.Dropout(proj_drop)
  100. def forward(self, x):
  101. B, N, C = x.shape
  102. # B1C -> B1H(C/H) -> BH1(C/H)
  103. q = self.wq(x[:, 0:1, ...]).reshape(B, 1, self.num_heads, C // self.num_heads).permute(0, 2, 1, 3)
  104. # BNC -> BNH(C/H) -> BHN(C/H)
  105. k = self.wk(x).reshape(B, N, self.num_heads, C // self.num_heads).permute(0, 2, 1, 3)
  106. # BNC -> BNH(C/H) -> BHN(C/H)
  107. v = self.wv(x).reshape(B, N, self.num_heads, C // self.num_heads).permute(0, 2, 1, 3)
  108. attn = (q @ k.transpose(-2, -1)) * self.scale # BH1(C/H) @ BH(C/H)N -> BH1N
  109. attn = attn.softmax(dim=-1)
  110. attn = self.attn_drop(attn)
  111. x = (attn @ v).transpose(1, 2).reshape(B, 1, C) # (BH1N @ BHN(C/H)) -> BH1(C/H) -> B1H(C/H) -> B1C
  112. x = self.proj(x)
  113. x = self.proj_drop(x)
  114. return x
  115. class CrossAttentionBlock(nn.Module):
  116. def __init__(
  117. self,
  118. dim: int,
  119. num_heads: int,
  120. mlp_ratio: float = 4.,
  121. qkv_bias: bool = False,
  122. proj_drop: float = 0.,
  123. attn_drop: float = 0.,
  124. drop_path: float = 0.,
  125. act_layer: Type[nn.Module] = nn.GELU,
  126. norm_layer: Type[nn.Module] = nn.LayerNorm,
  127. device=None,
  128. dtype=None,
  129. ):
  130. dd = {'device': device, 'dtype': dtype}
  131. super().__init__()
  132. self.norm1 = norm_layer(dim, **dd)
  133. self.attn = CrossAttention(
  134. dim,
  135. num_heads=num_heads,
  136. qkv_bias=qkv_bias,
  137. attn_drop=attn_drop,
  138. proj_drop=proj_drop,
  139. **dd,
  140. )
  141. # NOTE: drop path for stochastic depth, we shall see if this is better than dropout here
  142. self.drop_path = DropPath(drop_path) if drop_path > 0. else nn.Identity()
  143. def forward(self, x):
  144. x = x[:, 0:1, ...] + self.drop_path(self.attn(self.norm1(x)))
  145. return x
  146. class MultiScaleBlock(nn.Module):
  147. def __init__(
  148. self,
  149. dim: Tuple[int, ...],
  150. patches: Tuple[int, ...],
  151. depth: Tuple[int, ...],
  152. num_heads: Tuple[int, ...],
  153. mlp_ratio: Tuple[float, ...],
  154. qkv_bias: bool = False,
  155. proj_drop: float = 0.,
  156. attn_drop: float = 0.,
  157. drop_path: Union[List[float], float] = 0.,
  158. act_layer: Type[nn.Module] = nn.GELU,
  159. norm_layer: Type[nn.Module] = nn.LayerNorm,
  160. device=None,
  161. dtype=None,
  162. ):
  163. dd = {'device': device, 'dtype': dtype}
  164. super().__init__()
  165. num_branches = len(dim)
  166. self.num_branches = num_branches
  167. # different branch could have different embedding size, the first one is the base
  168. self.blocks = nn.ModuleList()
  169. for d in range(num_branches):
  170. tmp = []
  171. for i in range(depth[d]):
  172. tmp.append(Block(
  173. dim=dim[d],
  174. num_heads=num_heads[d],
  175. mlp_ratio=mlp_ratio[d],
  176. qkv_bias=qkv_bias,
  177. proj_drop=proj_drop,
  178. attn_drop=attn_drop,
  179. drop_path=drop_path[i],
  180. norm_layer=norm_layer,
  181. **dd,
  182. ))
  183. if len(tmp) != 0:
  184. self.blocks.append(nn.Sequential(*tmp))
  185. if len(self.blocks) == 0:
  186. self.blocks = None
  187. self.projs = nn.ModuleList()
  188. for d in range(num_branches):
  189. if dim[d] == dim[(d + 1) % num_branches] and False:
  190. tmp = [nn.Identity()]
  191. else:
  192. tmp = [norm_layer(dim[d], **dd), act_layer(), nn.Linear(dim[d], dim[(d + 1) % num_branches], **dd)]
  193. self.projs.append(nn.Sequential(*tmp))
  194. self.fusion = nn.ModuleList()
  195. for d in range(num_branches):
  196. d_ = (d + 1) % num_branches
  197. nh = num_heads[d_]
  198. if depth[-1] == 0: # backward capability:
  199. self.fusion.append(
  200. CrossAttentionBlock(
  201. dim=dim[d_],
  202. num_heads=nh,
  203. mlp_ratio=mlp_ratio[d],
  204. qkv_bias=qkv_bias,
  205. proj_drop=proj_drop,
  206. attn_drop=attn_drop,
  207. drop_path=drop_path[-1],
  208. norm_layer=norm_layer,
  209. **dd,
  210. ))
  211. else:
  212. tmp = []
  213. for _ in range(depth[-1]):
  214. tmp.append(CrossAttentionBlock(
  215. dim=dim[d_],
  216. num_heads=nh,
  217. mlp_ratio=mlp_ratio[d],
  218. qkv_bias=qkv_bias,
  219. proj_drop=proj_drop,
  220. attn_drop=attn_drop,
  221. drop_path=drop_path[-1],
  222. norm_layer=norm_layer,
  223. **dd,
  224. ))
  225. self.fusion.append(nn.Sequential(*tmp))
  226. self.revert_projs = nn.ModuleList()
  227. for d in range(num_branches):
  228. if dim[(d + 1) % num_branches] == dim[d] and False:
  229. tmp = [nn.Identity()]
  230. else:
  231. tmp = [norm_layer(dim[(d + 1) % num_branches], **dd), act_layer(),
  232. nn.Linear(dim[(d + 1) % num_branches], dim[d], **dd)]
  233. self.revert_projs.append(nn.Sequential(*tmp))
  234. def forward(self, x: List[torch.Tensor]) -> List[torch.Tensor]:
  235. outs_b = []
  236. for i, block in enumerate(self.blocks):
  237. outs_b.append(block(x[i]))
  238. # only take the cls token out
  239. proj_cls_token = torch.jit.annotate(List[torch.Tensor], [])
  240. for i, proj in enumerate(self.projs):
  241. proj_cls_token.append(proj(outs_b[i][:, 0:1, ...]))
  242. # cross attention
  243. outs = []
  244. for i, (fusion, revert_proj) in enumerate(zip(self.fusion, self.revert_projs)):
  245. tmp = torch.cat((proj_cls_token[i], outs_b[(i + 1) % self.num_branches][:, 1:, ...]), dim=1)
  246. tmp = fusion(tmp)
  247. reverted_proj_cls_token = revert_proj(tmp[:, 0:1, ...])
  248. tmp = torch.cat((reverted_proj_cls_token, outs_b[i][:, 1:, ...]), dim=1)
  249. outs.append(tmp)
  250. return outs
  251. def _compute_num_patches(img_size, patches):
  252. return [i[0] // p * i[1] // p for i, p in zip(img_size, patches)]
  253. @register_notrace_function
  254. def scale_image(x, ss: Tuple[int, int], crop_scale: bool = False): # annotations for torchscript
  255. """
  256. Pulled out of CrossViT.forward_features to bury conditional logic in a leaf node for FX tracing.
  257. Args:
  258. x (Tensor): input image
  259. ss (tuple[int, int]): height and width to scale to
  260. crop_scale (bool): whether to crop instead of interpolate to achieve the desired scale. Defaults to False
  261. Returns:
  262. Tensor: the "scaled" image batch tensor
  263. """
  264. H, W = x.shape[-2:]
  265. if H != ss[0] or W != ss[1]:
  266. if crop_scale and ss[0] <= H and ss[1] <= W:
  267. cu, cl = int(round((H - ss[0]) / 2.)), int(round((W - ss[1]) / 2.))
  268. x = x[:, :, cu:cu + ss[0], cl:cl + ss[1]]
  269. else:
  270. x = torch.nn.functional.interpolate(x, size=ss, mode='bicubic', align_corners=False)
  271. return x
  272. class CrossVit(nn.Module):
  273. """ Vision Transformer with support for patch or hybrid CNN input stage
  274. """
  275. def __init__(
  276. self,
  277. img_size: int = 224,
  278. img_scale: Tuple[float, ...] = (1.0, 1.0),
  279. patch_size: Tuple[int, ...] = (8, 16),
  280. in_chans: int = 3,
  281. num_classes: int = 1000,
  282. embed_dim: Tuple[int, ...] = (192, 384),
  283. depth: Tuple[Tuple[int, ...], ...] = ((1, 3, 1), (1, 3, 1), (1, 3, 1)),
  284. num_heads: Tuple[int, ...] = (6, 12),
  285. mlp_ratio: Tuple[float, ...] = (2., 2., 4.),
  286. multi_conv: bool = False,
  287. crop_scale: bool = False,
  288. qkv_bias: bool = True,
  289. drop_rate: float = 0.,
  290. pos_drop_rate: float = 0.,
  291. proj_drop_rate: float = 0.,
  292. attn_drop_rate: float = 0.,
  293. drop_path_rate: float = 0.,
  294. norm_layer: Type[nn.Module] = partial(nn.LayerNorm, eps=1e-6),
  295. global_pool: str = 'token',
  296. device=None,
  297. dtype=None,
  298. ):
  299. super().__init__()
  300. dd = {'device': device, 'dtype': dtype}
  301. assert global_pool in ('token', 'avg')
  302. self.num_classes = num_classes
  303. self.global_pool = global_pool
  304. self.img_size = to_2tuple(img_size)
  305. img_scale = to_2tuple(img_scale)
  306. self.img_size_scaled = [tuple([int(sj * si) for sj in self.img_size]) for si in img_scale]
  307. self.crop_scale = crop_scale # crop instead of interpolate for scale
  308. num_patches = _compute_num_patches(self.img_size_scaled, patch_size)
  309. self.num_branches = len(patch_size)
  310. self.embed_dim = embed_dim
  311. self.num_features = self.head_hidden_size = sum(embed_dim)
  312. self.patch_embed = nn.ModuleList()
  313. # hard-coded for torch jit script
  314. for i in range(self.num_branches):
  315. setattr(self, f'pos_embed_{i}', nn.Parameter(torch.zeros(1, 1 + num_patches[i], embed_dim[i], **dd)))
  316. setattr(self, f'cls_token_{i}', nn.Parameter(torch.zeros(1, 1, embed_dim[i], **dd)))
  317. for im_s, p, d in zip(self.img_size_scaled, patch_size, embed_dim):
  318. self.patch_embed.append(
  319. PatchEmbed(
  320. img_size=im_s,
  321. patch_size=p,
  322. in_chans=in_chans,
  323. embed_dim=d,
  324. multi_conv=multi_conv,
  325. **dd,
  326. ))
  327. self.pos_drop = nn.Dropout(p=pos_drop_rate)
  328. total_depth = sum([sum(x[-2:]) for x in depth])
  329. dpr = calculate_drop_path_rates(drop_path_rate, total_depth) # stochastic depth decay rule
  330. dpr_ptr = 0
  331. self.blocks = nn.ModuleList()
  332. for idx, block_cfg in enumerate(depth):
  333. curr_depth = max(block_cfg[:-1]) + block_cfg[-1]
  334. dpr_ = dpr[dpr_ptr:dpr_ptr + curr_depth]
  335. blk = MultiScaleBlock(
  336. embed_dim,
  337. num_patches,
  338. block_cfg,
  339. num_heads=num_heads,
  340. mlp_ratio=mlp_ratio,
  341. qkv_bias=qkv_bias,
  342. proj_drop=proj_drop_rate,
  343. attn_drop=attn_drop_rate,
  344. drop_path=dpr_,
  345. norm_layer=norm_layer,
  346. **dd,
  347. )
  348. dpr_ptr += curr_depth
  349. self.blocks.append(blk)
  350. self.norm = nn.ModuleList([norm_layer(embed_dim[i], **dd) for i in range(self.num_branches)])
  351. self.head_drop = nn.Dropout(drop_rate)
  352. self.head = nn.ModuleList([
  353. nn.Linear(embed_dim[i], num_classes, **dd) if num_classes > 0 else nn.Identity()
  354. for i in range(self.num_branches)])
  355. for i in range(self.num_branches):
  356. trunc_normal_(getattr(self, f'pos_embed_{i}'), std=.02)
  357. trunc_normal_(getattr(self, f'cls_token_{i}'), std=.02)
  358. self.apply(self._init_weights)
  359. def _init_weights(self, m):
  360. if isinstance(m, nn.Linear):
  361. trunc_normal_(m.weight, std=.02)
  362. if isinstance(m, nn.Linear) and m.bias is not None:
  363. nn.init.constant_(m.bias, 0)
  364. elif isinstance(m, nn.LayerNorm):
  365. nn.init.constant_(m.bias, 0)
  366. nn.init.constant_(m.weight, 1.0)
  367. @torch.jit.ignore
  368. def no_weight_decay(self):
  369. out = set()
  370. for i in range(self.num_branches):
  371. out.add(f'cls_token_{i}')
  372. pe = getattr(self, f'pos_embed_{i}', None)
  373. if pe is not None and pe.requires_grad:
  374. out.add(f'pos_embed_{i}')
  375. return out
  376. @torch.jit.ignore
  377. def group_matcher(self, coarse=False):
  378. return dict(
  379. stem=r'^cls_token|pos_embed|patch_embed', # stem and embed
  380. blocks=[(r'^blocks\.(\d+)', None), (r'^norm', (99999,))]
  381. )
  382. @torch.jit.ignore
  383. def set_grad_checkpointing(self, enable=True):
  384. assert not enable, 'gradient checkpointing not supported'
  385. @torch.jit.ignore
  386. def get_classifier(self) -> nn.Module:
  387. return self.head
  388. def reset_classifier(self, num_classes: int, global_pool: Optional[str] = None):
  389. self.num_classes = num_classes
  390. if global_pool is not None:
  391. assert global_pool in ('token', 'avg')
  392. self.global_pool = global_pool
  393. device = self.head[0].weight.device if hasattr(self.head[0], 'weight') else None
  394. dtype = self.head[0].weight.dtype if hasattr(self.head[0], 'weight') else None
  395. dd = {'device': device, 'dtype': dtype}
  396. self.head = nn.ModuleList([
  397. nn.Linear(self.embed_dim[i], num_classes, **dd) if num_classes > 0 else nn.Identity()
  398. for i in range(self.num_branches)
  399. ])
  400. def forward_features(self, x) -> List[torch.Tensor]:
  401. B = x.shape[0]
  402. xs = []
  403. for i, patch_embed in enumerate(self.patch_embed):
  404. x_ = x
  405. ss = self.img_size_scaled[i]
  406. x_ = scale_image(x_, ss, self.crop_scale)
  407. x_ = patch_embed(x_)
  408. cls_tokens = self.cls_token_0 if i == 0 else self.cls_token_1 # hard-coded for torch jit script
  409. cls_tokens = cls_tokens.expand(B, -1, -1)
  410. x_ = torch.cat((cls_tokens, x_), dim=1)
  411. pos_embed = self.pos_embed_0 if i == 0 else self.pos_embed_1 # hard-coded for torch jit script
  412. x_ = x_ + pos_embed
  413. x_ = self.pos_drop(x_)
  414. xs.append(x_)
  415. for i, blk in enumerate(self.blocks):
  416. xs = blk(xs)
  417. # NOTE: was before branch token section, move to here to assure all branch token are before layer norm
  418. xs = [norm(xs[i]) for i, norm in enumerate(self.norm)]
  419. return xs
  420. def forward_head(self, xs: List[torch.Tensor], pre_logits: bool = False) -> torch.Tensor:
  421. xs = [x[:, 1:].mean(dim=1) for x in xs] if self.global_pool == 'avg' else [x[:, 0] for x in xs]
  422. xs = [self.head_drop(x) for x in xs]
  423. if pre_logits or isinstance(self.head[0], nn.Identity):
  424. return torch.cat([x for x in xs], dim=1)
  425. return torch.mean(torch.stack([head(xs[i]) for i, head in enumerate(self.head)], dim=0), dim=0)
  426. def forward(self, x):
  427. xs = self.forward_features(x)
  428. x = self.forward_head(xs)
  429. return x
  430. def _create_crossvit(variant, pretrained=False, **kwargs):
  431. if kwargs.get('features_only', None):
  432. raise RuntimeError('features_only not implemented for Vision Transformer models.')
  433. def pretrained_filter_fn(state_dict):
  434. new_state_dict = {}
  435. for key in state_dict.keys():
  436. if 'pos_embed' in key or 'cls_token' in key:
  437. new_key = key.replace(".", "_")
  438. else:
  439. new_key = key
  440. new_state_dict[new_key] = state_dict[key]
  441. return new_state_dict
  442. return build_model_with_cfg(
  443. CrossVit,
  444. variant,
  445. pretrained,
  446. pretrained_filter_fn=pretrained_filter_fn,
  447. **kwargs,
  448. )
  449. def _cfg(url='', **kwargs):
  450. return {
  451. 'url': url,
  452. 'num_classes': 1000, 'input_size': (3, 240, 240), 'pool_size': None, 'crop_pct': 0.875,
  453. 'mean': IMAGENET_DEFAULT_MEAN, 'std': IMAGENET_DEFAULT_STD, 'fixed_input_size': True,
  454. 'first_conv': ('patch_embed.0.proj', 'patch_embed.1.proj'),
  455. 'classifier': ('head.0', 'head.1'),
  456. 'license': 'apache-2.0',
  457. **kwargs
  458. }
  459. default_cfgs = generate_default_cfgs({
  460. 'crossvit_15_240.in1k': _cfg(hf_hub_id='timm/'),
  461. 'crossvit_15_dagger_240.in1k': _cfg(
  462. hf_hub_id='timm/',
  463. first_conv=('patch_embed.0.proj.0', 'patch_embed.1.proj.0'),
  464. ),
  465. 'crossvit_15_dagger_408.in1k': _cfg(
  466. hf_hub_id='timm/',
  467. input_size=(3, 408, 408), first_conv=('patch_embed.0.proj.0', 'patch_embed.1.proj.0'), crop_pct=1.0,
  468. ),
  469. 'crossvit_18_240.in1k': _cfg(hf_hub_id='timm/'),
  470. 'crossvit_18_dagger_240.in1k': _cfg(
  471. hf_hub_id='timm/',
  472. first_conv=('patch_embed.0.proj.0', 'patch_embed.1.proj.0'),
  473. ),
  474. 'crossvit_18_dagger_408.in1k': _cfg(
  475. hf_hub_id='timm/',
  476. input_size=(3, 408, 408), first_conv=('patch_embed.0.proj.0', 'patch_embed.1.proj.0'), crop_pct=1.0,
  477. ),
  478. 'crossvit_9_240.in1k': _cfg(hf_hub_id='timm/'),
  479. 'crossvit_9_dagger_240.in1k': _cfg(
  480. hf_hub_id='timm/',
  481. first_conv=('patch_embed.0.proj.0', 'patch_embed.1.proj.0'),
  482. ),
  483. 'crossvit_base_240.in1k': _cfg(hf_hub_id='timm/'),
  484. 'crossvit_small_240.in1k': _cfg(hf_hub_id='timm/'),
  485. 'crossvit_tiny_240.in1k': _cfg(hf_hub_id='timm/'),
  486. })
  487. @register_model
  488. def crossvit_tiny_240(pretrained=False, **kwargs) -> CrossVit:
  489. model_args = dict(
  490. img_scale=(1.0, 224/240), patch_size=[12, 16], embed_dim=[96, 192], depth=[[1, 4, 0], [1, 4, 0], [1, 4, 0]],
  491. num_heads=[3, 3], mlp_ratio=[4, 4, 1])
  492. model = _create_crossvit(variant='crossvit_tiny_240', pretrained=pretrained, **dict(model_args, **kwargs))
  493. return model
  494. @register_model
  495. def crossvit_small_240(pretrained=False, **kwargs) -> CrossVit:
  496. model_args = dict(
  497. img_scale=(1.0, 224/240), patch_size=[12, 16], embed_dim=[192, 384], depth=[[1, 4, 0], [1, 4, 0], [1, 4, 0]],
  498. num_heads=[6, 6], mlp_ratio=[4, 4, 1])
  499. model = _create_crossvit(variant='crossvit_small_240', pretrained=pretrained, **dict(model_args, **kwargs))
  500. return model
  501. @register_model
  502. def crossvit_base_240(pretrained=False, **kwargs) -> CrossVit:
  503. model_args = dict(
  504. img_scale=(1.0, 224/240), patch_size=[12, 16], embed_dim=[384, 768], depth=[[1, 4, 0], [1, 4, 0], [1, 4, 0]],
  505. num_heads=[12, 12], mlp_ratio=[4, 4, 1])
  506. model = _create_crossvit(variant='crossvit_base_240', pretrained=pretrained, **dict(model_args, **kwargs))
  507. return model
  508. @register_model
  509. def crossvit_9_240(pretrained=False, **kwargs) -> CrossVit:
  510. model_args = dict(
  511. img_scale=(1.0, 224/240), patch_size=[12, 16], embed_dim=[128, 256], depth=[[1, 3, 0], [1, 3, 0], [1, 3, 0]],
  512. num_heads=[4, 4], mlp_ratio=[3, 3, 1])
  513. model = _create_crossvit(variant='crossvit_9_240', pretrained=pretrained, **dict(model_args, **kwargs))
  514. return model
  515. @register_model
  516. def crossvit_15_240(pretrained=False, **kwargs) -> CrossVit:
  517. model_args = dict(
  518. img_scale=(1.0, 224/240), patch_size=[12, 16], embed_dim=[192, 384], depth=[[1, 5, 0], [1, 5, 0], [1, 5, 0]],
  519. num_heads=[6, 6], mlp_ratio=[3, 3, 1])
  520. model = _create_crossvit(variant='crossvit_15_240', pretrained=pretrained, **dict(model_args, **kwargs))
  521. return model
  522. @register_model
  523. def crossvit_18_240(pretrained=False, **kwargs) -> CrossVit:
  524. model_args = dict(
  525. img_scale=(1.0, 224 / 240), patch_size=[12, 16], embed_dim=[224, 448], depth=[[1, 6, 0], [1, 6, 0], [1, 6, 0]],
  526. num_heads=[7, 7], mlp_ratio=[3, 3, 1], **kwargs)
  527. model = _create_crossvit(variant='crossvit_18_240', pretrained=pretrained, **dict(model_args, **kwargs))
  528. return model
  529. @register_model
  530. def crossvit_9_dagger_240(pretrained=False, **kwargs) -> CrossVit:
  531. model_args = dict(
  532. img_scale=(1.0, 224 / 240), patch_size=[12, 16], embed_dim=[128, 256], depth=[[1, 3, 0], [1, 3, 0], [1, 3, 0]],
  533. num_heads=[4, 4], mlp_ratio=[3, 3, 1], multi_conv=True)
  534. model = _create_crossvit(variant='crossvit_9_dagger_240', pretrained=pretrained, **dict(model_args, **kwargs))
  535. return model
  536. @register_model
  537. def crossvit_15_dagger_240(pretrained=False, **kwargs) -> CrossVit:
  538. model_args = dict(
  539. img_scale=(1.0, 224/240), patch_size=[12, 16], embed_dim=[192, 384], depth=[[1, 5, 0], [1, 5, 0], [1, 5, 0]],
  540. num_heads=[6, 6], mlp_ratio=[3, 3, 1], multi_conv=True)
  541. model = _create_crossvit(variant='crossvit_15_dagger_240', pretrained=pretrained, **dict(model_args, **kwargs))
  542. return model
  543. @register_model
  544. def crossvit_15_dagger_408(pretrained=False, **kwargs) -> CrossVit:
  545. model_args = dict(
  546. img_scale=(1.0, 384/408), patch_size=[12, 16], embed_dim=[192, 384], depth=[[1, 5, 0], [1, 5, 0], [1, 5, 0]],
  547. num_heads=[6, 6], mlp_ratio=[3, 3, 1], multi_conv=True)
  548. model = _create_crossvit(variant='crossvit_15_dagger_408', pretrained=pretrained, **dict(model_args, **kwargs))
  549. return model
  550. @register_model
  551. def crossvit_18_dagger_240(pretrained=False, **kwargs) -> CrossVit:
  552. model_args = dict(
  553. img_scale=(1.0, 224/240), patch_size=[12, 16], embed_dim=[224, 448], depth=[[1, 6, 0], [1, 6, 0], [1, 6, 0]],
  554. num_heads=[7, 7], mlp_ratio=[3, 3, 1], multi_conv=True)
  555. model = _create_crossvit(variant='crossvit_18_dagger_240', pretrained=pretrained, **dict(model_args, **kwargs))
  556. return model
  557. @register_model
  558. def crossvit_18_dagger_408(pretrained=False, **kwargs) -> CrossVit:
  559. model_args = dict(
  560. img_scale=(1.0, 384/408), patch_size=[12, 16], embed_dim=[224, 448], depth=[[1, 6, 0], [1, 6, 0], [1, 6, 0]],
  561. num_heads=[7, 7], mlp_ratio=[3, 3, 1], multi_conv=True)
  562. model = _create_crossvit(variant='crossvit_18_dagger_408', pretrained=pretrained, **dict(model_args, **kwargs))
  563. return model