cait.py 23 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525526527528529530531532533534535536537538539540541542543544545546547548549550551552553554555556557558559560561562563564565566567568569570571572573574575576577578579580581582583584585586587588589590591592593594595596597598599600601602603604605606607608609610611612613614615616617618619620621622623624625626627628629630631
  1. """ Class-Attention in Image Transformers (CaiT)
  2. Paper: 'Going deeper with Image Transformers' - https://arxiv.org/abs/2103.17239
  3. Original code and weights from https://github.com/facebookresearch/deit, copyright below
  4. Modifications and additions for timm hacked together by / Copyright 2021, Ross Wightman
  5. """
  6. # Copyright (c) 2015-present, Facebook, Inc.
  7. # All rights reserved.
  8. from functools import partial
  9. from typing import List, Optional, Tuple, Union, Type, Any
  10. import torch
  11. import torch.nn as nn
  12. from timm.data import IMAGENET_DEFAULT_MEAN, IMAGENET_DEFAULT_STD
  13. from timm.layers import PatchEmbed, Mlp, DropPath, trunc_normal_, use_fused_attn
  14. from ._builder import build_model_with_cfg
  15. from ._features import feature_take_indices
  16. from ._manipulate import checkpoint, checkpoint_seq
  17. from ._registry import register_model, generate_default_cfgs
  18. __all__ = ['Cait', 'ClassAttn', 'LayerScaleBlockClassAttn', 'LayerScaleBlock', 'TalkingHeadAttn']
  19. class ClassAttn(nn.Module):
  20. # taken from https://github.com/rwightman/pytorch-image-models/blob/master/timm/models/vision_transformer.py
  21. # with slight modifications to do CA
  22. fused_attn: torch.jit.Final[bool]
  23. def __init__(
  24. self,
  25. dim: int,
  26. num_heads: int = 8,
  27. qkv_bias: bool = False,
  28. attn_drop: float = 0.,
  29. proj_drop: float = 0.,
  30. device=None,
  31. dtype=None,
  32. ):
  33. super().__init__()
  34. dd = {'device': device, 'dtype': dtype}
  35. self.num_heads = num_heads
  36. head_dim = dim // num_heads
  37. self.scale = head_dim ** -0.5
  38. self.fused_attn = use_fused_attn()
  39. self.q = nn.Linear(dim, dim, bias=qkv_bias, **dd)
  40. self.k = nn.Linear(dim, dim, bias=qkv_bias, **dd)
  41. self.v = nn.Linear(dim, dim, bias=qkv_bias, **dd)
  42. self.attn_drop = nn.Dropout(attn_drop)
  43. self.proj = nn.Linear(dim, dim, **dd)
  44. self.proj_drop = nn.Dropout(proj_drop)
  45. def forward(self, x):
  46. B, N, C = x.shape
  47. q = self.q(x[:, 0]).unsqueeze(1).reshape(B, 1, self.num_heads, C // self.num_heads).permute(0, 2, 1, 3)
  48. k = self.k(x).reshape(B, N, self.num_heads, C // self.num_heads).permute(0, 2, 1, 3)
  49. v = self.v(x).reshape(B, N, self.num_heads, C // self.num_heads).permute(0, 2, 1, 3)
  50. if self.fused_attn:
  51. x_cls = torch.nn.functional.scaled_dot_product_attention(
  52. q, k, v,
  53. dropout_p=self.attn_drop.p if self.training else 0.,
  54. )
  55. else:
  56. q = q * self.scale
  57. attn = q @ k.transpose(-2, -1)
  58. attn = attn.softmax(dim=-1)
  59. attn = self.attn_drop(attn)
  60. x_cls = attn @ v
  61. x_cls = x_cls.transpose(1, 2).reshape(B, 1, C)
  62. x_cls = self.proj(x_cls)
  63. x_cls = self.proj_drop(x_cls)
  64. return x_cls
  65. class LayerScaleBlockClassAttn(nn.Module):
  66. # taken from https://github.com/rwightman/pytorch-image-models/blob/master/timm/models/vision_transformer.py
  67. # with slight modifications to add CA and LayerScale
  68. def __init__(
  69. self,
  70. dim: int,
  71. num_heads: int,
  72. mlp_ratio: float = 4.,
  73. qkv_bias: bool = False,
  74. proj_drop: float = 0.,
  75. attn_drop: float = 0.,
  76. drop_path: float = 0.,
  77. act_layer: Type[nn.Module] = nn.GELU,
  78. norm_layer: Type[nn.Module] = nn.LayerNorm,
  79. attn_block: Type[nn.Module] = ClassAttn,
  80. mlp_block: Type[nn.Module] = Mlp,
  81. init_values: float = 1e-4,
  82. device=None,
  83. dtype=None,
  84. ):
  85. super().__init__()
  86. dd = {'device': device, 'dtype': dtype}
  87. self.norm1 = norm_layer(dim, **dd)
  88. self.attn = attn_block(
  89. dim,
  90. num_heads=num_heads,
  91. qkv_bias=qkv_bias,
  92. attn_drop=attn_drop,
  93. proj_drop=proj_drop,
  94. **dd,
  95. )
  96. self.drop_path = DropPath(drop_path) if drop_path > 0. else nn.Identity()
  97. self.norm2 = norm_layer(dim, **dd)
  98. mlp_hidden_dim = int(dim * mlp_ratio)
  99. self.mlp = mlp_block(
  100. in_features=dim,
  101. hidden_features=mlp_hidden_dim,
  102. act_layer=act_layer,
  103. drop=proj_drop,
  104. **dd,
  105. )
  106. self.gamma_1 = nn.Parameter(init_values * torch.ones(dim, **dd))
  107. self.gamma_2 = nn.Parameter(init_values * torch.ones(dim, **dd))
  108. def forward(self, x, x_cls):
  109. u = torch.cat((x_cls, x), dim=1)
  110. x_cls = x_cls + self.drop_path(self.gamma_1 * self.attn(self.norm1(u)))
  111. x_cls = x_cls + self.drop_path(self.gamma_2 * self.mlp(self.norm2(x_cls)))
  112. return x_cls
  113. class TalkingHeadAttn(nn.Module):
  114. # taken from https://github.com/rwightman/pytorch-image-models/blob/master/timm/models/vision_transformer.py
  115. # with slight modifications to add Talking Heads Attention (https://arxiv.org/pdf/2003.02436v1.pdf)
  116. def __init__(
  117. self,
  118. dim: int,
  119. num_heads: int = 8,
  120. qkv_bias: bool = False,
  121. attn_drop: float = 0.,
  122. proj_drop: float = 0.,
  123. device=None,
  124. dtype=None,
  125. ):
  126. super().__init__()
  127. dd = {'device': device, 'dtype': dtype}
  128. self.num_heads = num_heads
  129. head_dim = dim // num_heads
  130. self.scale = head_dim ** -0.5
  131. self.qkv = nn.Linear(dim, dim * 3, bias=qkv_bias, **dd)
  132. self.attn_drop = nn.Dropout(attn_drop)
  133. self.proj = nn.Linear(dim, dim, **dd)
  134. self.proj_l = nn.Linear(num_heads, num_heads, **dd)
  135. self.proj_w = nn.Linear(num_heads, num_heads, **dd)
  136. self.proj_drop = nn.Dropout(proj_drop)
  137. def forward(self, x):
  138. B, N, C = x.shape
  139. qkv = self.qkv(x).reshape(B, N, 3, self.num_heads, C // self.num_heads).permute(2, 0, 3, 1, 4)
  140. q, k, v = qkv[0] * self.scale, qkv[1], qkv[2]
  141. attn = q @ k.transpose(-2, -1)
  142. attn = self.proj_l(attn.permute(0, 2, 3, 1)).permute(0, 3, 1, 2)
  143. attn = attn.softmax(dim=-1)
  144. attn = self.proj_w(attn.permute(0, 2, 3, 1)).permute(0, 3, 1, 2)
  145. attn = self.attn_drop(attn)
  146. x = (attn @ v).transpose(1, 2).reshape(B, N, C)
  147. x = self.proj(x)
  148. x = self.proj_drop(x)
  149. return x
  150. class LayerScaleBlock(nn.Module):
  151. # taken from https://github.com/rwightman/pytorch-image-models/blob/master/timm/models/vision_transformer.py
  152. # with slight modifications to add layerScale
  153. def __init__(
  154. self,
  155. dim: int,
  156. num_heads: int,
  157. mlp_ratio: float = 4.,
  158. qkv_bias: bool = False,
  159. proj_drop: float = 0.,
  160. attn_drop: float = 0.,
  161. drop_path: float = 0.,
  162. act_layer: Type[nn.Module] = nn.GELU,
  163. norm_layer: Type[nn.Module] = nn.LayerNorm,
  164. attn_block: Type[nn.Module] = TalkingHeadAttn,
  165. mlp_block: Type[nn.Module] = Mlp,
  166. init_values: float = 1e-4,
  167. device=None,
  168. dtype=None,
  169. ):
  170. super().__init__()
  171. dd = {'device': device, 'dtype': dtype}
  172. self.norm1 = norm_layer(dim, **dd)
  173. self.attn = attn_block(
  174. dim,
  175. num_heads=num_heads,
  176. qkv_bias=qkv_bias,
  177. attn_drop=attn_drop,
  178. proj_drop=proj_drop,
  179. **dd,
  180. )
  181. self.drop_path = DropPath(drop_path) if drop_path > 0. else nn.Identity()
  182. self.norm2 = norm_layer(dim, **dd)
  183. mlp_hidden_dim = int(dim * mlp_ratio)
  184. self.mlp = mlp_block(
  185. in_features=dim,
  186. hidden_features=mlp_hidden_dim,
  187. act_layer=act_layer,
  188. drop=proj_drop,
  189. **dd,
  190. )
  191. self.gamma_1 = nn.Parameter(init_values * torch.ones(dim, **dd))
  192. self.gamma_2 = nn.Parameter(init_values * torch.ones(dim, **dd))
  193. def forward(self, x):
  194. x = x + self.drop_path(self.gamma_1 * self.attn(self.norm1(x)))
  195. x = x + self.drop_path(self.gamma_2 * self.mlp(self.norm2(x)))
  196. return x
  197. class Cait(nn.Module):
  198. # taken from https://github.com/rwightman/pytorch-image-models/blob/master/timm/models/vision_transformer.py
  199. # with slight modifications to adapt to our cait models
  200. def __init__(
  201. self,
  202. img_size: int = 224,
  203. patch_size: int = 16,
  204. in_chans: int = 3,
  205. num_classes: int = 1000,
  206. global_pool: str = 'token',
  207. embed_dim: int = 768,
  208. depth: int = 12,
  209. num_heads: int = 12,
  210. mlp_ratio: float = 4.,
  211. qkv_bias: bool = True,
  212. drop_rate: float = 0.,
  213. pos_drop_rate: float = 0.,
  214. proj_drop_rate: float = 0.,
  215. attn_drop_rate: float = 0.,
  216. drop_path_rate: float = 0.,
  217. block_layers: Type[nn.Module] = LayerScaleBlock,
  218. block_layers_token: Type[nn.Module] = LayerScaleBlockClassAttn,
  219. patch_layer: Type[nn.Module] = PatchEmbed,
  220. norm_layer: Type[nn.Module] = partial(nn.LayerNorm, eps=1e-6),
  221. act_layer: Type[nn.Module] = nn.GELU,
  222. attn_block: Type[nn.Module] = TalkingHeadAttn,
  223. mlp_block: Type[nn.Module] = Mlp,
  224. init_values: float = 1e-4,
  225. attn_block_token_only: Type[nn.Module] = ClassAttn,
  226. mlp_block_token_only: Type[nn.Module] = Mlp,
  227. depth_token_only: int = 2,
  228. mlp_ratio_token_only: float = 4.0,
  229. device=None,
  230. dtype=None,
  231. ):
  232. super().__init__()
  233. dd = {'device': device, 'dtype': dtype}
  234. assert global_pool in ('', 'token', 'avg')
  235. self.num_classes = num_classes
  236. self.global_pool = global_pool
  237. self.num_features = self.head_hidden_size = self.embed_dim = embed_dim
  238. self.grad_checkpointing = False
  239. self.patch_embed = patch_layer(
  240. img_size=img_size,
  241. patch_size=patch_size,
  242. in_chans=in_chans,
  243. embed_dim=embed_dim,
  244. **dd,
  245. )
  246. num_patches = self.patch_embed.num_patches
  247. r = self.patch_embed.feat_ratio() if hasattr(self.patch_embed, 'feat_ratio') else patch_size
  248. self.cls_token = nn.Parameter(torch.zeros(1, 1, embed_dim, **dd))
  249. self.pos_embed = nn.Parameter(torch.zeros(1, num_patches, embed_dim, **dd))
  250. self.pos_drop = nn.Dropout(p=pos_drop_rate)
  251. dpr = [drop_path_rate for i in range(depth)]
  252. self.blocks = nn.Sequential(*[block_layers(
  253. dim=embed_dim,
  254. num_heads=num_heads,
  255. mlp_ratio=mlp_ratio,
  256. qkv_bias=qkv_bias,
  257. proj_drop=proj_drop_rate,
  258. attn_drop=attn_drop_rate,
  259. drop_path=dpr[i],
  260. norm_layer=norm_layer,
  261. act_layer=act_layer,
  262. attn_block=attn_block,
  263. mlp_block=mlp_block,
  264. init_values=init_values,
  265. **dd,
  266. ) for i in range(depth)])
  267. self.feature_info = [dict(num_chs=embed_dim, reduction=r, module=f'blocks.{i}') for i in range(depth)]
  268. self.blocks_token_only = nn.ModuleList([block_layers_token(
  269. dim=embed_dim,
  270. num_heads=num_heads,
  271. mlp_ratio=mlp_ratio_token_only,
  272. qkv_bias=qkv_bias,
  273. norm_layer=norm_layer,
  274. act_layer=act_layer,
  275. attn_block=attn_block_token_only,
  276. mlp_block=mlp_block_token_only,
  277. init_values=init_values,
  278. **dd,
  279. ) for _ in range(depth_token_only)])
  280. self.norm = norm_layer(embed_dim, **dd)
  281. self.head_drop = nn.Dropout(drop_rate)
  282. self.head = nn.Linear(embed_dim, num_classes, **dd) if num_classes > 0 else nn.Identity()
  283. trunc_normal_(self.pos_embed, std=.02)
  284. trunc_normal_(self.cls_token, std=.02)
  285. self.apply(self._init_weights)
  286. def _init_weights(self, m):
  287. if isinstance(m, nn.Linear):
  288. trunc_normal_(m.weight, std=.02)
  289. if isinstance(m, nn.Linear) and m.bias is not None:
  290. nn.init.constant_(m.bias, 0)
  291. elif isinstance(m, nn.LayerNorm):
  292. nn.init.constant_(m.bias, 0)
  293. nn.init.constant_(m.weight, 1.0)
  294. @torch.jit.ignore
  295. def no_weight_decay(self):
  296. return {'pos_embed', 'cls_token'}
  297. @torch.jit.ignore
  298. def set_grad_checkpointing(self, enable=True):
  299. self.grad_checkpointing = enable
  300. @torch.jit.ignore
  301. def group_matcher(self, coarse=False):
  302. def _matcher(name):
  303. if any([name.startswith(n) for n in ('cls_token', 'pos_embed', 'patch_embed')]):
  304. return 0
  305. elif name.startswith('blocks.'):
  306. return int(name.split('.')[1]) + 1
  307. elif name.startswith('blocks_token_only.'):
  308. # overlap token only blocks with last blocks
  309. to_offset = len(self.blocks) - len(self.blocks_token_only) + 1
  310. return int(name.split('.')[1]) + to_offset
  311. elif name.startswith('norm.'):
  312. return len(self.blocks)
  313. else:
  314. return float('inf')
  315. return _matcher
  316. @torch.jit.ignore
  317. def get_classifier(self) -> nn.Module:
  318. return self.head
  319. def reset_classifier(self, num_classes: int, global_pool: Optional[str] = None):
  320. self.num_classes = num_classes
  321. if global_pool is not None:
  322. assert global_pool in ('', 'token', 'avg')
  323. self.global_pool = global_pool
  324. self.head = nn.Linear(self.num_features, num_classes) if num_classes > 0 else nn.Identity()
  325. def forward_intermediates(
  326. self,
  327. x: torch.Tensor,
  328. indices: Optional[Union[int, List[int]]] = None,
  329. norm: bool = False,
  330. stop_early: bool = False,
  331. output_fmt: str = 'NCHW',
  332. intermediates_only: bool = False,
  333. ) -> Union[List[torch.Tensor], Tuple[torch.Tensor, List[torch.Tensor]]]:
  334. """ Forward features that returns intermediates.
  335. Args:
  336. x: Input image tensor
  337. indices: Take last n blocks if int, all if None, select matching indices if sequence
  338. norm: Apply norm layer to all intermediates
  339. stop_early: Stop iterating over blocks when last desired intermediate hit
  340. output_fmt: Shape of intermediate feature outputs
  341. intermediates_only: Only return intermediate features
  342. """
  343. assert output_fmt in ('NCHW', 'NLC'), 'Output format must be one of NCHW or NLC.'
  344. reshape = output_fmt == 'NCHW'
  345. intermediates = []
  346. take_indices, max_index = feature_take_indices(len(self.blocks), indices)
  347. # forward pass
  348. B, _, height, width = x.shape
  349. x = self.patch_embed(x)
  350. x = x + self.pos_embed
  351. x = self.pos_drop(x)
  352. if torch.jit.is_scripting() or not stop_early: # can't slice blocks in torchscript
  353. blocks = self.blocks
  354. else:
  355. blocks = self.blocks[:max_index + 1]
  356. for i, blk in enumerate(blocks):
  357. if self.grad_checkpointing and not torch.jit.is_scripting():
  358. x = checkpoint(blk, x)
  359. else:
  360. x = blk(x)
  361. if i in take_indices:
  362. # normalize intermediates with final norm layer if enabled
  363. intermediates.append(self.norm(x) if norm else x)
  364. # process intermediates
  365. if reshape:
  366. # reshape to BCHW output format
  367. H, W = self.patch_embed.dynamic_feat_size((height, width))
  368. intermediates = [y.reshape(B, H, W, -1).permute(0, 3, 1, 2).contiguous() for y in intermediates]
  369. if intermediates_only:
  370. return intermediates
  371. # NOTE not supporting return of class tokens
  372. cls_tokens = self.cls_token.expand(x.shape[0], -1, -1)
  373. for i, blk in enumerate(self.blocks_token_only):
  374. cls_tokens = blk(x, cls_tokens)
  375. x = torch.cat((cls_tokens, x), dim=1)
  376. x = self.norm(x)
  377. return x, intermediates
  378. def prune_intermediate_layers(
  379. self,
  380. indices: Union[int, List[int]] = 1,
  381. prune_norm: bool = False,
  382. prune_head: bool = True,
  383. ):
  384. """ Prune layers not required for specified intermediates.
  385. """
  386. take_indices, max_index = feature_take_indices(len(self.blocks), indices)
  387. self.blocks = self.blocks[:max_index + 1] # truncate blocks
  388. if prune_norm:
  389. self.norm = nn.Identity()
  390. if prune_head:
  391. self.blocks_token_only = nn.ModuleList() # prune token blocks with head
  392. self.reset_classifier(0, '')
  393. return take_indices
  394. def forward_features(self, x):
  395. x = self.patch_embed(x)
  396. x = x + self.pos_embed
  397. x = self.pos_drop(x)
  398. if self.grad_checkpointing and not torch.jit.is_scripting():
  399. x = checkpoint_seq(self.blocks, x)
  400. else:
  401. x = self.blocks(x)
  402. cls_tokens = self.cls_token.expand(x.shape[0], -1, -1)
  403. for i, blk in enumerate(self.blocks_token_only):
  404. cls_tokens = blk(x, cls_tokens)
  405. x = torch.cat((cls_tokens, x), dim=1)
  406. x = self.norm(x)
  407. return x
  408. def forward_head(self, x, pre_logits: bool = False):
  409. if self.global_pool:
  410. x = x[:, 1:].mean(dim=1) if self.global_pool == 'avg' else x[:, 0]
  411. x = self.head_drop(x)
  412. return x if pre_logits else self.head(x)
  413. def forward(self, x):
  414. x = self.forward_features(x)
  415. x = self.forward_head(x)
  416. return x
  417. def checkpoint_filter_fn(state_dict, model=None):
  418. if 'model' in state_dict:
  419. state_dict = state_dict['model']
  420. checkpoint_no_module = {}
  421. for k, v in state_dict.items():
  422. checkpoint_no_module[k.replace('module.', '')] = v
  423. return checkpoint_no_module
  424. def _create_cait(variant, pretrained=False, **kwargs):
  425. out_indices = kwargs.pop('out_indices', 3)
  426. model = build_model_with_cfg(
  427. Cait,
  428. variant,
  429. pretrained,
  430. pretrained_filter_fn=checkpoint_filter_fn,
  431. feature_cfg=dict(out_indices=out_indices, feature_cls='getter'),
  432. **kwargs,
  433. )
  434. return model
  435. def _cfg(url='', **kwargs):
  436. return {
  437. 'url': url,
  438. 'num_classes': 1000, 'input_size': (3, 384, 384), 'pool_size': None,
  439. 'crop_pct': 1.0, 'interpolation': 'bicubic', 'fixed_input_size': True,
  440. 'mean': IMAGENET_DEFAULT_MEAN, 'std': IMAGENET_DEFAULT_STD,
  441. 'first_conv': 'patch_embed.proj', 'classifier': 'head',
  442. 'license': 'apache-2.0',
  443. **kwargs
  444. }
  445. default_cfgs = generate_default_cfgs({
  446. 'cait_xxs24_224.fb_dist_in1k': _cfg(
  447. hf_hub_id='timm/',
  448. url='https://dl.fbaipublicfiles.com/deit/XXS24_224.pth',
  449. input_size=(3, 224, 224),
  450. ),
  451. 'cait_xxs24_384.fb_dist_in1k': _cfg(
  452. hf_hub_id='timm/',
  453. url='https://dl.fbaipublicfiles.com/deit/XXS24_384.pth',
  454. ),
  455. 'cait_xxs36_224.fb_dist_in1k': _cfg(
  456. hf_hub_id='timm/',
  457. url='https://dl.fbaipublicfiles.com/deit/XXS36_224.pth',
  458. input_size=(3, 224, 224),
  459. ),
  460. 'cait_xxs36_384.fb_dist_in1k': _cfg(
  461. hf_hub_id='timm/',
  462. url='https://dl.fbaipublicfiles.com/deit/XXS36_384.pth',
  463. ),
  464. 'cait_xs24_384.fb_dist_in1k': _cfg(
  465. hf_hub_id='timm/',
  466. url='https://dl.fbaipublicfiles.com/deit/XS24_384.pth',
  467. ),
  468. 'cait_s24_224.fb_dist_in1k': _cfg(
  469. hf_hub_id='timm/',
  470. url='https://dl.fbaipublicfiles.com/deit/S24_224.pth',
  471. input_size=(3, 224, 224),
  472. ),
  473. 'cait_s24_384.fb_dist_in1k': _cfg(
  474. hf_hub_id='timm/',
  475. url='https://dl.fbaipublicfiles.com/deit/S24_384.pth',
  476. ),
  477. 'cait_s36_384.fb_dist_in1k': _cfg(
  478. hf_hub_id='timm/',
  479. url='https://dl.fbaipublicfiles.com/deit/S36_384.pth',
  480. ),
  481. 'cait_m36_384.fb_dist_in1k': _cfg(
  482. hf_hub_id='timm/',
  483. url='https://dl.fbaipublicfiles.com/deit/M36_384.pth',
  484. ),
  485. 'cait_m48_448.fb_dist_in1k': _cfg(
  486. hf_hub_id='timm/',
  487. url='https://dl.fbaipublicfiles.com/deit/M48_448.pth',
  488. input_size=(3, 448, 448),
  489. ),
  490. })
  491. @register_model
  492. def cait_xxs24_224(pretrained=False, **kwargs) -> Cait:
  493. model_args = dict(patch_size=16, embed_dim=192, depth=24, num_heads=4, init_values=1e-5)
  494. model = _create_cait('cait_xxs24_224', pretrained=pretrained, **dict(model_args, **kwargs))
  495. return model
  496. @register_model
  497. def cait_xxs24_384(pretrained=False, **kwargs) -> Cait:
  498. model_args = dict(patch_size=16, embed_dim=192, depth=24, num_heads=4, init_values=1e-5)
  499. model = _create_cait('cait_xxs24_384', pretrained=pretrained, **dict(model_args, **kwargs))
  500. return model
  501. @register_model
  502. def cait_xxs36_224(pretrained=False, **kwargs) -> Cait:
  503. model_args = dict(patch_size=16, embed_dim=192, depth=36, num_heads=4, init_values=1e-5)
  504. model = _create_cait('cait_xxs36_224', pretrained=pretrained, **dict(model_args, **kwargs))
  505. return model
  506. @register_model
  507. def cait_xxs36_384(pretrained=False, **kwargs) -> Cait:
  508. model_args = dict(patch_size=16, embed_dim=192, depth=36, num_heads=4, init_values=1e-5)
  509. model = _create_cait('cait_xxs36_384', pretrained=pretrained, **dict(model_args, **kwargs))
  510. return model
  511. @register_model
  512. def cait_xs24_384(pretrained=False, **kwargs) -> Cait:
  513. model_args = dict(patch_size=16, embed_dim=288, depth=24, num_heads=6, init_values=1e-5)
  514. model = _create_cait('cait_xs24_384', pretrained=pretrained, **dict(model_args, **kwargs))
  515. return model
  516. @register_model
  517. def cait_s24_224(pretrained=False, **kwargs) -> Cait:
  518. model_args = dict(patch_size=16, embed_dim=384, depth=24, num_heads=8, init_values=1e-5)
  519. model = _create_cait('cait_s24_224', pretrained=pretrained, **dict(model_args, **kwargs))
  520. return model
  521. @register_model
  522. def cait_s24_384(pretrained=False, **kwargs) -> Cait:
  523. model_args = dict(patch_size=16, embed_dim=384, depth=24, num_heads=8, init_values=1e-5)
  524. model = _create_cait('cait_s24_384', pretrained=pretrained, **dict(model_args, **kwargs))
  525. return model
  526. @register_model
  527. def cait_s36_384(pretrained=False, **kwargs) -> Cait:
  528. model_args = dict(patch_size=16, embed_dim=384, depth=36, num_heads=8, init_values=1e-6)
  529. model = _create_cait('cait_s36_384', pretrained=pretrained, **dict(model_args, **kwargs))
  530. return model
  531. @register_model
  532. def cait_m36_384(pretrained=False, **kwargs) -> Cait:
  533. model_args = dict(patch_size=16, embed_dim=768, depth=36, num_heads=16, init_values=1e-6)
  534. model = _create_cait('cait_m36_384', pretrained=pretrained, **dict(model_args, **kwargs))
  535. return model
  536. @register_model
  537. def cait_m48_448(pretrained=False, **kwargs) -> Cait:
  538. model_args = dict(patch_size=16, embed_dim=768, depth=48, num_heads=16, init_values=1e-6)
  539. model = _create_cait('cait_m48_448', pretrained=pretrained, **dict(model_args, **kwargs))
  540. return model