xcit.py 43 KB

12345678910111213141516171819202122232425262728293031323334353637383940414243444546474849505152535455565758596061626364656667686970717273747576777879808182838485868788899091929394959697989910010110210310410510610710810911011111211311411511611711811912012112212312412512612712812913013113213313413513613713813914014114214314414514614714814915015115215315415515615715815916016116216316416516616716816917017117217317417517617717817918018118218318418518618718818919019119219319419519619719819920020120220320420520620720820921021121221321421521621721821922022122222322422522622722822923023123223323423523623723823924024124224324424524624724824925025125225325425525625725825926026126226326426526626726826927027127227327427527627727827928028128228328428528628728828929029129229329429529629729829930030130230330430530630730830931031131231331431531631731831932032132232332432532632732832933033133233333433533633733833934034134234334434534634734834935035135235335435535635735835936036136236336436536636736836937037137237337437537637737837938038138238338438538638738838939039139239339439539639739839940040140240340440540640740840941041141241341441541641741841942042142242342442542642742842943043143243343443543643743843944044144244344444544644744844945045145245345445545645745845946046146246346446546646746846947047147247347447547647747847948048148248348448548648748848949049149249349449549649749849950050150250350450550650750850951051151251351451551651751851952052152252352452552652752852953053153253353453553653753853954054154254354454554654754854955055155255355455555655755855956056156256356456556656756856957057157257357457557657757857958058158258358458558658758858959059159259359459559659759859960060160260360460560660760860961061161261361461561661761861962062162262362462562662762862963063163263363463563663763863964064164264364464564664764864965065165265365465565665765865966066166266366466566666766866967067167267367467567667767867968068168268368468568668768868969069169269369469569669769869970070170270370470570670770870971071171271371471571671771871972072172272372472572672772872973073173273373473573673773873974074174274374474574674774874975075175275375475575675775875976076176276376476576676776876977077177277377477577677777877978078178278378478578678778878979079179279379479579679779879980080180280380480580680780880981081181281381481581681781881982082182282382482582682782882983083183283383483583683783883984084184284384484584684784884985085185285385485585685785885986086186286386486586686786886987087187287387487587687787887988088188288388488588688788888989089189289389489589689789889990090190290390490590690790890991091191291391491591691791891992092192292392492592692792892993093193293393493593693793893994094194294394494594694794894995095195295395495595695795895996096196296396496596696796896997097197297397497597697797897998098198298398498598698798898999099199299399499599699799899910001001100210031004100510061007100810091010101110121013101410151016101710181019102010211022102310241025102610271028102910301031103210331034103510361037103810391040104110421043104410451046104710481049105010511052105310541055105610571058105910601061106210631064106510661067106810691070107110721073107410751076107710781079108010811082108310841085108610871088
  1. """ Cross-Covariance Image Transformer (XCiT) in PyTorch
  2. Paper:
  3. - https://arxiv.org/abs/2106.09681
  4. Same as the official implementation, with some minor adaptations, original copyright below
  5. - https://github.com/facebookresearch/xcit/blob/master/xcit.py
  6. Modifications and additions for timm hacked together by / Copyright 2021, Ross Wightman
  7. """
  8. # Copyright (c) 2015-present, Facebook, Inc.
  9. # All rights reserved.
  10. import math
  11. from functools import partial
  12. from typing import List, Optional, Tuple, Union, Type, Any
  13. import torch
  14. import torch.nn as nn
  15. from timm.data import IMAGENET_DEFAULT_MEAN, IMAGENET_DEFAULT_STD
  16. from timm.layers import DropPath, trunc_normal_, to_2tuple, use_fused_attn, Mlp
  17. from ._builder import build_model_with_cfg
  18. from ._features import feature_take_indices
  19. from ._features_fx import register_notrace_module
  20. from ._manipulate import checkpoint
  21. from ._registry import register_model, generate_default_cfgs, register_model_deprecations
  22. from .cait import ClassAttn
  23. __all__ = ['Xcit'] # model_registry will add each entrypoint fn to this
  24. @register_notrace_module # reason: FX can't symbolically trace torch.arange in forward method
  25. class PositionalEncodingFourier(nn.Module):
  26. """
  27. Positional encoding relying on a fourier kernel matching the one used in the "Attention is all you Need" paper.
  28. Based on the official XCiT code
  29. - https://github.com/facebookresearch/xcit/blob/master/xcit.py
  30. """
  31. def __init__(
  32. self,
  33. hidden_dim: int = 32,
  34. dim: int = 768,
  35. temperature: float = 10000,
  36. device=None,
  37. dtype=None,
  38. ):
  39. dd = {'device': device, 'dtype': dtype}
  40. super().__init__()
  41. self.token_projection = nn.Conv2d(hidden_dim * 2, dim, kernel_size=1, **dd)
  42. self.scale = 2 * math.pi
  43. self.temperature = temperature
  44. self.hidden_dim = hidden_dim
  45. self.dim = dim
  46. self.eps = 1e-6
  47. def forward(self, B: int, H: int, W: int):
  48. device = self.token_projection.weight.device
  49. dtype = self.token_projection.weight.dtype
  50. y_embed = torch.arange(1, H + 1, device=device).to(torch.float32).unsqueeze(1).repeat(1, 1, W)
  51. x_embed = torch.arange(1, W + 1, device=device).to(torch.float32).repeat(1, H, 1)
  52. y_embed = y_embed / (y_embed[:, -1:, :] + self.eps) * self.scale
  53. x_embed = x_embed / (x_embed[:, :, -1:] + self.eps) * self.scale
  54. dim_t = torch.arange(self.hidden_dim, device=device).to(torch.float32)
  55. dim_t = self.temperature ** (2 * torch.div(dim_t, 2, rounding_mode='floor') / self.hidden_dim)
  56. pos_x = x_embed[:, :, :, None] / dim_t
  57. pos_y = y_embed[:, :, :, None] / dim_t
  58. pos_x = torch.stack([pos_x[:, :, :, 0::2].sin(), pos_x[:, :, :, 1::2].cos()], dim=4).flatten(3)
  59. pos_y = torch.stack([pos_y[:, :, :, 0::2].sin(), pos_y[:, :, :, 1::2].cos()], dim=4).flatten(3)
  60. pos = torch.cat((pos_y, pos_x), dim=3).permute(0, 3, 1, 2)
  61. pos = self.token_projection(pos.to(dtype))
  62. return pos.repeat(B, 1, 1, 1) # (B, C, H, W)
  63. def conv3x3(in_planes, out_planes, stride=1, device=None, dtype=None):
  64. """3x3 convolution + batch norm"""
  65. dd = {'device': device, 'dtype': dtype}
  66. return torch.nn.Sequential(
  67. nn.Conv2d(in_planes, out_planes, kernel_size=3, stride=stride, padding=1, bias=False, **dd),
  68. nn.BatchNorm2d(out_planes, **dd)
  69. )
  70. class ConvPatchEmbed(nn.Module):
  71. """Image to Patch Embedding using multiple convolutional layers"""
  72. def __init__(
  73. self,
  74. img_size: int = 224,
  75. patch_size: int = 16,
  76. in_chans: int = 3,
  77. embed_dim: int = 768,
  78. act_layer: Type[nn.Module] = nn.GELU,
  79. device=None,
  80. dtype=None,
  81. ):
  82. dd = {'device': device, 'dtype': dtype}
  83. super().__init__()
  84. img_size = to_2tuple(img_size)
  85. num_patches = (img_size[1] // patch_size) * (img_size[0] // patch_size)
  86. self.img_size = img_size
  87. self.patch_size = patch_size
  88. self.num_patches = num_patches
  89. if patch_size == 16:
  90. self.proj = torch.nn.Sequential(
  91. conv3x3(in_chans, embed_dim // 8, 2, **dd),
  92. act_layer(),
  93. conv3x3(embed_dim // 8, embed_dim // 4, 2, **dd),
  94. act_layer(),
  95. conv3x3(embed_dim // 4, embed_dim // 2, 2, **dd),
  96. act_layer(),
  97. conv3x3(embed_dim // 2, embed_dim, 2, **dd),
  98. )
  99. elif patch_size == 8:
  100. self.proj = torch.nn.Sequential(
  101. conv3x3(in_chans, embed_dim // 4, 2, **dd),
  102. act_layer(),
  103. conv3x3(embed_dim // 4, embed_dim // 2, 2, **dd),
  104. act_layer(),
  105. conv3x3(embed_dim // 2, embed_dim, 2, **dd),
  106. )
  107. else:
  108. raise('For convolutional projection, patch size has to be in [8, 16]')
  109. def forward(self, x):
  110. x = self.proj(x)
  111. Hp, Wp = x.shape[2], x.shape[3]
  112. x = x.flatten(2).transpose(1, 2) # (B, N, C)
  113. return x, (Hp, Wp)
  114. class LPI(nn.Module):
  115. """
  116. Local Patch Interaction module that allows explicit communication between tokens in 3x3 windows to augment the
  117. implicit communication performed by the block diagonal scatter attention. Implemented using 2 layers of separable
  118. 3x3 convolutions with GeLU and BatchNorm2d
  119. """
  120. def __init__(
  121. self,
  122. in_features: int,
  123. out_features: Optional[int] = None,
  124. act_layer: Type[nn.Module] = nn.GELU,
  125. kernel_size: int = 3,
  126. device=None,
  127. dtype=None,
  128. ):
  129. super().__init__()
  130. dd = {'device': device, 'dtype': dtype}
  131. out_features = out_features or in_features
  132. padding = kernel_size // 2
  133. self.conv1 = torch.nn.Conv2d(
  134. in_features, in_features, kernel_size=kernel_size, padding=padding, groups=in_features, **dd)
  135. self.act = act_layer()
  136. self.bn = nn.BatchNorm2d(in_features, **dd)
  137. self.conv2 = torch.nn.Conv2d(
  138. in_features, out_features, kernel_size=kernel_size, padding=padding, groups=out_features, **dd)
  139. def forward(self, x, H: int, W: int):
  140. B, N, C = x.shape
  141. x = x.permute(0, 2, 1).reshape(B, C, H, W)
  142. x = self.conv1(x)
  143. x = self.act(x)
  144. x = self.bn(x)
  145. x = self.conv2(x)
  146. x = x.reshape(B, C, N).permute(0, 2, 1)
  147. return x
  148. class ClassAttentionBlock(nn.Module):
  149. """Class Attention Layer as in CaiT https://arxiv.org/abs/2103.17239"""
  150. def __init__(
  151. self,
  152. dim: int,
  153. num_heads: int,
  154. mlp_ratio: float = 4.,
  155. qkv_bias: bool = False,
  156. proj_drop: float = 0.,
  157. attn_drop: float = 0.,
  158. drop_path: float = 0.,
  159. act_layer: Type[nn.Module] = nn.GELU,
  160. norm_layer: Type[nn.Module] = nn.LayerNorm,
  161. eta: Optional[float] = 1.,
  162. tokens_norm: bool = False,
  163. device=None,
  164. dtype=None,
  165. ):
  166. dd = {'device': device, 'dtype': dtype}
  167. super().__init__()
  168. self.norm1 = norm_layer(dim, **dd)
  169. self.attn = ClassAttn(
  170. dim,
  171. num_heads=num_heads,
  172. qkv_bias=qkv_bias,
  173. attn_drop=attn_drop,
  174. proj_drop=proj_drop,
  175. **dd,
  176. )
  177. self.drop_path1 = DropPath(drop_path) if drop_path > 0. else nn.Identity()
  178. self.norm2 = norm_layer(dim, **dd)
  179. self.mlp = Mlp(
  180. in_features=dim,
  181. hidden_features=int(dim * mlp_ratio),
  182. act_layer=act_layer,
  183. drop=proj_drop,
  184. **dd,
  185. )
  186. self.drop_path2 = DropPath(drop_path) if drop_path > 0. else nn.Identity()
  187. if eta is not None: # LayerScale Initialization (no layerscale when None)
  188. self.gamma1 = nn.Parameter(eta * torch.ones(dim, **dd))
  189. self.gamma2 = nn.Parameter(eta * torch.ones(dim, **dd))
  190. else:
  191. self.gamma1, self.gamma2 = 1.0, 1.0
  192. # See https://github.com/rwightman/pytorch-image-models/pull/747#issuecomment-877795721
  193. self.tokens_norm = tokens_norm
  194. def forward(self, x):
  195. x_norm1 = self.norm1(x)
  196. x_attn = torch.cat([self.attn(x_norm1), x_norm1[:, 1:]], dim=1)
  197. x = x + self.drop_path1(self.gamma1 * x_attn)
  198. if self.tokens_norm:
  199. x = self.norm2(x)
  200. else:
  201. x = torch.cat([self.norm2(x[:, 0:1]), x[:, 1:]], dim=1)
  202. x_res = x
  203. cls_token = x[:, 0:1]
  204. cls_token = self.gamma2 * self.mlp(cls_token)
  205. x = torch.cat([cls_token, x[:, 1:]], dim=1)
  206. x = x_res + self.drop_path2(x)
  207. return x
  208. class XCA(nn.Module):
  209. fused_attn: torch.jit.Final[bool]
  210. """ Cross-Covariance Attention (XCA)
  211. Operation where the channels are updated using a weighted sum. The weights are obtained from the (softmax
  212. normalized) Cross-covariance matrix (Q^T \\cdot K \\in d_h \\times d_h)
  213. """
  214. def __init__(
  215. self,
  216. dim: int,
  217. num_heads: int = 8,
  218. qkv_bias: bool = False,
  219. attn_drop: float = 0.,
  220. proj_drop: float = 0.,
  221. device=None,
  222. dtype=None,
  223. ):
  224. dd = {'device': device, 'dtype': dtype}
  225. super().__init__()
  226. self.num_heads = num_heads
  227. self.fused_attn = use_fused_attn(experimental=True)
  228. self.temperature = nn.Parameter(torch.ones(num_heads, 1, 1, **dd))
  229. self.qkv = nn.Linear(dim, dim * 3, bias=qkv_bias, **dd)
  230. self.attn_drop = nn.Dropout(attn_drop)
  231. self.proj = nn.Linear(dim, dim, **dd)
  232. self.proj_drop = nn.Dropout(proj_drop)
  233. def forward(self, x):
  234. B, N, C = x.shape
  235. # Result of next line is (qkv, B, num (H)eads, (C')hannels per head, N)
  236. qkv = self.qkv(x).reshape(B, N, 3, self.num_heads, C // self.num_heads).permute(2, 0, 3, 4, 1)
  237. q, k, v = qkv.unbind(0) # make torchscript happy (cannot use tensor as tuple)
  238. if self.fused_attn:
  239. q = torch.nn.functional.normalize(q, dim=-1) * self.temperature
  240. k = torch.nn.functional.normalize(k, dim=-1)
  241. x = torch.nn.functional.scaled_dot_product_attention(q, k, v, scale=1.0)
  242. else:
  243. # Paper section 3.2 l2-Normalization and temperature scaling
  244. q = torch.nn.functional.normalize(q, dim=-1)
  245. k = torch.nn.functional.normalize(k, dim=-1)
  246. attn = (q @ k.transpose(-2, -1)) * self.temperature
  247. attn = attn.softmax(dim=-1)
  248. attn = self.attn_drop(attn)
  249. x = attn @ v
  250. x = x.permute(0, 3, 1, 2).reshape(B, N, C)
  251. x = self.proj(x)
  252. x = self.proj_drop(x)
  253. return x
  254. @torch.jit.ignore
  255. def no_weight_decay(self):
  256. return {'temperature'}
  257. class XCABlock(nn.Module):
  258. def __init__(
  259. self,
  260. dim: int,
  261. num_heads: int,
  262. mlp_ratio: float = 4.,
  263. qkv_bias: bool = False,
  264. proj_drop: float = 0.,
  265. attn_drop: float = 0.,
  266. drop_path: float = 0.,
  267. act_layer: Type[nn.Module] = nn.GELU,
  268. norm_layer: Type[nn.Module] = nn.LayerNorm,
  269. eta: float = 1.,
  270. device=None,
  271. dtype=None,
  272. ):
  273. dd = {'device': device, 'dtype': dtype}
  274. super().__init__()
  275. self.norm1 = norm_layer(dim, **dd)
  276. self.attn = XCA(
  277. dim,
  278. num_heads=num_heads,
  279. qkv_bias=qkv_bias,
  280. attn_drop=attn_drop,
  281. proj_drop=proj_drop,
  282. **dd,
  283. )
  284. self.drop_path1 = DropPath(drop_path) if drop_path > 0. else nn.Identity()
  285. self.norm3 = norm_layer(dim, **dd)
  286. self.local_mp = LPI(in_features=dim, act_layer=act_layer, **dd)
  287. self.drop_path3 = DropPath(drop_path) if drop_path > 0. else nn.Identity()
  288. self.norm2 = norm_layer(dim, **dd)
  289. self.mlp = Mlp(
  290. in_features=dim,
  291. hidden_features=int(dim * mlp_ratio),
  292. act_layer=act_layer,
  293. drop=proj_drop,
  294. **dd,
  295. )
  296. self.drop_path2 = DropPath(drop_path) if drop_path > 0. else nn.Identity()
  297. self.gamma1 = nn.Parameter(eta * torch.ones(dim, **dd))
  298. self.gamma3 = nn.Parameter(eta * torch.ones(dim, **dd))
  299. self.gamma2 = nn.Parameter(eta * torch.ones(dim, **dd))
  300. def forward(self, x, H: int, W: int):
  301. x = x + self.drop_path1(self.gamma1 * self.attn(self.norm1(x)))
  302. # NOTE official code has 3 then 2, so keeping it the same to be consistent with loaded weights
  303. # See https://github.com/rwightman/pytorch-image-models/pull/747#issuecomment-877795721
  304. x = x + self.drop_path3(self.gamma3 * self.local_mp(self.norm3(x), H, W))
  305. x = x + self.drop_path2(self.gamma2 * self.mlp(self.norm2(x)))
  306. return x
  307. class Xcit(nn.Module):
  308. """
  309. Based on timm and DeiT code bases
  310. https://github.com/rwightman/pytorch-image-models/tree/master/timm
  311. https://github.com/facebookresearch/deit/
  312. """
  313. def __init__(
  314. self,
  315. img_size: Union[int, Tuple[int, int]] = 224,
  316. patch_size: int = 16,
  317. in_chans: int = 3,
  318. num_classes: int = 1000,
  319. global_pool: str = 'token',
  320. embed_dim: int = 768,
  321. depth: int = 12,
  322. num_heads: int = 12,
  323. mlp_ratio: float = 4.,
  324. qkv_bias: bool = True,
  325. drop_rate: float = 0.,
  326. pos_drop_rate: float = 0.,
  327. proj_drop_rate: float = 0.,
  328. attn_drop_rate: float = 0.,
  329. drop_path_rate: float = 0.,
  330. act_layer: Optional[Type[nn.Module]] = None,
  331. norm_layer: Optional[Type[nn.Module]] = None,
  332. cls_attn_layers: int = 2,
  333. use_pos_embed: bool = True,
  334. eta: float = 1.,
  335. tokens_norm: bool = False,
  336. device=None,
  337. dtype=None,
  338. ):
  339. """
  340. Args:
  341. img_size (int, tuple): input image size
  342. patch_size (int): patch size
  343. in_chans (int): number of input channels
  344. num_classes (int): number of classes for classification head
  345. embed_dim (int): embedding dimension
  346. depth (int): depth of transformer
  347. num_heads (int): number of attention heads
  348. mlp_ratio (int): ratio of mlp hidden dim to embedding dim
  349. qkv_bias (bool): enable bias for qkv if True
  350. drop_rate (float): dropout rate after positional embedding, and in XCA/CA projection + MLP
  351. pos_drop_rate: position embedding dropout rate
  352. proj_drop_rate (float): projection dropout rate
  353. attn_drop_rate (float): attention dropout rate
  354. drop_path_rate (float): stochastic depth rate (constant across all layers)
  355. norm_layer: (nn.Module): normalization layer
  356. cls_attn_layers: (int) Depth of Class attention layers
  357. use_pos_embed: (bool) whether to use positional encoding
  358. eta: (float) layerscale initialization value
  359. tokens_norm: (bool) Whether to normalize all tokens or just the cls_token in the CA
  360. Notes:
  361. - Although `layer_norm` is user specifiable, there are hard-coded `BatchNorm2d`s in the local patch
  362. interaction (class LPI) and the patch embedding (class ConvPatchEmbed)
  363. """
  364. super().__init__()
  365. dd = {'device': device, 'dtype': dtype}
  366. assert global_pool in ('', 'avg', 'token')
  367. img_size = to_2tuple(img_size)
  368. assert (img_size[0] % patch_size == 0) and (img_size[0] % patch_size == 0), \
  369. '`patch_size` should divide image dimensions evenly'
  370. norm_layer = norm_layer or partial(nn.LayerNorm, eps=1e-6)
  371. act_layer = act_layer or nn.GELU
  372. self.num_classes = num_classes
  373. self.num_features = self.head_hidden_size = self.embed_dim = embed_dim
  374. self.global_pool = global_pool
  375. self.grad_checkpointing = False
  376. self.patch_embed = ConvPatchEmbed(
  377. img_size=img_size,
  378. patch_size=patch_size,
  379. in_chans=in_chans,
  380. embed_dim=embed_dim,
  381. act_layer=act_layer,
  382. **dd,
  383. )
  384. r = patch_size
  385. self.cls_token = nn.Parameter(torch.zeros(1, 1, embed_dim, **dd))
  386. if use_pos_embed:
  387. self.pos_embed = PositionalEncodingFourier(dim=embed_dim, **dd)
  388. else:
  389. self.pos_embed = None
  390. self.pos_drop = nn.Dropout(p=pos_drop_rate)
  391. self.blocks = nn.ModuleList([
  392. XCABlock(
  393. dim=embed_dim,
  394. num_heads=num_heads,
  395. mlp_ratio=mlp_ratio,
  396. qkv_bias=qkv_bias,
  397. proj_drop=proj_drop_rate,
  398. attn_drop=attn_drop_rate,
  399. drop_path=drop_path_rate,
  400. act_layer=act_layer,
  401. norm_layer=norm_layer,
  402. eta=eta,
  403. **dd,
  404. )
  405. for _ in range(depth)])
  406. self.feature_info = [dict(num_chs=embed_dim, reduction=r, module=f'blocks.{i}') for i in range(depth)]
  407. self.cls_attn_blocks = nn.ModuleList([
  408. ClassAttentionBlock(
  409. dim=embed_dim,
  410. num_heads=num_heads,
  411. mlp_ratio=mlp_ratio,
  412. qkv_bias=qkv_bias,
  413. proj_drop=drop_rate,
  414. attn_drop=attn_drop_rate,
  415. act_layer=act_layer,
  416. norm_layer=norm_layer,
  417. eta=eta,
  418. tokens_norm=tokens_norm,
  419. **dd,
  420. )
  421. for _ in range(cls_attn_layers)])
  422. # Classifier head
  423. self.norm = norm_layer(embed_dim, **dd)
  424. self.head_drop = nn.Dropout(drop_rate)
  425. self.head = nn.Linear(self.num_features, num_classes, **dd) if num_classes > 0 else nn.Identity()
  426. # Init weights
  427. trunc_normal_(self.cls_token, std=.02)
  428. self.apply(self._init_weights)
  429. def _init_weights(self, m):
  430. if isinstance(m, nn.Linear):
  431. trunc_normal_(m.weight, std=.02)
  432. if isinstance(m, nn.Linear) and m.bias is not None:
  433. nn.init.constant_(m.bias, 0)
  434. @torch.jit.ignore
  435. def no_weight_decay(self):
  436. return {'pos_embed', 'cls_token'}
  437. @torch.jit.ignore
  438. def group_matcher(self, coarse=False):
  439. return dict(
  440. stem=r'^cls_token|pos_embed|patch_embed', # stem and embed
  441. blocks=r'^blocks\.(\d+)',
  442. cls_attn_blocks=[(r'^cls_attn_blocks\.(\d+)', None), (r'^norm', (99999,))]
  443. )
  444. @torch.jit.ignore
  445. def set_grad_checkpointing(self, enable=True):
  446. self.grad_checkpointing = enable
  447. @torch.jit.ignore
  448. def get_classifier(self) -> nn.Module:
  449. return self.head
  450. def reset_classifier(self, num_classes: int, global_pool: Optional[str] = None):
  451. self.num_classes = num_classes
  452. if global_pool is not None:
  453. assert global_pool in ('', 'avg', 'token')
  454. self.global_pool = global_pool
  455. device = self.head.weight.device if hasattr(self.head, 'weight') else None
  456. dtype = self.head.weight.dtype if hasattr(self.head, 'weight') else None
  457. self.head = nn.Linear(self.num_features, num_classes, device=device, dtype=dtype) if num_classes > 0 else nn.Identity()
  458. def forward_intermediates(
  459. self,
  460. x: torch.Tensor,
  461. indices: Optional[Union[int, List[int]]] = None,
  462. norm: bool = False,
  463. stop_early: bool = False,
  464. output_fmt: str = 'NCHW',
  465. intermediates_only: bool = False,
  466. ) -> Union[List[torch.Tensor], Tuple[torch.Tensor, List[torch.Tensor]]]:
  467. """ Forward features that returns intermediates.
  468. Args:
  469. x: Input image tensor
  470. indices: Take last n blocks if int, all if None, select matching indices if sequence
  471. norm: Apply norm layer to all intermediates
  472. stop_early: Stop iterating over blocks when last desired intermediate hit
  473. output_fmt: Shape of intermediate feature outputs
  474. intermediates_only: Only return intermediate features
  475. Returns:
  476. """
  477. assert output_fmt in ('NCHW', 'NLC'), 'Output format must be one of NCHW or NLC.'
  478. reshape = output_fmt == 'NCHW'
  479. intermediates = []
  480. take_indices, max_index = feature_take_indices(len(self.blocks), indices)
  481. # forward pass
  482. B, _, height, width = x.shape
  483. x, (Hp, Wp) = self.patch_embed(x)
  484. if self.pos_embed is not None:
  485. # `pos_embed` (B, C, Hp, Wp), reshape -> (B, C, N), permute -> (B, N, C)
  486. pos_encoding = self.pos_embed(B, Hp, Wp).reshape(B, -1, x.shape[1]).permute(0, 2, 1)
  487. x = x + pos_encoding
  488. x = self.pos_drop(x)
  489. if torch.jit.is_scripting() or not stop_early: # can't slice blocks in torchscript
  490. blocks = self.blocks
  491. else:
  492. blocks = self.blocks[:max_index + 1]
  493. for i, blk in enumerate(blocks):
  494. if self.grad_checkpointing and not torch.jit.is_scripting():
  495. x = checkpoint(blk, x, Hp, Wp)
  496. else:
  497. x = blk(x, Hp, Wp)
  498. if i in take_indices:
  499. # normalize intermediates with final norm layer if enabled
  500. intermediates.append(self.norm(x) if norm else x)
  501. # process intermediates
  502. if reshape:
  503. # reshape to BCHW output format
  504. intermediates = [y.reshape(B, Hp, Wp, -1).permute(0, 3, 1, 2).contiguous() for y in intermediates]
  505. if intermediates_only:
  506. return intermediates
  507. # NOTE not supporting return of class tokens
  508. x = torch.cat((self.cls_token.expand(B, -1, -1), x), dim=1)
  509. for blk in self.cls_attn_blocks:
  510. if self.grad_checkpointing and not torch.jit.is_scripting():
  511. x = checkpoint(blk, x)
  512. else:
  513. x = blk(x)
  514. x = self.norm(x)
  515. return x, intermediates
  516. def prune_intermediate_layers(
  517. self,
  518. indices: Union[int, List[int]] = 1,
  519. prune_norm: bool = False,
  520. prune_head: bool = True,
  521. ):
  522. """ Prune layers not required for specified intermediates.
  523. """
  524. take_indices, max_index = feature_take_indices(len(self.blocks), indices)
  525. self.blocks = self.blocks[:max_index + 1] # truncate blocks
  526. if prune_norm:
  527. self.norm = nn.Identity()
  528. if prune_head:
  529. self.cls_attn_blocks = nn.ModuleList() # prune token blocks with head
  530. self.reset_classifier(0, '')
  531. return take_indices
  532. def forward_features(self, x):
  533. B = x.shape[0]
  534. # x is (B, N, C). (Hp, Hw) is (height in units of patches, width in units of patches)
  535. x, (Hp, Wp) = self.patch_embed(x)
  536. if self.pos_embed is not None:
  537. # `pos_embed` (B, C, Hp, Wp), reshape -> (B, C, N), permute -> (B, N, C)
  538. pos_encoding = self.pos_embed(B, Hp, Wp).reshape(B, -1, x.shape[1]).permute(0, 2, 1)
  539. x = x + pos_encoding
  540. x = self.pos_drop(x)
  541. for blk in self.blocks:
  542. if self.grad_checkpointing and not torch.jit.is_scripting():
  543. x = checkpoint(blk, x, Hp, Wp)
  544. else:
  545. x = blk(x, Hp, Wp)
  546. x = torch.cat((self.cls_token.expand(B, -1, -1), x), dim=1)
  547. for blk in self.cls_attn_blocks:
  548. if self.grad_checkpointing and not torch.jit.is_scripting():
  549. x = checkpoint(blk, x)
  550. else:
  551. x = blk(x)
  552. x = self.norm(x)
  553. return x
  554. def forward_head(self, x, pre_logits: bool = False):
  555. if self.global_pool:
  556. x = x[:, 1:].mean(dim=1) if self.global_pool == 'avg' else x[:, 0]
  557. x = self.head_drop(x)
  558. return x if pre_logits else self.head(x)
  559. def forward(self, x):
  560. x = self.forward_features(x)
  561. x = self.forward_head(x)
  562. return x
  563. def checkpoint_filter_fn(state_dict, model):
  564. if 'model' in state_dict:
  565. state_dict = state_dict['model']
  566. # For consistency with timm's transformer models while being compatible with official weights source we rename
  567. # pos_embeder to pos_embed. Also account for use_pos_embed == False
  568. use_pos_embed = getattr(model, 'pos_embed', None) is not None
  569. pos_embed_keys = [k for k in state_dict if k.startswith('pos_embed')]
  570. for k in pos_embed_keys:
  571. if use_pos_embed:
  572. state_dict[k.replace('pos_embeder.', 'pos_embed.')] = state_dict.pop(k)
  573. else:
  574. del state_dict[k]
  575. # timm's implementation of class attention in CaiT is slightly more efficient as it does not compute query vectors
  576. # for all tokens, just the class token. To use official weights source we must split qkv into q, k, v
  577. if 'cls_attn_blocks.0.attn.qkv.weight' in state_dict and 'cls_attn_blocks.0.attn.q.weight' in model.state_dict():
  578. num_ca_blocks = len(model.cls_attn_blocks)
  579. for i in range(num_ca_blocks):
  580. qkv_weight = state_dict.pop(f'cls_attn_blocks.{i}.attn.qkv.weight')
  581. qkv_weight = qkv_weight.reshape(3, -1, qkv_weight.shape[-1])
  582. for j, subscript in enumerate('qkv'):
  583. state_dict[f'cls_attn_blocks.{i}.attn.{subscript}.weight'] = qkv_weight[j]
  584. qkv_bias = state_dict.pop(f'cls_attn_blocks.{i}.attn.qkv.bias', None)
  585. if qkv_bias is not None:
  586. qkv_bias = qkv_bias.reshape(3, -1)
  587. for j, subscript in enumerate('qkv'):
  588. state_dict[f'cls_attn_blocks.{i}.attn.{subscript}.bias'] = qkv_bias[j]
  589. return state_dict
  590. def _create_xcit(variant, pretrained=False, default_cfg=None, **kwargs):
  591. out_indices = kwargs.pop('out_indices', 3)
  592. model = build_model_with_cfg(
  593. Xcit,
  594. variant,
  595. pretrained,
  596. pretrained_filter_fn=checkpoint_filter_fn,
  597. feature_cfg=dict(out_indices=out_indices, feature_cls='getter'),
  598. **kwargs,
  599. )
  600. return model
  601. def _cfg(url='', **kwargs):
  602. return {
  603. 'url': url,
  604. 'num_classes': 1000, 'input_size': (3, 224, 224), 'pool_size': None,
  605. 'crop_pct': 1.0, 'interpolation': 'bicubic', 'fixed_input_size': True,
  606. 'mean': IMAGENET_DEFAULT_MEAN, 'std': IMAGENET_DEFAULT_STD,
  607. 'first_conv': 'patch_embed.proj.0.0', 'classifier': 'head',
  608. 'license': 'apache-2.0', **kwargs
  609. }
  610. default_cfgs = generate_default_cfgs({
  611. # Patch size 16
  612. 'xcit_nano_12_p16_224.fb_in1k': _cfg(
  613. hf_hub_id='timm/',
  614. url='https://dl.fbaipublicfiles.com/xcit/xcit_nano_12_p16_224.pth'),
  615. 'xcit_nano_12_p16_224.fb_dist_in1k': _cfg(
  616. hf_hub_id='timm/',
  617. url='https://dl.fbaipublicfiles.com/xcit/xcit_nano_12_p16_224_dist.pth'),
  618. 'xcit_nano_12_p16_384.fb_dist_in1k': _cfg(
  619. hf_hub_id='timm/',
  620. url='https://dl.fbaipublicfiles.com/xcit/xcit_nano_12_p16_384_dist.pth', input_size=(3, 384, 384)),
  621. 'xcit_tiny_12_p16_224.fb_in1k': _cfg(
  622. hf_hub_id='timm/',
  623. url='https://dl.fbaipublicfiles.com/xcit/xcit_tiny_12_p16_224.pth'),
  624. 'xcit_tiny_12_p16_224.fb_dist_in1k': _cfg(
  625. hf_hub_id='timm/',
  626. url='https://dl.fbaipublicfiles.com/xcit/xcit_tiny_12_p16_224_dist.pth'),
  627. 'xcit_tiny_12_p16_384.fb_dist_in1k': _cfg(
  628. hf_hub_id='timm/',
  629. url='https://dl.fbaipublicfiles.com/xcit/xcit_tiny_12_p16_384_dist.pth', input_size=(3, 384, 384)),
  630. 'xcit_tiny_24_p16_224.fb_in1k': _cfg(
  631. hf_hub_id='timm/',
  632. url='https://dl.fbaipublicfiles.com/xcit/xcit_tiny_24_p16_224.pth'),
  633. 'xcit_tiny_24_p16_224.fb_dist_in1k': _cfg(
  634. hf_hub_id='timm/',
  635. url='https://dl.fbaipublicfiles.com/xcit/xcit_tiny_24_p16_224_dist.pth'),
  636. 'xcit_tiny_24_p16_384.fb_dist_in1k': _cfg(
  637. hf_hub_id='timm/',
  638. url='https://dl.fbaipublicfiles.com/xcit/xcit_tiny_24_p16_384_dist.pth', input_size=(3, 384, 384)),
  639. 'xcit_small_12_p16_224.fb_in1k': _cfg(
  640. hf_hub_id='timm/',
  641. url='https://dl.fbaipublicfiles.com/xcit/xcit_small_12_p16_224.pth'),
  642. 'xcit_small_12_p16_224.fb_dist_in1k': _cfg(
  643. hf_hub_id='timm/',
  644. url='https://dl.fbaipublicfiles.com/xcit/xcit_small_12_p16_224_dist.pth'),
  645. 'xcit_small_12_p16_384.fb_dist_in1k': _cfg(
  646. hf_hub_id='timm/',
  647. url='https://dl.fbaipublicfiles.com/xcit/xcit_small_12_p16_384_dist.pth', input_size=(3, 384, 384)),
  648. 'xcit_small_24_p16_224.fb_in1k': _cfg(
  649. hf_hub_id='timm/',
  650. url='https://dl.fbaipublicfiles.com/xcit/xcit_small_24_p16_224.pth'),
  651. 'xcit_small_24_p16_224.fb_dist_in1k': _cfg(
  652. hf_hub_id='timm/',
  653. url='https://dl.fbaipublicfiles.com/xcit/xcit_small_24_p16_224_dist.pth'),
  654. 'xcit_small_24_p16_384.fb_dist_in1k': _cfg(
  655. hf_hub_id='timm/',
  656. url='https://dl.fbaipublicfiles.com/xcit/xcit_small_24_p16_384_dist.pth', input_size=(3, 384, 384)),
  657. 'xcit_medium_24_p16_224.fb_in1k': _cfg(
  658. hf_hub_id='timm/',
  659. url='https://dl.fbaipublicfiles.com/xcit/xcit_medium_24_p16_224.pth'),
  660. 'xcit_medium_24_p16_224.fb_dist_in1k': _cfg(
  661. hf_hub_id='timm/',
  662. url='https://dl.fbaipublicfiles.com/xcit/xcit_medium_24_p16_224_dist.pth'),
  663. 'xcit_medium_24_p16_384.fb_dist_in1k': _cfg(
  664. hf_hub_id='timm/',
  665. url='https://dl.fbaipublicfiles.com/xcit/xcit_medium_24_p16_384_dist.pth', input_size=(3, 384, 384)),
  666. 'xcit_large_24_p16_224.fb_in1k': _cfg(
  667. hf_hub_id='timm/',
  668. url='https://dl.fbaipublicfiles.com/xcit/xcit_large_24_p16_224.pth'),
  669. 'xcit_large_24_p16_224.fb_dist_in1k': _cfg(
  670. hf_hub_id='timm/',
  671. url='https://dl.fbaipublicfiles.com/xcit/xcit_large_24_p16_224_dist.pth'),
  672. 'xcit_large_24_p16_384.fb_dist_in1k': _cfg(
  673. hf_hub_id='timm/',
  674. url='https://dl.fbaipublicfiles.com/xcit/xcit_large_24_p16_384_dist.pth', input_size=(3, 384, 384)),
  675. # Patch size 8
  676. 'xcit_nano_12_p8_224.fb_in1k': _cfg(
  677. hf_hub_id='timm/',
  678. url='https://dl.fbaipublicfiles.com/xcit/xcit_nano_12_p8_224.pth'),
  679. 'xcit_nano_12_p8_224.fb_dist_in1k': _cfg(
  680. hf_hub_id='timm/',
  681. url='https://dl.fbaipublicfiles.com/xcit/xcit_nano_12_p8_224_dist.pth'),
  682. 'xcit_nano_12_p8_384.fb_dist_in1k': _cfg(
  683. hf_hub_id='timm/',
  684. url='https://dl.fbaipublicfiles.com/xcit/xcit_nano_12_p8_384_dist.pth', input_size=(3, 384, 384)),
  685. 'xcit_tiny_12_p8_224.fb_in1k': _cfg(
  686. hf_hub_id='timm/',
  687. url='https://dl.fbaipublicfiles.com/xcit/xcit_tiny_12_p8_224.pth'),
  688. 'xcit_tiny_12_p8_224.fb_dist_in1k': _cfg(
  689. hf_hub_id='timm/',
  690. url='https://dl.fbaipublicfiles.com/xcit/xcit_tiny_12_p8_224_dist.pth'),
  691. 'xcit_tiny_12_p8_384.fb_dist_in1k': _cfg(
  692. hf_hub_id='timm/',
  693. url='https://dl.fbaipublicfiles.com/xcit/xcit_tiny_12_p8_384_dist.pth', input_size=(3, 384, 384)),
  694. 'xcit_tiny_24_p8_224.fb_in1k': _cfg(
  695. hf_hub_id='timm/',
  696. url='https://dl.fbaipublicfiles.com/xcit/xcit_tiny_24_p8_224.pth'),
  697. 'xcit_tiny_24_p8_224.fb_dist_in1k': _cfg(
  698. hf_hub_id='timm/',
  699. url='https://dl.fbaipublicfiles.com/xcit/xcit_tiny_24_p8_224_dist.pth'),
  700. 'xcit_tiny_24_p8_384.fb_dist_in1k': _cfg(
  701. hf_hub_id='timm/',
  702. url='https://dl.fbaipublicfiles.com/xcit/xcit_tiny_24_p8_384_dist.pth', input_size=(3, 384, 384)),
  703. 'xcit_small_12_p8_224.fb_in1k': _cfg(
  704. hf_hub_id='timm/',
  705. url='https://dl.fbaipublicfiles.com/xcit/xcit_small_12_p8_224.pth'),
  706. 'xcit_small_12_p8_224.fb_dist_in1k': _cfg(
  707. hf_hub_id='timm/',
  708. url='https://dl.fbaipublicfiles.com/xcit/xcit_small_12_p8_224_dist.pth'),
  709. 'xcit_small_12_p8_384.fb_dist_in1k': _cfg(
  710. hf_hub_id='timm/',
  711. url='https://dl.fbaipublicfiles.com/xcit/xcit_small_12_p8_384_dist.pth', input_size=(3, 384, 384)),
  712. 'xcit_small_24_p8_224.fb_in1k': _cfg(
  713. hf_hub_id='timm/',
  714. url='https://dl.fbaipublicfiles.com/xcit/xcit_small_24_p8_224.pth'),
  715. 'xcit_small_24_p8_224.fb_dist_in1k': _cfg(
  716. hf_hub_id='timm/',
  717. url='https://dl.fbaipublicfiles.com/xcit/xcit_small_24_p8_224_dist.pth'),
  718. 'xcit_small_24_p8_384.fb_dist_in1k': _cfg(
  719. hf_hub_id='timm/',
  720. url='https://dl.fbaipublicfiles.com/xcit/xcit_small_24_p8_384_dist.pth', input_size=(3, 384, 384)),
  721. 'xcit_medium_24_p8_224.fb_in1k': _cfg(
  722. hf_hub_id='timm/',
  723. url='https://dl.fbaipublicfiles.com/xcit/xcit_medium_24_p8_224.pth'),
  724. 'xcit_medium_24_p8_224.fb_dist_in1k': _cfg(
  725. hf_hub_id='timm/',
  726. url='https://dl.fbaipublicfiles.com/xcit/xcit_medium_24_p8_224_dist.pth'),
  727. 'xcit_medium_24_p8_384.fb_dist_in1k': _cfg(
  728. hf_hub_id='timm/',
  729. url='https://dl.fbaipublicfiles.com/xcit/xcit_medium_24_p8_384_dist.pth', input_size=(3, 384, 384)),
  730. 'xcit_large_24_p8_224.fb_in1k': _cfg(
  731. hf_hub_id='timm/',
  732. url='https://dl.fbaipublicfiles.com/xcit/xcit_large_24_p8_224.pth'),
  733. 'xcit_large_24_p8_224.fb_dist_in1k': _cfg(
  734. hf_hub_id='timm/',
  735. url='https://dl.fbaipublicfiles.com/xcit/xcit_large_24_p8_224_dist.pth'),
  736. 'xcit_large_24_p8_384.fb_dist_in1k': _cfg(
  737. hf_hub_id='timm/',
  738. url='https://dl.fbaipublicfiles.com/xcit/xcit_large_24_p8_384_dist.pth', input_size=(3, 384, 384)),
  739. })
  740. @register_model
  741. def xcit_nano_12_p16_224(pretrained=False, **kwargs) -> Xcit:
  742. model_args = dict(
  743. patch_size=16, embed_dim=128, depth=12, num_heads=4, eta=1.0, tokens_norm=False)
  744. model = _create_xcit('xcit_nano_12_p16_224', pretrained=pretrained, **dict(model_args, **kwargs))
  745. return model
  746. @register_model
  747. def xcit_nano_12_p16_384(pretrained=False, **kwargs) -> Xcit:
  748. model_args = dict(
  749. patch_size=16, embed_dim=128, depth=12, num_heads=4, eta=1.0, tokens_norm=False, img_size=384)
  750. model = _create_xcit('xcit_nano_12_p16_384', pretrained=pretrained, **dict(model_args, **kwargs))
  751. return model
  752. @register_model
  753. def xcit_tiny_12_p16_224(pretrained=False, **kwargs) -> Xcit:
  754. model_args = dict(
  755. patch_size=16, embed_dim=192, depth=12, num_heads=4, eta=1.0, tokens_norm=True)
  756. model = _create_xcit('xcit_tiny_12_p16_224', pretrained=pretrained, **dict(model_args, **kwargs))
  757. return model
  758. @register_model
  759. def xcit_tiny_12_p16_384(pretrained=False, **kwargs) -> Xcit:
  760. model_args = dict(
  761. patch_size=16, embed_dim=192, depth=12, num_heads=4, eta=1.0, tokens_norm=True)
  762. model = _create_xcit('xcit_tiny_12_p16_384', pretrained=pretrained, **dict(model_args, **kwargs))
  763. return model
  764. @register_model
  765. def xcit_small_12_p16_224(pretrained=False, **kwargs) -> Xcit:
  766. model_args = dict(
  767. patch_size=16, embed_dim=384, depth=12, num_heads=8, eta=1.0, tokens_norm=True)
  768. model = _create_xcit('xcit_small_12_p16_224', pretrained=pretrained, **dict(model_args, **kwargs))
  769. return model
  770. @register_model
  771. def xcit_small_12_p16_384(pretrained=False, **kwargs) -> Xcit:
  772. model_args = dict(
  773. patch_size=16, embed_dim=384, depth=12, num_heads=8, eta=1.0, tokens_norm=True)
  774. model = _create_xcit('xcit_small_12_p16_384', pretrained=pretrained, **dict(model_args, **kwargs))
  775. return model
  776. @register_model
  777. def xcit_tiny_24_p16_224(pretrained=False, **kwargs) -> Xcit:
  778. model_args = dict(
  779. patch_size=16, embed_dim=192, depth=24, num_heads=4, eta=1e-5, tokens_norm=True)
  780. model = _create_xcit('xcit_tiny_24_p16_224', pretrained=pretrained, **dict(model_args, **kwargs))
  781. return model
  782. @register_model
  783. def xcit_tiny_24_p16_384(pretrained=False, **kwargs) -> Xcit:
  784. model_args = dict(
  785. patch_size=16, embed_dim=192, depth=24, num_heads=4, eta=1e-5, tokens_norm=True)
  786. model = _create_xcit('xcit_tiny_24_p16_384', pretrained=pretrained, **dict(model_args, **kwargs))
  787. return model
  788. @register_model
  789. def xcit_small_24_p16_224(pretrained=False, **kwargs) -> Xcit:
  790. model_args = dict(
  791. patch_size=16, embed_dim=384, depth=24, num_heads=8, eta=1e-5, tokens_norm=True)
  792. model = _create_xcit('xcit_small_24_p16_224', pretrained=pretrained, **dict(model_args, **kwargs))
  793. return model
  794. @register_model
  795. def xcit_small_24_p16_384(pretrained=False, **kwargs) -> Xcit:
  796. model_args = dict(
  797. patch_size=16, embed_dim=384, depth=24, num_heads=8, eta=1e-5, tokens_norm=True)
  798. model = _create_xcit('xcit_small_24_p16_384', pretrained=pretrained, **dict(model_args, **kwargs))
  799. return model
  800. @register_model
  801. def xcit_medium_24_p16_224(pretrained=False, **kwargs) -> Xcit:
  802. model_args = dict(
  803. patch_size=16, embed_dim=512, depth=24, num_heads=8, eta=1e-5, tokens_norm=True)
  804. model = _create_xcit('xcit_medium_24_p16_224', pretrained=pretrained, **dict(model_args, **kwargs))
  805. return model
  806. @register_model
  807. def xcit_medium_24_p16_384(pretrained=False, **kwargs) -> Xcit:
  808. model_args = dict(
  809. patch_size=16, embed_dim=512, depth=24, num_heads=8, eta=1e-5, tokens_norm=True)
  810. model = _create_xcit('xcit_medium_24_p16_384', pretrained=pretrained, **dict(model_args, **kwargs))
  811. return model
  812. @register_model
  813. def xcit_large_24_p16_224(pretrained=False, **kwargs) -> Xcit:
  814. model_args = dict(
  815. patch_size=16, embed_dim=768, depth=24, num_heads=16, eta=1e-5, tokens_norm=True)
  816. model = _create_xcit('xcit_large_24_p16_224', pretrained=pretrained, **dict(model_args, **kwargs))
  817. return model
  818. @register_model
  819. def xcit_large_24_p16_384(pretrained=False, **kwargs) -> Xcit:
  820. model_args = dict(
  821. patch_size=16, embed_dim=768, depth=24, num_heads=16, eta=1e-5, tokens_norm=True)
  822. model = _create_xcit('xcit_large_24_p16_384', pretrained=pretrained, **dict(model_args, **kwargs))
  823. return model
  824. # Patch size 8x8 models
  825. @register_model
  826. def xcit_nano_12_p8_224(pretrained=False, **kwargs) -> Xcit:
  827. model_args = dict(
  828. patch_size=8, embed_dim=128, depth=12, num_heads=4, eta=1.0, tokens_norm=False)
  829. model = _create_xcit('xcit_nano_12_p8_224', pretrained=pretrained, **dict(model_args, **kwargs))
  830. return model
  831. @register_model
  832. def xcit_nano_12_p8_384(pretrained=False, **kwargs) -> Xcit:
  833. model_args = dict(
  834. patch_size=8, embed_dim=128, depth=12, num_heads=4, eta=1.0, tokens_norm=False)
  835. model = _create_xcit('xcit_nano_12_p8_384', pretrained=pretrained, **dict(model_args, **kwargs))
  836. return model
  837. @register_model
  838. def xcit_tiny_12_p8_224(pretrained=False, **kwargs) -> Xcit:
  839. model_args = dict(
  840. patch_size=8, embed_dim=192, depth=12, num_heads=4, eta=1.0, tokens_norm=True)
  841. model = _create_xcit('xcit_tiny_12_p8_224', pretrained=pretrained, **dict(model_args, **kwargs))
  842. return model
  843. @register_model
  844. def xcit_tiny_12_p8_384(pretrained=False, **kwargs) -> Xcit:
  845. model_args = dict(
  846. patch_size=8, embed_dim=192, depth=12, num_heads=4, eta=1.0, tokens_norm=True)
  847. model = _create_xcit('xcit_tiny_12_p8_384', pretrained=pretrained, **dict(model_args, **kwargs))
  848. return model
  849. @register_model
  850. def xcit_small_12_p8_224(pretrained=False, **kwargs) -> Xcit:
  851. model_args = dict(
  852. patch_size=8, embed_dim=384, depth=12, num_heads=8, eta=1.0, tokens_norm=True)
  853. model = _create_xcit('xcit_small_12_p8_224', pretrained=pretrained, **dict(model_args, **kwargs))
  854. return model
  855. @register_model
  856. def xcit_small_12_p8_384(pretrained=False, **kwargs) -> Xcit:
  857. model_args = dict(
  858. patch_size=8, embed_dim=384, depth=12, num_heads=8, eta=1.0, tokens_norm=True)
  859. model = _create_xcit('xcit_small_12_p8_384', pretrained=pretrained, **dict(model_args, **kwargs))
  860. return model
  861. @register_model
  862. def xcit_tiny_24_p8_224(pretrained=False, **kwargs) -> Xcit:
  863. model_args = dict(
  864. patch_size=8, embed_dim=192, depth=24, num_heads=4, eta=1e-5, tokens_norm=True)
  865. model = _create_xcit('xcit_tiny_24_p8_224', pretrained=pretrained, **dict(model_args, **kwargs))
  866. return model
  867. @register_model
  868. def xcit_tiny_24_p8_384(pretrained=False, **kwargs) -> Xcit:
  869. model_args = dict(
  870. patch_size=8, embed_dim=192, depth=24, num_heads=4, eta=1e-5, tokens_norm=True)
  871. model = _create_xcit('xcit_tiny_24_p8_384', pretrained=pretrained, **dict(model_args, **kwargs))
  872. return model
  873. @register_model
  874. def xcit_small_24_p8_224(pretrained=False, **kwargs) -> Xcit:
  875. model_args = dict(
  876. patch_size=8, embed_dim=384, depth=24, num_heads=8, eta=1e-5, tokens_norm=True)
  877. model = _create_xcit('xcit_small_24_p8_224', pretrained=pretrained, **dict(model_args, **kwargs))
  878. return model
  879. @register_model
  880. def xcit_small_24_p8_384(pretrained=False, **kwargs) -> Xcit:
  881. model_args = dict(
  882. patch_size=8, embed_dim=384, depth=24, num_heads=8, eta=1e-5, tokens_norm=True)
  883. model = _create_xcit('xcit_small_24_p8_384', pretrained=pretrained, **dict(model_args, **kwargs))
  884. return model
  885. @register_model
  886. def xcit_medium_24_p8_224(pretrained=False, **kwargs) -> Xcit:
  887. model_args = dict(
  888. patch_size=8, embed_dim=512, depth=24, num_heads=8, eta=1e-5, tokens_norm=True)
  889. model = _create_xcit('xcit_medium_24_p8_224', pretrained=pretrained, **dict(model_args, **kwargs))
  890. return model
  891. @register_model
  892. def xcit_medium_24_p8_384(pretrained=False, **kwargs) -> Xcit:
  893. model_args = dict(
  894. patch_size=8, embed_dim=512, depth=24, num_heads=8, eta=1e-5, tokens_norm=True)
  895. model = _create_xcit('xcit_medium_24_p8_384', pretrained=pretrained, **dict(model_args, **kwargs))
  896. return model
  897. @register_model
  898. def xcit_large_24_p8_224(pretrained=False, **kwargs) -> Xcit:
  899. model_args = dict(
  900. patch_size=8, embed_dim=768, depth=24, num_heads=16, eta=1e-5, tokens_norm=True)
  901. model = _create_xcit('xcit_large_24_p8_224', pretrained=pretrained, **dict(model_args, **kwargs))
  902. return model
  903. @register_model
  904. def xcit_large_24_p8_384(pretrained=False, **kwargs) -> Xcit:
  905. model_args = dict(
  906. patch_size=8, embed_dim=768, depth=24, num_heads=16, eta=1e-5, tokens_norm=True)
  907. model = _create_xcit('xcit_large_24_p8_384', pretrained=pretrained, **dict(model_args, **kwargs))
  908. return model
  909. register_model_deprecations(__name__, {
  910. # Patch size 16
  911. 'xcit_nano_12_p16_224_dist': 'xcit_nano_12_p16_224.fb_dist_in1k',
  912. 'xcit_nano_12_p16_384_dist': 'xcit_nano_12_p16_384.fb_dist_in1k',
  913. 'xcit_tiny_12_p16_224_dist': 'xcit_tiny_12_p16_224.fb_dist_in1k',
  914. 'xcit_tiny_12_p16_384_dist': 'xcit_tiny_12_p16_384.fb_dist_in1k',
  915. 'xcit_tiny_24_p16_224_dist': 'xcit_tiny_24_p16_224.fb_dist_in1k',
  916. 'xcit_tiny_24_p16_384_dist': 'xcit_tiny_24_p16_384.fb_dist_in1k',
  917. 'xcit_small_12_p16_224_dist': 'xcit_small_12_p16_224.fb_dist_in1k',
  918. 'xcit_small_12_p16_384_dist': 'xcit_small_12_p16_384.fb_dist_in1k',
  919. 'xcit_small_24_p16_224_dist': 'xcit_small_24_p16_224.fb_dist_in1k',
  920. 'xcit_small_24_p16_384_dist': 'xcit_small_24_p16_384.fb_dist_in1k',
  921. 'xcit_medium_24_p16_224_dist': 'xcit_medium_24_p16_224.fb_dist_in1k',
  922. 'xcit_medium_24_p16_384_dist': 'xcit_medium_24_p16_384.fb_dist_in1k',
  923. 'xcit_large_24_p16_224_dist': 'xcit_large_24_p16_224.fb_dist_in1k',
  924. 'xcit_large_24_p16_384_dist': 'xcit_large_24_p16_384.fb_dist_in1k',
  925. # Patch size 8
  926. 'xcit_nano_12_p8_224_dist': 'xcit_nano_12_p8_224.fb_dist_in1k',
  927. 'xcit_nano_12_p8_384_dist': 'xcit_nano_12_p8_384.fb_dist_in1k',
  928. 'xcit_tiny_12_p8_224_dist': 'xcit_tiny_12_p8_224.fb_dist_in1k',
  929. 'xcit_tiny_12_p8_384_dist': 'xcit_tiny_12_p8_384.fb_dist_in1k',
  930. 'xcit_tiny_24_p8_224_dist': 'xcit_tiny_24_p8_224.fb_dist_in1k',
  931. 'xcit_tiny_24_p8_384_dist': 'xcit_tiny_24_p8_384.fb_dist_in1k',
  932. 'xcit_small_12_p8_224_dist': 'xcit_small_12_p8_224.fb_dist_in1k',
  933. 'xcit_small_12_p8_384_dist': 'xcit_small_12_p8_384.fb_dist_in1k',
  934. 'xcit_small_24_p8_224_dist': 'xcit_small_24_p8_224.fb_dist_in1k',
  935. 'xcit_small_24_p8_384_dist': 'xcit_small_24_p8_384.fb_dist_in1k',
  936. 'xcit_medium_24_p8_224_dist': 'xcit_medium_24_p8_224.fb_dist_in1k',
  937. 'xcit_medium_24_p8_384_dist': 'xcit_medium_24_p8_384.fb_dist_in1k',
  938. 'xcit_large_24_p8_224_dist': 'xcit_large_24_p8_224.fb_dist_in1k',
  939. 'xcit_large_24_p8_384_dist': 'xcit_large_24_p8_384.fb_dist_in1k',
  940. })