metaformer.py 40 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525526527528529530531532533534535536537538539540541542543544545546547548549550551552553554555556557558559560561562563564565566567568569570571572573574575576577578579580581582583584585586587588589590591592593594595596597598599600601602603604605606607608609610611612613614615616617618619620621622623624625626627628629630631632633634635636637638639640641642643644645646647648649650651652653654655656657658659660661662663664665666667668669670671672673674675676677678679680681682683684685686687688689690691692693694695696697698699700701702703704705706707708709710711712713714715716717718719720721722723724725726727728729730731732733734735736737738739740741742743744745746747748749750751752753754755756757758759760761762763764765766767768769770771772773774775776777778779780781782783784785786787788789790791792793794795796797798799800801802803804805806807808809810811812813814815816817818819820821822823824825826827828829830831832833834835836837838839840841842843844845846847848849850851852853854855856857858859860861862863864865866867868869870871872873874875876877878879880881882883884885886887888889890891892893894895896897898899900901902903904905906907908909910911912913914915916917918919920921922923924925926927928929930931932933934935936937938939940941942943944945946947948949950951952953954955956957958959960961962963964965966967968969970971972973974975976977978979980981982983984985986987988989990991992993994995996997998999100010011002100310041005100610071008100910101011101210131014101510161017101810191020102110221023102410251026102710281029103010311032103310341035103610371038103910401041104210431044104510461047104810491050105110521053105410551056105710581059106010611062106310641065106610671068106910701071107210731074107510761077107810791080108110821083108410851086108710881089109010911092109310941095109610971098109911001101110211031104110511061107110811091110111111121113111411151116111711181119112011211122112311241125112611271128112911301131113211331134113511361137113811391140114111421143114411451146114711481149115011511152115311541155115611571158115911601161116211631164116511661167116811691170117111721173117411751176117711781179118011811182
  1. """
  2. Poolformer from MetaFormer is Actually What You Need for Vision https://arxiv.org/abs/2111.11418
  3. IdentityFormer, RandFormer, PoolFormerV2, ConvFormer, and CAFormer
  4. from MetaFormer Baselines for Vision https://arxiv.org/abs/2210.13452
  5. All implemented models support feature extraction and variable input resolution.
  6. Original implementation by Weihao Yu et al.,
  7. adapted for timm by Fredo Guan and Ross Wightman.
  8. Adapted from https://github.com/sail-sg/metaformer, original copyright below
  9. """
  10. # Copyright 2022 Garena Online Private Limited
  11. #
  12. # Licensed under the Apache License, Version 2.0 (the "License");
  13. # you may not use this file except in compliance with the License.
  14. # You may obtain a copy of the License at
  15. #
  16. # http://www.apache.org/licenses/LICENSE-2.0
  17. #
  18. # Unless required by applicable law or agreed to in writing, software
  19. # distributed under the License is distributed on an "AS IS" BASIS,
  20. # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  21. # See the License for the specific language governing permissions and
  22. # limitations under the License.
  23. from collections import OrderedDict
  24. from functools import partial
  25. from typing import List, Optional, Tuple, Union, Type
  26. import torch
  27. import torch.nn as nn
  28. import torch.nn.functional as F
  29. from torch import Tensor
  30. from torch.jit import Final
  31. from timm.data import IMAGENET_DEFAULT_MEAN, IMAGENET_DEFAULT_STD
  32. from timm.layers import (
  33. trunc_normal_,
  34. DropPath,
  35. calculate_drop_path_rates,
  36. SelectAdaptivePool2d,
  37. GroupNorm1,
  38. LayerNorm,
  39. LayerNorm2d,
  40. Mlp,
  41. use_fused_attn,
  42. )
  43. from ._builder import build_model_with_cfg
  44. from ._features import feature_take_indices
  45. from ._manipulate import checkpoint, checkpoint_seq
  46. from ._registry import generate_default_cfgs, register_model
  47. __all__ = ['MetaFormer']
  48. class Stem(nn.Module):
  49. """
  50. Stem implemented by a layer of convolution.
  51. Conv2d params constant across all models.
  52. """
  53. def __init__(
  54. self,
  55. in_channels: int,
  56. out_channels: int,
  57. norm_layer: Optional[Type[nn.Module]] = None,
  58. device=None,
  59. dtype=None,
  60. ):
  61. dd = {'device': device, 'dtype': dtype}
  62. super().__init__()
  63. self.conv = nn.Conv2d(
  64. in_channels,
  65. out_channels,
  66. kernel_size=7,
  67. stride=4,
  68. padding=2,
  69. **dd,
  70. )
  71. self.norm = norm_layer(out_channels, **dd) if norm_layer else nn.Identity()
  72. def forward(self, x):
  73. x = self.conv(x)
  74. x = self.norm(x)
  75. return x
  76. class Downsampling(nn.Module):
  77. """
  78. Downsampling implemented by a layer of convolution.
  79. """
  80. def __init__(
  81. self,
  82. in_channels: int,
  83. out_channels: int,
  84. kernel_size: int,
  85. stride: int = 1,
  86. padding: int = 0,
  87. norm_layer: Optional[Type[nn.Module]] = None,
  88. device=None,
  89. dtype=None,
  90. ):
  91. dd = {'device': device, 'dtype': dtype}
  92. super().__init__()
  93. self.norm = norm_layer(in_channels, **dd) if norm_layer else nn.Identity()
  94. self.conv = nn.Conv2d(
  95. in_channels,
  96. out_channels,
  97. kernel_size=kernel_size,
  98. stride=stride,
  99. padding=padding,
  100. **dd
  101. )
  102. def forward(self, x):
  103. x = self.norm(x)
  104. x = self.conv(x)
  105. return x
  106. class Scale(nn.Module):
  107. """
  108. Scale vector by element multiplications.
  109. """
  110. def __init__(
  111. self,
  112. dim: int,
  113. init_value: float = 1.0,
  114. trainable: bool = True,
  115. use_nchw: bool = True,
  116. device=None,
  117. dtype=None,
  118. ):
  119. dd = {'device': device, 'dtype': dtype}
  120. super().__init__()
  121. self.shape = (dim, 1, 1) if use_nchw else (dim,)
  122. self.scale = nn.Parameter(init_value * torch.ones(dim, **dd), requires_grad=trainable)
  123. def forward(self, x):
  124. return x * self.scale.view(self.shape)
  125. class SquaredReLU(nn.Module):
  126. """
  127. Squared ReLU: https://arxiv.org/abs/2109.08668
  128. """
  129. def __init__(self, inplace: bool = False):
  130. super().__init__()
  131. self.relu = nn.ReLU(inplace=inplace)
  132. def forward(self, x):
  133. return torch.square(self.relu(x))
  134. class StarReLU(nn.Module):
  135. """
  136. StarReLU: s * relu(x) ** 2 + b
  137. """
  138. def __init__(
  139. self,
  140. scale_value: float = 1.0,
  141. bias_value: float = 0.0,
  142. scale_learnable: bool = True,
  143. bias_learnable: bool = True,
  144. mode: Optional[str] = None,
  145. inplace: bool = False,
  146. device=None,
  147. dtype=None,
  148. ):
  149. dd = {'device': device, 'dtype': dtype}
  150. super().__init__()
  151. self.inplace = inplace
  152. self.relu = nn.ReLU(inplace=inplace)
  153. self.scale = nn.Parameter(scale_value * torch.ones(1, **dd), requires_grad=scale_learnable)
  154. self.bias = nn.Parameter(bias_value * torch.ones(1, **dd), requires_grad=bias_learnable)
  155. def forward(self, x):
  156. return self.scale * self.relu(x) ** 2 + self.bias
  157. class Attention(nn.Module):
  158. """
  159. Vanilla self-attention from Transformer: https://arxiv.org/abs/1706.03762.
  160. Modified from timm.
  161. """
  162. fused_attn: Final[bool]
  163. def __init__(
  164. self,
  165. dim: int,
  166. head_dim: int = 32,
  167. num_heads: Optional[int] = None,
  168. qkv_bias: bool = False,
  169. attn_drop: float = 0.,
  170. proj_drop: float = 0.,
  171. proj_bias: bool = False,
  172. device=None,
  173. dtype=None,
  174. **kwargs
  175. ):
  176. dd = {'device': device, 'dtype': dtype}
  177. super().__init__()
  178. self.head_dim = head_dim
  179. self.scale = head_dim ** -0.5
  180. self.fused_attn = use_fused_attn()
  181. self.num_heads = num_heads if num_heads else dim // head_dim
  182. if self.num_heads == 0:
  183. self.num_heads = 1
  184. self.attention_dim = self.num_heads * self.head_dim
  185. self.qkv = nn.Linear(dim, self.attention_dim * 3, bias=qkv_bias, **dd)
  186. self.attn_drop = nn.Dropout(attn_drop)
  187. self.proj = nn.Linear(self.attention_dim, dim, bias=proj_bias, **dd)
  188. self.proj_drop = nn.Dropout(proj_drop)
  189. def forward(self, x):
  190. B, N, C = x.shape
  191. qkv = self.qkv(x).reshape(B, N, 3, self.num_heads, self.head_dim).permute(2, 0, 3, 1, 4)
  192. q, k, v = qkv.unbind(0)
  193. if self.fused_attn:
  194. x = F.scaled_dot_product_attention(
  195. q, k, v,
  196. dropout_p=self.attn_drop.p if self.training else 0.,
  197. )
  198. else:
  199. attn = (q @ k.transpose(-2, -1)) * self.scale
  200. attn = attn.softmax(dim=-1)
  201. attn = self.attn_drop(attn)
  202. x = attn @ v
  203. x = x.transpose(1, 2).reshape(B, N, C)
  204. x = self.proj(x)
  205. x = self.proj_drop(x)
  206. return x
  207. # custom norm modules that disable the bias term, since the original models defs
  208. # used a custom norm with a weight term but no bias term.
  209. class GroupNorm1NoBias(GroupNorm1):
  210. def __init__(self, num_channels: int, **kwargs):
  211. super().__init__(num_channels, **kwargs)
  212. self.eps = kwargs.get('eps', 1e-6)
  213. self.bias = None
  214. class LayerNorm2dNoBias(LayerNorm2d):
  215. def __init__(self, num_channels: int, **kwargs):
  216. super().__init__(num_channels, **kwargs)
  217. self.eps = kwargs.get('eps', 1e-6)
  218. self.bias = None
  219. class LayerNormNoBias(nn.LayerNorm):
  220. def __init__(self, num_channels: int, **kwargs):
  221. super().__init__(num_channels, **kwargs)
  222. self.eps = kwargs.get('eps', 1e-6)
  223. self.bias = None
  224. class SepConv(nn.Module):
  225. r"""
  226. Inverted separable convolution from MobileNetV2: https://arxiv.org/abs/1801.04381.
  227. """
  228. def __init__(
  229. self,
  230. dim: int,
  231. expansion_ratio: float = 2,
  232. act1_layer: Type[nn.Module] = StarReLU,
  233. act2_layer: Type[nn.Module] = nn.Identity,
  234. bias: bool = False,
  235. kernel_size: int = 7,
  236. padding: int = 3,
  237. device=None,
  238. dtype=None,
  239. **kwargs
  240. ):
  241. dd = {'device': device, 'dtype': dtype}
  242. super().__init__()
  243. mid_channels = int(expansion_ratio * dim)
  244. self.pwconv1 = nn.Conv2d(dim, mid_channels, kernel_size=1, bias=bias, **dd)
  245. self.act1 = act1_layer(**dd) if issubclass(act1_layer, StarReLU) else act1_layer()
  246. self.dwconv = nn.Conv2d(
  247. mid_channels,
  248. mid_channels,
  249. kernel_size=kernel_size,
  250. padding=padding,
  251. groups=mid_channels,
  252. bias=bias,
  253. **dd,
  254. ) # depthwise conv
  255. self.act2 = act2_layer(**dd) if issubclass(act2_layer, StarReLU) else act2_layer()
  256. self.pwconv2 = nn.Conv2d(mid_channels, dim, kernel_size=1, bias=bias, **dd)
  257. def forward(self, x):
  258. x = self.pwconv1(x)
  259. x = self.act1(x)
  260. x = self.dwconv(x)
  261. x = self.act2(x)
  262. x = self.pwconv2(x)
  263. return x
  264. class Pooling(nn.Module):
  265. """
  266. Implementation of pooling for PoolFormer: https://arxiv.org/abs/2111.11418
  267. """
  268. def __init__(self, pool_size: int = 3, **kwargs):
  269. super().__init__()
  270. self.pool = nn.AvgPool2d(pool_size, stride=1, padding=pool_size // 2, count_include_pad=False)
  271. def forward(self, x):
  272. y = self.pool(x)
  273. return y - x
  274. class MlpHead(nn.Module):
  275. """ MLP classification head
  276. """
  277. def __init__(
  278. self,
  279. dim: int,
  280. num_classes: int = 1000,
  281. mlp_ratio: float = 4,
  282. act_layer: Type[nn.Module] = SquaredReLU,
  283. norm_layer: Type[nn.Module] = LayerNorm,
  284. drop_rate: float = 0.,
  285. bias: bool = True,
  286. device=None,
  287. dtype=None,
  288. ):
  289. dd = {'device': device, 'dtype': dtype}
  290. super().__init__()
  291. hidden_features = int(mlp_ratio * dim)
  292. self.fc1 = nn.Linear(dim, hidden_features, bias=bias, **dd)
  293. self.act = act_layer()
  294. self.norm = norm_layer(hidden_features, **dd)
  295. self.fc2 = nn.Linear(hidden_features, num_classes, bias=bias, **dd)
  296. self.head_drop = nn.Dropout(drop_rate)
  297. def forward(self, x):
  298. x = self.fc1(x)
  299. x = self.act(x)
  300. x = self.norm(x)
  301. x = self.head_drop(x)
  302. x = self.fc2(x)
  303. return x
  304. class MetaFormerBlock(nn.Module):
  305. """
  306. Implementation of one MetaFormer block.
  307. """
  308. def __init__(
  309. self,
  310. dim: int,
  311. token_mixer: Type[nn.Module] = Pooling,
  312. mlp_act: Type[nn.Module] = StarReLU,
  313. mlp_bias: bool = False,
  314. norm_layer: Type[nn.Module] = LayerNorm2d,
  315. proj_drop: float = 0.,
  316. drop_path: float = 0.,
  317. use_nchw: bool = True,
  318. layer_scale_init_value: Optional[float] = None,
  319. res_scale_init_value: Optional[float] = None,
  320. device=None,
  321. dtype=None,
  322. **kwargs
  323. ):
  324. dd = {'device': device, 'dtype': dtype}
  325. super().__init__()
  326. ls_layer = partial(Scale, dim=dim, init_value=layer_scale_init_value, use_nchw=use_nchw, **dd)
  327. rs_layer = partial(Scale, dim=dim, init_value=res_scale_init_value, use_nchw=use_nchw, **dd)
  328. self.norm1 = norm_layer(dim, **dd)
  329. self.token_mixer = token_mixer(dim=dim, proj_drop=proj_drop, **dd, **kwargs)
  330. self.drop_path1 = DropPath(drop_path) if drop_path > 0. else nn.Identity()
  331. self.layer_scale1 = ls_layer() if layer_scale_init_value is not None else nn.Identity()
  332. self.res_scale1 = rs_layer() if res_scale_init_value is not None else nn.Identity()
  333. self.norm2 = norm_layer(dim, **dd)
  334. self.mlp = Mlp(
  335. dim,
  336. int(4 * dim),
  337. act_layer=mlp_act,
  338. bias=mlp_bias,
  339. drop=proj_drop,
  340. use_conv=use_nchw,
  341. **dd
  342. )
  343. self.drop_path2 = DropPath(drop_path) if drop_path > 0. else nn.Identity()
  344. self.layer_scale2 = ls_layer() if layer_scale_init_value is not None else nn.Identity()
  345. self.res_scale2 = rs_layer() if res_scale_init_value is not None else nn.Identity()
  346. def forward(self, x):
  347. x = self.res_scale1(x) + \
  348. self.layer_scale1(
  349. self.drop_path1(
  350. self.token_mixer(self.norm1(x))
  351. )
  352. )
  353. x = self.res_scale2(x) + \
  354. self.layer_scale2(
  355. self.drop_path2(
  356. self.mlp(self.norm2(x))
  357. )
  358. )
  359. return x
  360. class MetaFormerStage(nn.Module):
  361. def __init__(
  362. self,
  363. in_chs: int,
  364. out_chs: int,
  365. depth: int = 2,
  366. token_mixer: Type[nn.Module] = nn.Identity,
  367. mlp_act: Type[nn.Module] = StarReLU,
  368. mlp_bias: bool = False,
  369. downsample_norm: Optional[Type[nn.Module]] = LayerNorm2d,
  370. norm_layer: Type[nn.Module] = LayerNorm2d,
  371. proj_drop: float = 0.,
  372. dp_rates: List[float] = [0.] * 2,
  373. layer_scale_init_value: Optional[float] = None,
  374. res_scale_init_value: Optional[float] = None,
  375. device=None,
  376. dtype=None,
  377. **kwargs,
  378. ):
  379. dd = {'device': device, 'dtype': dtype}
  380. super().__init__()
  381. self.grad_checkpointing = False
  382. self.use_nchw = not issubclass(token_mixer, Attention)
  383. # don't downsample if in_chs and out_chs are the same
  384. self.downsample = nn.Identity() if in_chs == out_chs else Downsampling(
  385. in_chs,
  386. out_chs,
  387. kernel_size=3,
  388. stride=2,
  389. padding=1,
  390. norm_layer=downsample_norm,
  391. **dd,
  392. )
  393. self.blocks = nn.Sequential(*[MetaFormerBlock(
  394. dim=out_chs,
  395. token_mixer=token_mixer,
  396. mlp_act=mlp_act,
  397. mlp_bias=mlp_bias,
  398. norm_layer=norm_layer,
  399. proj_drop=proj_drop,
  400. drop_path=dp_rates[i],
  401. layer_scale_init_value=layer_scale_init_value,
  402. res_scale_init_value=res_scale_init_value,
  403. use_nchw=self.use_nchw,
  404. **dd,
  405. **kwargs,
  406. ) for i in range(depth)])
  407. @torch.jit.ignore
  408. def set_grad_checkpointing(self, enable=True):
  409. self.grad_checkpointing = enable
  410. def forward(self, x: Tensor):
  411. x = self.downsample(x)
  412. B, C, H, W = x.shape
  413. if not self.use_nchw:
  414. x = x.reshape(B, C, -1).transpose(1, 2)
  415. if self.grad_checkpointing and not torch.jit.is_scripting():
  416. x = checkpoint_seq(self.blocks, x)
  417. else:
  418. x = self.blocks(x)
  419. if not self.use_nchw:
  420. x = x.transpose(1, 2).reshape(B, C, H, W)
  421. return x
  422. class MetaFormer(nn.Module):
  423. r""" MetaFormer
  424. A PyTorch impl of : `MetaFormer Baselines for Vision` -
  425. https://arxiv.org/abs/2210.13452
  426. Args:
  427. in_chans (int): Number of input image channels.
  428. num_classes (int): Number of classes for classification head.
  429. global_pool: Pooling for classifier head.
  430. depths (list or tuple): Number of blocks at each stage.
  431. dims (list or tuple): Feature dimension at each stage.
  432. token_mixers (list, tuple or token_fcn): Token mixer for each stage.
  433. mlp_act: Activation layer for MLP.
  434. mlp_bias (boolean): Enable or disable mlp bias term.
  435. drop_path_rate (float): Stochastic depth rate.
  436. drop_rate (float): Dropout rate.
  437. layer_scale_init_values (list, tuple, float or None): Init value for Layer Scale.
  438. None means not use the layer scale. Form: https://arxiv.org/abs/2103.17239.
  439. res_scale_init_values (list, tuple, float or None): Init value for res Scale on residual connections.
  440. None means not use the res scale. From: https://arxiv.org/abs/2110.09456.
  441. downsample_norm (nn.Module): Norm layer used in stem and downsampling layers.
  442. norm_layers (list, tuple or norm_fcn): Norm layers for each stage.
  443. output_norm: Norm layer before classifier head.
  444. use_mlp_head: Use MLP classification head.
  445. """
  446. def __init__(
  447. self,
  448. in_chans: int = 3,
  449. num_classes: int = 1000,
  450. global_pool: str = 'avg',
  451. depths: Tuple[int, ...] = (2, 2, 6, 2),
  452. dims: Tuple[int, ...] = (64, 128, 320, 512),
  453. token_mixers: Union[Type[nn.Module], List[Type[nn.Module]]] = Pooling,
  454. mlp_act: Type[nn.Module] = StarReLU,
  455. mlp_bias: bool = False,
  456. drop_path_rate: float = 0.,
  457. proj_drop_rate: float = 0.,
  458. drop_rate: float = 0.0,
  459. layer_scale_init_values: Optional[Union[float, List[float]]] = None,
  460. res_scale_init_values: Union[Tuple[Optional[float], ...], List[Optional[float]]] = (None, None, 1.0, 1.0),
  461. downsample_norm: Optional[Type[nn.Module]] = LayerNorm2dNoBias,
  462. norm_layers: Union[Type[nn.Module], List[Type[nn.Module]]] = LayerNorm2dNoBias,
  463. output_norm: Type[nn.Module] = LayerNorm2d,
  464. use_mlp_head: bool = True,
  465. device=None,
  466. dtype=None,
  467. **kwargs,
  468. ):
  469. super().__init__()
  470. dd = {'device': device, 'dtype': dtype}
  471. # Bind dd kwargs to activation layers that need them
  472. if mlp_act in (StarReLU,):
  473. mlp_act = partial(mlp_act, **dd)
  474. self.num_classes = num_classes
  475. self.num_features = dims[-1]
  476. self.drop_rate = drop_rate
  477. self.use_mlp_head = use_mlp_head
  478. self.num_stages = len(depths)
  479. # convert everything to lists if they aren't indexable
  480. if not isinstance(depths, (list, tuple)):
  481. depths = [depths] # it means the model has only one stage
  482. if not isinstance(dims, (list, tuple)):
  483. dims = [dims]
  484. if not isinstance(token_mixers, (list, tuple)):
  485. token_mixers = [token_mixers] * self.num_stages
  486. if not isinstance(norm_layers, (list, tuple)):
  487. norm_layers = [norm_layers] * self.num_stages
  488. if not isinstance(layer_scale_init_values, (list, tuple)):
  489. layer_scale_init_values = [layer_scale_init_values] * self.num_stages
  490. if not isinstance(res_scale_init_values, (list, tuple)):
  491. res_scale_init_values = [res_scale_init_values] * self.num_stages
  492. self.grad_checkpointing = False
  493. self.feature_info = []
  494. self.stem = Stem(
  495. in_chans,
  496. dims[0],
  497. norm_layer=downsample_norm,
  498. **dd,
  499. )
  500. stages = []
  501. prev_dim = dims[0]
  502. dp_rates = calculate_drop_path_rates(drop_path_rate, depths, stagewise=True)
  503. for i in range(self.num_stages):
  504. stages += [MetaFormerStage(
  505. prev_dim,
  506. dims[i],
  507. depth=depths[i],
  508. token_mixer=token_mixers[i],
  509. mlp_act=mlp_act,
  510. mlp_bias=mlp_bias,
  511. proj_drop=proj_drop_rate,
  512. dp_rates=dp_rates[i],
  513. layer_scale_init_value=layer_scale_init_values[i],
  514. res_scale_init_value=res_scale_init_values[i],
  515. downsample_norm=downsample_norm,
  516. norm_layer=norm_layers[i],
  517. **dd,
  518. **kwargs,
  519. )]
  520. prev_dim = dims[i]
  521. self.feature_info += [dict(num_chs=dims[i], reduction=2**(i+2), module=f'stages.{i}')]
  522. self.stages = nn.Sequential(*stages)
  523. # if using MlpHead, dropout is handled by MlpHead
  524. if num_classes > 0:
  525. if self.use_mlp_head:
  526. # FIXME not actually returning mlp hidden state right now as pre-logits.
  527. final = MlpHead(self.num_features, num_classes, drop_rate=self.drop_rate, **dd)
  528. self.head_hidden_size = self.num_features
  529. else:
  530. final = nn.Linear(self.num_features, num_classes, **dd)
  531. self.head_hidden_size = self.num_features
  532. else:
  533. final = nn.Identity()
  534. self.head = nn.Sequential(OrderedDict([
  535. ('global_pool', SelectAdaptivePool2d(pool_type=global_pool)),
  536. ('norm', output_norm(self.num_features, **dd)),
  537. ('flatten', nn.Flatten(1) if global_pool else nn.Identity()),
  538. ('drop', nn.Dropout(drop_rate) if self.use_mlp_head else nn.Identity()),
  539. ('fc', final)
  540. ]))
  541. self.apply(self._init_weights)
  542. def _init_weights(self, m):
  543. if isinstance(m, (nn.Conv2d, nn.Linear)):
  544. trunc_normal_(m.weight, std=.02)
  545. if m.bias is not None:
  546. nn.init.constant_(m.bias, 0)
  547. @torch.jit.ignore
  548. def set_grad_checkpointing(self, enable=True):
  549. self.grad_checkpointing = enable
  550. for stage in self.stages:
  551. stage.set_grad_checkpointing(enable=enable)
  552. @torch.jit.ignore
  553. def get_classifier(self) -> nn.Module:
  554. return self.head.fc
  555. def reset_classifier(self, num_classes: int, global_pool: Optional[str] = None, device=None, dtype=None):
  556. dd = {'device': device, 'dtype': dtype}
  557. self.num_classes = num_classes
  558. if global_pool is not None:
  559. self.head.global_pool = SelectAdaptivePool2d(pool_type=global_pool)
  560. self.head.flatten = nn.Flatten(1) if global_pool else nn.Identity()
  561. if num_classes > 0:
  562. if self.use_mlp_head:
  563. final = MlpHead(self.num_features, num_classes, drop_rate=self.drop_rate, **dd)
  564. else:
  565. final = nn.Linear(self.num_features, num_classes, **dd)
  566. else:
  567. final = nn.Identity()
  568. self.head.fc = final
  569. def forward_intermediates(
  570. self,
  571. x: torch.Tensor,
  572. indices: Optional[Union[int, List[int]]] = None,
  573. norm: bool = False,
  574. stop_early: bool = False,
  575. output_fmt: str = 'NCHW',
  576. intermediates_only: bool = False,
  577. ) -> Union[List[torch.Tensor], Tuple[torch.Tensor, List[torch.Tensor]]]:
  578. """ Forward features that returns intermediates.
  579. Args:
  580. x: Input image tensor
  581. indices: Take last n blocks if int, all if None, select matching indices if sequence
  582. norm: Apply norm layer to compatible intermediates
  583. stop_early: Stop iterating over blocks when last desired intermediate hit
  584. output_fmt: Shape of intermediate feature outputs
  585. intermediates_only: Only return intermediate features
  586. Returns:
  587. """
  588. assert output_fmt in ('NCHW',), 'Output shape must be NCHW.'
  589. intermediates = []
  590. take_indices, max_index = feature_take_indices(len(self.stages), indices)
  591. # forward pass
  592. x = self.stem(x)
  593. if torch.jit.is_scripting() or not stop_early: # can't slice blocks in torchscript
  594. stages = self.stages
  595. else:
  596. stages = self.stages[:max_index + 1]
  597. for feat_idx, stage in enumerate(stages):
  598. if self.grad_checkpointing and not torch.jit.is_scripting():
  599. x = checkpoint(stage, x)
  600. else:
  601. x = stage(x)
  602. if feat_idx in take_indices:
  603. intermediates.append(x)
  604. if intermediates_only:
  605. return intermediates
  606. return x, intermediates
  607. def prune_intermediate_layers(
  608. self,
  609. indices: Union[int, List[int]] = 1,
  610. prune_norm: bool = False,
  611. prune_head: bool = True,
  612. ):
  613. """ Prune layers not required for specified intermediates.
  614. """
  615. take_indices, max_index = feature_take_indices(len(self.stages), indices)
  616. self.stages = self.stages[:max_index + 1] # truncate blocks w/ stem as idx 0
  617. if prune_head:
  618. self.reset_classifier(0, '')
  619. return take_indices
  620. def forward_head(self, x: Tensor, pre_logits: bool = False):
  621. # NOTE nn.Sequential in head broken down since can't call head[:-1](x) in torchscript :(
  622. x = self.head.global_pool(x)
  623. x = self.head.norm(x)
  624. x = self.head.flatten(x)
  625. x = self.head.drop(x)
  626. return x if pre_logits else self.head.fc(x)
  627. def forward_features(self, x: Tensor):
  628. x = self.stem(x)
  629. if self.grad_checkpointing and not torch.jit.is_scripting():
  630. x = checkpoint_seq(self.stages, x)
  631. else:
  632. x = self.stages(x)
  633. return x
  634. def forward(self, x: Tensor):
  635. x = self.forward_features(x)
  636. x = self.forward_head(x)
  637. return x
  638. # this works but it's long and breaks backwards compatibility with weights from the poolformer-only impl
  639. def checkpoint_filter_fn(state_dict, model):
  640. if 'stem.conv.weight' in state_dict:
  641. return state_dict
  642. import re
  643. out_dict = {}
  644. is_poolformerv1 = 'network.0.0.mlp.fc1.weight' in state_dict
  645. model_state_dict = model.state_dict()
  646. for k, v in state_dict.items():
  647. if is_poolformerv1:
  648. k = re.sub(r'layer_scale_([0-9]+)', r'layer_scale\1.scale', k)
  649. k = k.replace('network.1', 'downsample_layers.1')
  650. k = k.replace('network.3', 'downsample_layers.2')
  651. k = k.replace('network.5', 'downsample_layers.3')
  652. k = k.replace('network.2', 'network.1')
  653. k = k.replace('network.4', 'network.2')
  654. k = k.replace('network.6', 'network.3')
  655. k = k.replace('network', 'stages')
  656. k = re.sub(r'downsample_layers.([0-9]+)', r'stages.\1.downsample', k)
  657. k = k.replace('downsample.proj', 'downsample.conv')
  658. k = k.replace('patch_embed.proj', 'patch_embed.conv')
  659. k = re.sub(r'([0-9]+).([0-9]+)', r'\1.blocks.\2', k)
  660. k = k.replace('stages.0.downsample', 'patch_embed')
  661. k = k.replace('patch_embed', 'stem')
  662. k = k.replace('post_norm', 'norm')
  663. k = k.replace('pre_norm', 'norm')
  664. k = re.sub(r'^head', 'head.fc', k)
  665. k = re.sub(r'^norm', 'head.norm', k)
  666. if v.shape != model_state_dict[k] and v.numel() == model_state_dict[k].numel():
  667. v = v.reshape(model_state_dict[k].shape)
  668. out_dict[k] = v
  669. return out_dict
  670. def _create_metaformer(variant, pretrained=False, **kwargs):
  671. default_out_indices = tuple(i for i, _ in enumerate(kwargs.get('depths', (2, 2, 6, 2))))
  672. out_indices = kwargs.pop('out_indices', default_out_indices)
  673. model = build_model_with_cfg(
  674. MetaFormer,
  675. variant,
  676. pretrained,
  677. pretrained_filter_fn=checkpoint_filter_fn,
  678. feature_cfg=dict(flatten_sequential=True, out_indices=out_indices),
  679. **kwargs,
  680. )
  681. return model
  682. def _cfg(url='', **kwargs):
  683. return {
  684. 'url': url,
  685. 'num_classes': 1000, 'input_size': (3, 224, 224), 'pool_size': (7, 7),
  686. 'crop_pct': 1.0, 'interpolation': 'bicubic',
  687. 'mean': IMAGENET_DEFAULT_MEAN, 'std': IMAGENET_DEFAULT_STD,
  688. 'classifier': 'head.fc', 'first_conv': 'stem.conv',
  689. 'license': 'apache-2.0',
  690. **kwargs
  691. }
  692. default_cfgs = generate_default_cfgs({
  693. 'poolformer_s12.sail_in1k': _cfg(
  694. hf_hub_id='timm/',
  695. crop_pct=0.9),
  696. 'poolformer_s24.sail_in1k': _cfg(
  697. hf_hub_id='timm/',
  698. crop_pct=0.9),
  699. 'poolformer_s36.sail_in1k': _cfg(
  700. hf_hub_id='timm/',
  701. crop_pct=0.9),
  702. 'poolformer_m36.sail_in1k': _cfg(
  703. hf_hub_id='timm/',
  704. crop_pct=0.95),
  705. 'poolformer_m48.sail_in1k': _cfg(
  706. hf_hub_id='timm/',
  707. crop_pct=0.95),
  708. 'poolformerv2_s12.sail_in1k': _cfg(hf_hub_id='timm/'),
  709. 'poolformerv2_s24.sail_in1k': _cfg(hf_hub_id='timm/'),
  710. 'poolformerv2_s36.sail_in1k': _cfg(hf_hub_id='timm/'),
  711. 'poolformerv2_m36.sail_in1k': _cfg(hf_hub_id='timm/'),
  712. 'poolformerv2_m48.sail_in1k': _cfg(hf_hub_id='timm/'),
  713. 'convformer_s18.sail_in1k': _cfg(
  714. hf_hub_id='timm/',
  715. classifier='head.fc.fc2'),
  716. 'convformer_s18.sail_in1k_384': _cfg(
  717. hf_hub_id='timm/',
  718. classifier='head.fc.fc2', input_size=(3, 384, 384), pool_size=(12, 12)),
  719. 'convformer_s18.sail_in22k_ft_in1k': _cfg(
  720. hf_hub_id='timm/',
  721. classifier='head.fc.fc2'),
  722. 'convformer_s18.sail_in22k_ft_in1k_384': _cfg(
  723. hf_hub_id='timm/',
  724. classifier='head.fc.fc2', input_size=(3, 384, 384), pool_size=(12, 12)),
  725. 'convformer_s18.sail_in22k': _cfg(
  726. hf_hub_id='timm/',
  727. classifier='head.fc.fc2', num_classes=21841),
  728. 'convformer_s36.sail_in1k': _cfg(
  729. hf_hub_id='timm/',
  730. classifier='head.fc.fc2'),
  731. 'convformer_s36.sail_in1k_384': _cfg(
  732. hf_hub_id='timm/',
  733. classifier='head.fc.fc2', input_size=(3, 384, 384), pool_size=(12, 12)),
  734. 'convformer_s36.sail_in22k_ft_in1k': _cfg(
  735. hf_hub_id='timm/',
  736. classifier='head.fc.fc2'),
  737. 'convformer_s36.sail_in22k_ft_in1k_384': _cfg(
  738. hf_hub_id='timm/',
  739. classifier='head.fc.fc2', input_size=(3, 384, 384), pool_size=(12, 12)),
  740. 'convformer_s36.sail_in22k': _cfg(
  741. hf_hub_id='timm/',
  742. classifier='head.fc.fc2', num_classes=21841),
  743. 'convformer_m36.sail_in1k': _cfg(
  744. hf_hub_id='timm/',
  745. classifier='head.fc.fc2'),
  746. 'convformer_m36.sail_in1k_384': _cfg(
  747. hf_hub_id='timm/',
  748. classifier='head.fc.fc2', input_size=(3, 384, 384), pool_size=(12, 12)),
  749. 'convformer_m36.sail_in22k_ft_in1k': _cfg(
  750. hf_hub_id='timm/',
  751. classifier='head.fc.fc2'),
  752. 'convformer_m36.sail_in22k_ft_in1k_384': _cfg(
  753. hf_hub_id='timm/',
  754. classifier='head.fc.fc2', input_size=(3, 384, 384), pool_size=(12, 12)),
  755. 'convformer_m36.sail_in22k': _cfg(
  756. hf_hub_id='timm/',
  757. classifier='head.fc.fc2', num_classes=21841),
  758. 'convformer_b36.sail_in1k': _cfg(
  759. hf_hub_id='timm/',
  760. classifier='head.fc.fc2'),
  761. 'convformer_b36.sail_in1k_384': _cfg(
  762. hf_hub_id='timm/',
  763. classifier='head.fc.fc2', input_size=(3, 384, 384), pool_size=(12, 12)),
  764. 'convformer_b36.sail_in22k_ft_in1k': _cfg(
  765. hf_hub_id='timm/',
  766. classifier='head.fc.fc2'),
  767. 'convformer_b36.sail_in22k_ft_in1k_384': _cfg(
  768. hf_hub_id='timm/',
  769. classifier='head.fc.fc2', input_size=(3, 384, 384), pool_size=(12, 12)),
  770. 'convformer_b36.sail_in22k': _cfg(
  771. hf_hub_id='timm/',
  772. classifier='head.fc.fc2', num_classes=21841),
  773. 'caformer_s18.sail_in1k': _cfg(
  774. hf_hub_id='timm/',
  775. classifier='head.fc.fc2'),
  776. 'caformer_s18.sail_in1k_384': _cfg(
  777. hf_hub_id='timm/',
  778. classifier='head.fc.fc2', input_size=(3, 384, 384), pool_size=(12, 12)),
  779. 'caformer_s18.sail_in22k_ft_in1k': _cfg(
  780. hf_hub_id='timm/',
  781. classifier='head.fc.fc2'),
  782. 'caformer_s18.sail_in22k_ft_in1k_384': _cfg(
  783. hf_hub_id='timm/',
  784. classifier='head.fc.fc2', input_size=(3, 384, 384), pool_size=(12, 12)),
  785. 'caformer_s18.sail_in22k': _cfg(
  786. hf_hub_id='timm/',
  787. classifier='head.fc.fc2', num_classes=21841),
  788. 'caformer_s36.sail_in1k': _cfg(
  789. hf_hub_id='timm/',
  790. classifier='head.fc.fc2'),
  791. 'caformer_s36.sail_in1k_384': _cfg(
  792. hf_hub_id='timm/',
  793. classifier='head.fc.fc2', input_size=(3, 384, 384), pool_size=(12, 12)),
  794. 'caformer_s36.sail_in22k_ft_in1k': _cfg(
  795. hf_hub_id='timm/',
  796. classifier='head.fc.fc2'),
  797. 'caformer_s36.sail_in22k_ft_in1k_384': _cfg(
  798. hf_hub_id='timm/',
  799. classifier='head.fc.fc2', input_size=(3, 384, 384), pool_size=(12, 12)),
  800. 'caformer_s36.sail_in22k': _cfg(
  801. hf_hub_id='timm/',
  802. classifier='head.fc.fc2', num_classes=21841),
  803. 'caformer_m36.sail_in1k': _cfg(
  804. hf_hub_id='timm/',
  805. classifier='head.fc.fc2'),
  806. 'caformer_m36.sail_in1k_384': _cfg(
  807. hf_hub_id='timm/',
  808. classifier='head.fc.fc2', input_size=(3, 384, 384), pool_size=(12, 12)),
  809. 'caformer_m36.sail_in22k_ft_in1k': _cfg(
  810. hf_hub_id='timm/',
  811. classifier='head.fc.fc2'),
  812. 'caformer_m36.sail_in22k_ft_in1k_384': _cfg(
  813. hf_hub_id='timm/',
  814. classifier='head.fc.fc2', input_size=(3, 384, 384), pool_size=(12, 12)),
  815. 'caformer_m36.sail_in22k': _cfg(
  816. hf_hub_id='timm/',
  817. classifier='head.fc.fc2', num_classes=21841),
  818. 'caformer_b36.sail_in1k': _cfg(
  819. hf_hub_id='timm/',
  820. classifier='head.fc.fc2'),
  821. 'caformer_b36.sail_in1k_384': _cfg(
  822. hf_hub_id='timm/',
  823. classifier='head.fc.fc2', input_size=(3, 384, 384), pool_size=(12, 12)),
  824. 'caformer_b36.sail_in22k_ft_in1k': _cfg(
  825. hf_hub_id='timm/',
  826. classifier='head.fc.fc2'),
  827. 'caformer_b36.sail_in22k_ft_in1k_384': _cfg(
  828. hf_hub_id='timm/',
  829. classifier='head.fc.fc2', input_size=(3, 384, 384), pool_size=(12, 12)),
  830. 'caformer_b36.sail_in22k': _cfg(
  831. hf_hub_id='timm/',
  832. classifier='head.fc.fc2', num_classes=21841),
  833. })
  834. @register_model
  835. def poolformer_s12(pretrained=False, **kwargs) -> MetaFormer:
  836. model_kwargs = dict(
  837. depths=[2, 2, 6, 2],
  838. dims=[64, 128, 320, 512],
  839. downsample_norm=None,
  840. mlp_act=nn.GELU,
  841. mlp_bias=True,
  842. norm_layers=GroupNorm1,
  843. layer_scale_init_values=1e-5,
  844. res_scale_init_values=None,
  845. use_mlp_head=False,
  846. **kwargs)
  847. return _create_metaformer('poolformer_s12', pretrained=pretrained, **model_kwargs)
  848. @register_model
  849. def poolformer_s24(pretrained=False, **kwargs) -> MetaFormer:
  850. model_kwargs = dict(
  851. depths=[4, 4, 12, 4],
  852. dims=[64, 128, 320, 512],
  853. downsample_norm=None,
  854. mlp_act=nn.GELU,
  855. mlp_bias=True,
  856. norm_layers=GroupNorm1,
  857. layer_scale_init_values=1e-5,
  858. res_scale_init_values=None,
  859. use_mlp_head=False,
  860. **kwargs)
  861. return _create_metaformer('poolformer_s24', pretrained=pretrained, **model_kwargs)
  862. @register_model
  863. def poolformer_s36(pretrained=False, **kwargs) -> MetaFormer:
  864. model_kwargs = dict(
  865. depths=[6, 6, 18, 6],
  866. dims=[64, 128, 320, 512],
  867. downsample_norm=None,
  868. mlp_act=nn.GELU,
  869. mlp_bias=True,
  870. norm_layers=GroupNorm1,
  871. layer_scale_init_values=1e-6,
  872. res_scale_init_values=None,
  873. use_mlp_head=False,
  874. **kwargs)
  875. return _create_metaformer('poolformer_s36', pretrained=pretrained, **model_kwargs)
  876. @register_model
  877. def poolformer_m36(pretrained=False, **kwargs) -> MetaFormer:
  878. model_kwargs = dict(
  879. depths=[6, 6, 18, 6],
  880. dims=[96, 192, 384, 768],
  881. downsample_norm=None,
  882. mlp_act=nn.GELU,
  883. mlp_bias=True,
  884. norm_layers=GroupNorm1,
  885. layer_scale_init_values=1e-6,
  886. res_scale_init_values=None,
  887. use_mlp_head=False,
  888. **kwargs)
  889. return _create_metaformer('poolformer_m36', pretrained=pretrained, **model_kwargs)
  890. @register_model
  891. def poolformer_m48(pretrained=False, **kwargs) -> MetaFormer:
  892. model_kwargs = dict(
  893. depths=[8, 8, 24, 8],
  894. dims=[96, 192, 384, 768],
  895. downsample_norm=None,
  896. mlp_act=nn.GELU,
  897. mlp_bias=True,
  898. norm_layers=GroupNorm1,
  899. layer_scale_init_values=1e-6,
  900. res_scale_init_values=None,
  901. use_mlp_head=False,
  902. **kwargs)
  903. return _create_metaformer('poolformer_m48', pretrained=pretrained, **model_kwargs)
  904. @register_model
  905. def poolformerv2_s12(pretrained=False, **kwargs) -> MetaFormer:
  906. model_kwargs = dict(
  907. depths=[2, 2, 6, 2],
  908. dims=[64, 128, 320, 512],
  909. norm_layers=GroupNorm1NoBias,
  910. use_mlp_head=False,
  911. **kwargs)
  912. return _create_metaformer('poolformerv2_s12', pretrained=pretrained, **model_kwargs)
  913. @register_model
  914. def poolformerv2_s24(pretrained=False, **kwargs) -> MetaFormer:
  915. model_kwargs = dict(
  916. depths=[4, 4, 12, 4],
  917. dims=[64, 128, 320, 512],
  918. norm_layers=GroupNorm1NoBias,
  919. use_mlp_head=False,
  920. **kwargs)
  921. return _create_metaformer('poolformerv2_s24', pretrained=pretrained, **model_kwargs)
  922. @register_model
  923. def poolformerv2_s36(pretrained=False, **kwargs) -> MetaFormer:
  924. model_kwargs = dict(
  925. depths=[6, 6, 18, 6],
  926. dims=[64, 128, 320, 512],
  927. norm_layers=GroupNorm1NoBias,
  928. use_mlp_head=False,
  929. **kwargs)
  930. return _create_metaformer('poolformerv2_s36', pretrained=pretrained, **model_kwargs)
  931. @register_model
  932. def poolformerv2_m36(pretrained=False, **kwargs) -> MetaFormer:
  933. model_kwargs = dict(
  934. depths=[6, 6, 18, 6],
  935. dims=[96, 192, 384, 768],
  936. norm_layers=GroupNorm1NoBias,
  937. use_mlp_head=False,
  938. **kwargs)
  939. return _create_metaformer('poolformerv2_m36', pretrained=pretrained, **model_kwargs)
  940. @register_model
  941. def poolformerv2_m48(pretrained=False, **kwargs) -> MetaFormer:
  942. model_kwargs = dict(
  943. depths=[8, 8, 24, 8],
  944. dims=[96, 192, 384, 768],
  945. norm_layers=GroupNorm1NoBias,
  946. use_mlp_head=False,
  947. **kwargs)
  948. return _create_metaformer('poolformerv2_m48', pretrained=pretrained, **model_kwargs)
  949. @register_model
  950. def convformer_s18(pretrained=False, **kwargs) -> MetaFormer:
  951. model_kwargs = dict(
  952. depths=[3, 3, 9, 3],
  953. dims=[64, 128, 320, 512],
  954. token_mixers=SepConv,
  955. norm_layers=LayerNorm2dNoBias,
  956. **kwargs)
  957. return _create_metaformer('convformer_s18', pretrained=pretrained, **model_kwargs)
  958. @register_model
  959. def convformer_s36(pretrained=False, **kwargs) -> MetaFormer:
  960. model_kwargs = dict(
  961. depths=[3, 12, 18, 3],
  962. dims=[64, 128, 320, 512],
  963. token_mixers=SepConv,
  964. norm_layers=LayerNorm2dNoBias,
  965. **kwargs)
  966. return _create_metaformer('convformer_s36', pretrained=pretrained, **model_kwargs)
  967. @register_model
  968. def convformer_m36(pretrained=False, **kwargs) -> MetaFormer:
  969. model_kwargs = dict(
  970. depths=[3, 12, 18, 3],
  971. dims=[96, 192, 384, 576],
  972. token_mixers=SepConv,
  973. norm_layers=LayerNorm2dNoBias,
  974. **kwargs)
  975. return _create_metaformer('convformer_m36', pretrained=pretrained, **model_kwargs)
  976. @register_model
  977. def convformer_b36(pretrained=False, **kwargs) -> MetaFormer:
  978. model_kwargs = dict(
  979. depths=[3, 12, 18, 3],
  980. dims=[128, 256, 512, 768],
  981. token_mixers=SepConv,
  982. norm_layers=LayerNorm2dNoBias,
  983. **kwargs)
  984. return _create_metaformer('convformer_b36', pretrained=pretrained, **model_kwargs)
  985. @register_model
  986. def caformer_s18(pretrained=False, **kwargs) -> MetaFormer:
  987. model_kwargs = dict(
  988. depths=[3, 3, 9, 3],
  989. dims=[64, 128, 320, 512],
  990. token_mixers=[SepConv, SepConv, Attention, Attention],
  991. norm_layers=[LayerNorm2dNoBias] * 2 + [LayerNormNoBias] * 2,
  992. **kwargs)
  993. return _create_metaformer('caformer_s18', pretrained=pretrained, **model_kwargs)
  994. @register_model
  995. def caformer_s36(pretrained=False, **kwargs) -> MetaFormer:
  996. model_kwargs = dict(
  997. depths=[3, 12, 18, 3],
  998. dims=[64, 128, 320, 512],
  999. token_mixers=[SepConv, SepConv, Attention, Attention],
  1000. norm_layers=[LayerNorm2dNoBias] * 2 + [LayerNormNoBias] * 2,
  1001. **kwargs)
  1002. return _create_metaformer('caformer_s36', pretrained=pretrained, **model_kwargs)
  1003. @register_model
  1004. def caformer_m36(pretrained=False, **kwargs) -> MetaFormer:
  1005. model_kwargs = dict(
  1006. depths=[3, 12, 18, 3],
  1007. dims=[96, 192, 384, 576],
  1008. token_mixers=[SepConv, SepConv, Attention, Attention],
  1009. norm_layers=[LayerNorm2dNoBias] * 2 + [LayerNormNoBias] * 2,
  1010. **kwargs)
  1011. return _create_metaformer('caformer_m36', pretrained=pretrained, **model_kwargs)
  1012. @register_model
  1013. def caformer_b36(pretrained=False, **kwargs) -> MetaFormer:
  1014. model_kwargs = dict(
  1015. depths=[3, 12, 18, 3],
  1016. dims=[128, 256, 512, 768],
  1017. token_mixers=[SepConv, SepConv, Attention, Attention],
  1018. norm_layers=[LayerNorm2dNoBias] * 2 + [LayerNormNoBias] * 2,
  1019. **kwargs)
  1020. return _create_metaformer('caformer_b36', pretrained=pretrained, **model_kwargs)