mlp_mixer.py 33 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525526527528529530531532533534535536537538539540541542543544545546547548549550551552553554555556557558559560561562563564565566567568569570571572573574575576577578579580581582583584585586587588589590591592593594595596597598599600601602603604605606607608609610611612613614615616617618619620621622623624625626627628629630631632633634635636637638639640641642643644645646647648649650651652653654655656657658659660661662663664665666667668669670671672673674675676677678679680681682683684685686687688689690691692693694695696697698699700701702703704705706707708709710711712713714715716717718719720721722723724725726727728729730731732733734735736737738739740741742743744745746747748749750751752753754755756757758759760761762763764765766767768769770771772773774775776777778779780781782783784785786787788789790791792793794795796797798799800801802803804805806807808809810811812813814815816817818819820821822823824825826827828829830831832833834835836837838839840841842843844845846847848849850851852853854855856857858859860861862863864865866867868869870871872873874875876877878879
  1. """ MLP-Mixer, ResMLP, and gMLP in PyTorch
  2. This impl originally based on MLP-Mixer paper.
  3. Official JAX impl: https://github.com/google-research/vision_transformer/blob/linen/vit_jax/models_mixer.py
  4. Paper: 'MLP-Mixer: An all-MLP Architecture for Vision' - https://arxiv.org/abs/2105.01601
  5. @article{tolstikhin2021,
  6. title={MLP-Mixer: An all-MLP Architecture for Vision},
  7. author={Tolstikhin, Ilya and Houlsby, Neil and Kolesnikov, Alexander and Beyer, Lucas and Zhai, Xiaohua and Unterthiner,
  8. Thomas and Yung, Jessica and Keysers, Daniel and Uszkoreit, Jakob and Lucic, Mario and Dosovitskiy, Alexey},
  9. journal={arXiv preprint arXiv:2105.01601},
  10. year={2021}
  11. }
  12. Also supporting ResMlp, and a preliminary (not verified) implementations of gMLP
  13. Code: https://github.com/facebookresearch/deit
  14. Paper: `ResMLP: Feedforward networks for image classification...` - https://arxiv.org/abs/2105.03404
  15. @misc{touvron2021resmlp,
  16. title={ResMLP: Feedforward networks for image classification with data-efficient training},
  17. author={Hugo Touvron and Piotr Bojanowski and Mathilde Caron and Matthieu Cord and Alaaeldin El-Nouby and
  18. Edouard Grave and Armand Joulin and Gabriel Synnaeve and Jakob Verbeek and Hervé Jégou},
  19. year={2021},
  20. eprint={2105.03404},
  21. }
  22. Paper: `Pay Attention to MLPs` - https://arxiv.org/abs/2105.08050
  23. @misc{liu2021pay,
  24. title={Pay Attention to MLPs},
  25. author={Hanxiao Liu and Zihang Dai and David R. So and Quoc V. Le},
  26. year={2021},
  27. eprint={2105.08050},
  28. }
  29. A thank you to paper authors for releasing code and weights.
  30. Hacked together by / Copyright 2021 Ross Wightman
  31. """
  32. import math
  33. from functools import partial
  34. from typing import Any, Dict, List, Optional, Type, Union, Tuple
  35. import torch
  36. import torch.nn as nn
  37. from timm.data import IMAGENET_DEFAULT_MEAN, IMAGENET_DEFAULT_STD
  38. from timm.layers import PatchEmbed, Mlp, GluMlp, GatedMlp, DropPath, lecun_normal_, to_2tuple
  39. from ._builder import build_model_with_cfg
  40. from ._features import feature_take_indices
  41. from ._manipulate import named_apply, checkpoint, checkpoint_seq
  42. from ._registry import generate_default_cfgs, register_model, register_model_deprecations
  43. __all__ = ['MixerBlock', 'MlpMixer'] # model_registry will add each entrypoint fn to this
  44. class MixerBlock(nn.Module):
  45. """Residual Block w/ token mixing and channel MLPs.
  46. Based on: 'MLP-Mixer: An all-MLP Architecture for Vision' - https://arxiv.org/abs/2105.01601
  47. """
  48. def __init__(
  49. self,
  50. dim: int,
  51. seq_len: int,
  52. mlp_ratio: Union[float, Tuple[float, float]] = (0.5, 4.0),
  53. mlp_layer: Type[nn.Module] = Mlp,
  54. norm_layer: Type[nn.Module] = partial(nn.LayerNorm, eps=1e-6),
  55. act_layer: Type[nn.Module] = nn.GELU,
  56. drop: float = 0.,
  57. drop_path: float = 0.,
  58. device=None,
  59. dtype=None,
  60. ) -> None:
  61. """Initialize MixerBlock.
  62. Args:
  63. dim: Dimension of input features.
  64. seq_len: Sequence length.
  65. mlp_ratio: Expansion ratios for token mixing and channel MLPs.
  66. mlp_layer: MLP layer class.
  67. norm_layer: Normalization layer.
  68. act_layer: Activation layer.
  69. drop: Dropout rate.
  70. drop_path: Drop path rate.
  71. """
  72. dd = {'device': device, 'dtype': dtype}
  73. super().__init__()
  74. tokens_dim, channels_dim = [int(x * dim) for x in to_2tuple(mlp_ratio)]
  75. self.norm1 = norm_layer(dim, **dd)
  76. self.mlp_tokens = mlp_layer(seq_len, tokens_dim, act_layer=act_layer, drop=drop, **dd)
  77. self.drop_path = DropPath(drop_path) if drop_path > 0. else nn.Identity()
  78. self.norm2 = norm_layer(dim, **dd)
  79. self.mlp_channels = mlp_layer(dim, channels_dim, act_layer=act_layer, drop=drop, **dd)
  80. def forward(self, x: torch.Tensor) -> torch.Tensor:
  81. """Forward pass."""
  82. x = x + self.drop_path(self.mlp_tokens(self.norm1(x).transpose(1, 2)).transpose(1, 2))
  83. x = x + self.drop_path(self.mlp_channels(self.norm2(x)))
  84. return x
  85. class Affine(nn.Module):
  86. """Affine transformation layer."""
  87. def __init__(self, dim: int, device=None, dtype=None) -> None:
  88. """Initialize Affine layer.
  89. Args:
  90. dim: Dimension of features.
  91. """
  92. dd = {'device': device, 'dtype': dtype}
  93. super().__init__()
  94. self.alpha = nn.Parameter(torch.ones((1, 1, dim), **dd))
  95. self.beta = nn.Parameter(torch.zeros((1, 1, dim), **dd))
  96. def forward(self, x: torch.Tensor) -> torch.Tensor:
  97. """Apply affine transformation."""
  98. return torch.addcmul(self.beta, self.alpha, x)
  99. class ResBlock(nn.Module):
  100. """Residual MLP block w/ LayerScale and Affine 'norm'.
  101. Based on: `ResMLP: Feedforward networks for image classification...` - https://arxiv.org/abs/2105.03404
  102. """
  103. def __init__(
  104. self,
  105. dim: int,
  106. seq_len: int,
  107. mlp_ratio: float = 4,
  108. mlp_layer: Type[nn.Module] = Mlp,
  109. norm_layer: Type[nn.Module] = Affine,
  110. act_layer: Type[nn.Module] = nn.GELU,
  111. init_values: float = 1e-4,
  112. drop: float = 0.,
  113. drop_path: float = 0.,
  114. device=None,
  115. dtype=None,
  116. ) -> None:
  117. """Initialize ResBlock.
  118. Args:
  119. dim: Dimension of input features.
  120. seq_len: Sequence length.
  121. mlp_ratio: Channel MLP expansion ratio.
  122. mlp_layer: MLP layer class.
  123. norm_layer: Normalization layer.
  124. act_layer: Activation layer.
  125. init_values: Initial values for layer scale.
  126. drop: Dropout rate.
  127. drop_path: Drop path rate.
  128. """
  129. dd = {'device': device, 'dtype': dtype}
  130. super().__init__()
  131. channel_dim = int(dim * mlp_ratio)
  132. self.norm1 = norm_layer(dim, **dd)
  133. self.linear_tokens = nn.Linear(seq_len, seq_len, **dd)
  134. self.drop_path = DropPath(drop_path) if drop_path > 0. else nn.Identity()
  135. self.norm2 = norm_layer(dim, **dd)
  136. self.mlp_channels = mlp_layer(dim, channel_dim, act_layer=act_layer, drop=drop, **dd)
  137. self.ls1 = nn.Parameter(init_values * torch.ones(dim, **dd))
  138. self.ls2 = nn.Parameter(init_values * torch.ones(dim, **dd))
  139. def forward(self, x: torch.Tensor) -> torch.Tensor:
  140. """Forward pass."""
  141. x = x + self.drop_path(self.ls1 * self.linear_tokens(self.norm1(x).transpose(1, 2)).transpose(1, 2))
  142. x = x + self.drop_path(self.ls2 * self.mlp_channels(self.norm2(x)))
  143. return x
  144. class SpatialGatingUnit(nn.Module):
  145. """Spatial Gating Unit.
  146. Based on: `Pay Attention to MLPs` - https://arxiv.org/abs/2105.08050
  147. """
  148. def __init__(
  149. self,
  150. dim: int,
  151. seq_len: int,
  152. norm_layer: Type[nn.Module] = nn.LayerNorm,
  153. device=None,
  154. dtype=None,
  155. ) -> None:
  156. """Initialize Spatial Gating Unit.
  157. Args:
  158. dim: Dimension of input features.
  159. seq_len: Sequence length.
  160. norm_layer: Normalization layer.
  161. """
  162. dd = {'device': device, 'dtype': dtype}
  163. super().__init__()
  164. gate_dim = dim // 2
  165. self.norm = norm_layer(gate_dim, **dd)
  166. self.proj = nn.Linear(seq_len, seq_len, **dd)
  167. def init_weights(self) -> None:
  168. """Initialize weights for projection gate."""
  169. # special init for the projection gate, called as override by base model init
  170. nn.init.normal_(self.proj.weight, std=1e-6)
  171. nn.init.ones_(self.proj.bias)
  172. def forward(self, x: torch.Tensor) -> torch.Tensor:
  173. """Apply spatial gating."""
  174. u, v = x.chunk(2, dim=-1)
  175. v = self.norm(v)
  176. v = self.proj(v.transpose(-1, -2))
  177. return u * v.transpose(-1, -2)
  178. class SpatialGatingBlock(nn.Module):
  179. """Residual Block w/ Spatial Gating.
  180. Based on: `Pay Attention to MLPs` - https://arxiv.org/abs/2105.08050
  181. """
  182. def __init__(
  183. self,
  184. dim: int,
  185. seq_len: int,
  186. mlp_ratio: float = 4,
  187. mlp_layer: Type[nn.Module] = GatedMlp,
  188. norm_layer: Type[nn.Module] = partial(nn.LayerNorm, eps=1e-6),
  189. act_layer: Type[nn.Module] = nn.GELU,
  190. drop: float = 0.,
  191. drop_path: float = 0.,
  192. device=None,
  193. dtype=None,
  194. ) -> None:
  195. """Initialize SpatialGatingBlock.
  196. Args:
  197. dim: Dimension of input features.
  198. seq_len: Sequence length.
  199. mlp_ratio: Channel MLP expansion ratio.
  200. mlp_layer: MLP layer class.
  201. norm_layer: Normalization layer.
  202. act_layer: Activation layer.
  203. drop: Dropout rate.
  204. drop_path: Drop path rate.
  205. """
  206. dd = {'device': device, 'dtype': dtype}
  207. super().__init__()
  208. channel_dim = int(dim * mlp_ratio)
  209. self.norm = norm_layer(dim, **dd)
  210. sgu = partial(SpatialGatingUnit, seq_len=seq_len, **dd)
  211. self.mlp_channels = mlp_layer(
  212. dim,
  213. channel_dim,
  214. act_layer=act_layer,
  215. gate_layer=sgu,
  216. drop=drop,
  217. **dd,
  218. )
  219. self.drop_path = DropPath(drop_path) if drop_path > 0. else nn.Identity()
  220. def forward(self, x: torch.Tensor) -> torch.Tensor:
  221. """Forward pass."""
  222. x = x + self.drop_path(self.mlp_channels(self.norm(x)))
  223. return x
  224. class MlpMixer(nn.Module):
  225. """MLP-Mixer model architecture.
  226. Based on: 'MLP-Mixer: An all-MLP Architecture for Vision' - https://arxiv.org/abs/2105.01601
  227. """
  228. def __init__(
  229. self,
  230. num_classes: int = 1000,
  231. img_size: int = 224,
  232. in_chans: int = 3,
  233. patch_size: int = 16,
  234. num_blocks: int = 8,
  235. embed_dim: int = 512,
  236. mlp_ratio: Union[float, Tuple[float, float]] = (0.5, 4.0),
  237. block_layer: Type[nn.Module] = MixerBlock,
  238. mlp_layer: Type[nn.Module] = Mlp,
  239. norm_layer: Type[nn.Module] = partial(nn.LayerNorm, eps=1e-6),
  240. act_layer: Type[nn.Module] = nn.GELU,
  241. drop_rate: float = 0.,
  242. proj_drop_rate: float = 0.,
  243. drop_path_rate: float = 0.,
  244. nlhb: bool = False,
  245. stem_norm: bool = False,
  246. global_pool: str = 'avg',
  247. device=None,
  248. dtype=None,
  249. ) -> None:
  250. """Initialize MLP-Mixer.
  251. Args:
  252. num_classes: Number of classes for classification.
  253. img_size: Input image size.
  254. in_chans: Number of input channels.
  255. patch_size: Patch size.
  256. num_blocks: Number of mixer blocks.
  257. embed_dim: Embedding dimension.
  258. mlp_ratio: MLP expansion ratio(s).
  259. block_layer: Block layer class.
  260. mlp_layer: MLP layer class.
  261. norm_layer: Normalization layer.
  262. act_layer: Activation layer.
  263. drop_rate: Head dropout rate.
  264. proj_drop_rate: Projection dropout rate.
  265. drop_path_rate: Drop path rate.
  266. nlhb: Use negative log bias initialization.
  267. stem_norm: Apply normalization to stem.
  268. global_pool: Global pooling type.
  269. """
  270. super().__init__()
  271. dd = {'device': device, 'dtype': dtype}
  272. self.num_classes = num_classes
  273. self.global_pool = global_pool
  274. self.num_features = self.head_hidden_size = self.embed_dim = embed_dim # for consistency with other models
  275. self.grad_checkpointing = False
  276. self.stem = PatchEmbed(
  277. img_size=img_size,
  278. patch_size=patch_size,
  279. in_chans=in_chans,
  280. embed_dim=embed_dim,
  281. norm_layer=norm_layer if stem_norm else None,
  282. **dd,
  283. )
  284. reduction = self.stem.feat_ratio() if hasattr(self.stem, 'feat_ratio') else patch_size
  285. # FIXME drop_path (stochastic depth scaling rule or all the same?)
  286. self.blocks = nn.Sequential(*[
  287. block_layer(
  288. embed_dim,
  289. self.stem.num_patches,
  290. mlp_ratio,
  291. mlp_layer=mlp_layer,
  292. norm_layer=norm_layer,
  293. act_layer=act_layer,
  294. drop=proj_drop_rate,
  295. drop_path=drop_path_rate,
  296. **dd,
  297. )
  298. for _ in range(num_blocks)])
  299. self.feature_info = [
  300. dict(module=f'blocks.{i}', num_chs=embed_dim, reduction=reduction) for i in range(num_blocks)]
  301. self.norm = norm_layer(embed_dim, **dd)
  302. self.head_drop = nn.Dropout(drop_rate)
  303. self.head = nn.Linear(embed_dim, self.num_classes, **dd) if num_classes > 0 else nn.Identity()
  304. self.init_weights(nlhb=nlhb)
  305. @torch.jit.ignore
  306. def init_weights(self, nlhb: bool = False) -> None:
  307. """Initialize model weights.
  308. Args:
  309. nlhb: Use negative log bias initialization for head.
  310. """
  311. head_bias = -math.log(self.num_classes) if nlhb else 0.
  312. named_apply(partial(_init_weights, head_bias=head_bias), module=self) # depth-first
  313. @torch.jit.ignore
  314. def group_matcher(self, coarse: bool = False) -> Dict[str, Any]:
  315. """Create regex patterns for parameter grouping.
  316. Args:
  317. coarse: Use coarse grouping.
  318. Returns:
  319. Dictionary mapping group names to regex patterns.
  320. """
  321. return dict(
  322. stem=r'^stem', # stem and embed
  323. blocks=[(r'^blocks\.(\d+)', None), (r'^norm', (99999,))]
  324. )
  325. @torch.jit.ignore
  326. def set_grad_checkpointing(self, enable: bool = True) -> None:
  327. """Enable or disable gradient checkpointing.
  328. Args:
  329. enable: Whether to enable gradient checkpointing.
  330. """
  331. self.grad_checkpointing = enable
  332. @torch.jit.ignore
  333. def get_classifier(self) -> nn.Module:
  334. """Get the classifier module."""
  335. return self.head
  336. def reset_classifier(self, num_classes: int, global_pool: Optional[str] = None) -> None:
  337. """Reset the classifier head.
  338. Args:
  339. num_classes: Number of classes for new classifier.
  340. global_pool: Global pooling type.
  341. """
  342. self.num_classes = num_classes
  343. if global_pool is not None:
  344. assert global_pool in ('', 'avg')
  345. self.global_pool = global_pool
  346. device, dtype = self.head.weight.device, self.head.weight.dtype if hasattr(self.head, 'weight') else (None, None)
  347. self.head = nn.Linear(self.embed_dim, num_classes, device=device, dtype=dtype) if num_classes > 0 else nn.Identity()
  348. def forward_intermediates(
  349. self,
  350. x: torch.Tensor,
  351. indices: Optional[Union[int, List[int]]] = None,
  352. norm: bool = False,
  353. stop_early: bool = False,
  354. output_fmt: str = 'NCHW',
  355. intermediates_only: bool = False,
  356. ) -> Union[List[torch.Tensor], Tuple[torch.Tensor, List[torch.Tensor]]]:
  357. """Forward features that returns intermediates.
  358. Args:
  359. x: Input image tensor.
  360. indices: Take last n blocks if int, all if None, select matching indices if sequence.
  361. norm: Apply norm layer to all intermediates.
  362. stop_early: Stop iterating over blocks when last desired intermediate hit.
  363. output_fmt: Shape of intermediate feature outputs ('NCHW' or 'NLC').
  364. intermediates_only: Only return intermediate features.
  365. Returns:
  366. List of intermediate features or tuple of (final features, intermediates).
  367. """
  368. assert output_fmt in ('NCHW', 'NLC'), 'Output format must be one of NCHW or NLC.'
  369. reshape = output_fmt == 'NCHW'
  370. intermediates = []
  371. take_indices, max_index = feature_take_indices(len(self.blocks), indices)
  372. # forward pass
  373. B, _, height, width = x.shape
  374. x = self.stem(x)
  375. if torch.jit.is_scripting() or not stop_early: # can't slice blocks in torchscript
  376. blocks = self.blocks
  377. else:
  378. blocks = self.blocks[:max_index + 1]
  379. for i, blk in enumerate(blocks):
  380. if self.grad_checkpointing and not torch.jit.is_scripting():
  381. x = checkpoint(blk, x)
  382. else:
  383. x = blk(x)
  384. if i in take_indices:
  385. # normalize intermediates with final norm layer if enabled
  386. intermediates.append(self.norm(x) if norm else x)
  387. # process intermediates
  388. if reshape:
  389. # reshape to BCHW output format
  390. H, W = self.stem.dynamic_feat_size((height, width))
  391. intermediates = [y.reshape(B, H, W, -1).permute(0, 3, 1, 2).contiguous() for y in intermediates]
  392. if intermediates_only:
  393. return intermediates
  394. x = self.norm(x)
  395. return x, intermediates
  396. def prune_intermediate_layers(
  397. self,
  398. indices: Union[int, List[int]] = 1,
  399. prune_norm: bool = False,
  400. prune_head: bool = True,
  401. ) -> List[int]:
  402. """Prune layers not required for specified intermediates.
  403. Args:
  404. indices: Indices of intermediate layers to keep.
  405. prune_norm: Whether to prune normalization layer.
  406. prune_head: Whether to prune the classifier head.
  407. Returns:
  408. List of indices that were kept.
  409. """
  410. take_indices, max_index = feature_take_indices(len(self.blocks), indices)
  411. self.blocks = self.blocks[:max_index + 1] # truncate blocks
  412. if prune_norm:
  413. self.norm = nn.Identity()
  414. if prune_head:
  415. self.reset_classifier(0, '')
  416. return take_indices
  417. def forward_features(self, x: torch.Tensor) -> torch.Tensor:
  418. """Forward pass through feature extraction layers."""
  419. x = self.stem(x)
  420. if self.grad_checkpointing and not torch.jit.is_scripting():
  421. x = checkpoint_seq(self.blocks, x)
  422. else:
  423. x = self.blocks(x)
  424. x = self.norm(x)
  425. return x
  426. def forward_head(self, x: torch.Tensor, pre_logits: bool = False) -> torch.Tensor:
  427. """Forward pass through classifier head.
  428. Args:
  429. x: Feature tensor.
  430. pre_logits: Return features before final classifier.
  431. Returns:
  432. Output tensor.
  433. """
  434. if self.global_pool == 'avg':
  435. x = x.mean(dim=1)
  436. x = self.head_drop(x)
  437. return x if pre_logits else self.head(x)
  438. def forward(self, x: torch.Tensor) -> torch.Tensor:
  439. """Forward pass."""
  440. x = self.forward_features(x)
  441. x = self.forward_head(x)
  442. return x
  443. def _init_weights(module: nn.Module, name: str, head_bias: float = 0., flax: bool = False) -> None:
  444. """Mixer weight initialization (trying to match Flax defaults).
  445. Args:
  446. module: Module to initialize.
  447. name: Module name.
  448. head_bias: Bias value for head layer.
  449. flax: Use Flax-style initialization.
  450. """
  451. if isinstance(module, nn.Linear):
  452. if name.startswith('head'):
  453. nn.init.zeros_(module.weight)
  454. nn.init.constant_(module.bias, head_bias)
  455. else:
  456. if flax:
  457. # Flax defaults
  458. lecun_normal_(module.weight)
  459. if module.bias is not None:
  460. nn.init.zeros_(module.bias)
  461. else:
  462. # like MLP init in vit (my original init)
  463. nn.init.xavier_uniform_(module.weight)
  464. if module.bias is not None:
  465. if 'mlp' in name:
  466. nn.init.normal_(module.bias, std=1e-6)
  467. else:
  468. nn.init.zeros_(module.bias)
  469. elif isinstance(module, nn.Conv2d):
  470. lecun_normal_(module.weight)
  471. if module.bias is not None:
  472. nn.init.zeros_(module.bias)
  473. elif isinstance(module, (nn.LayerNorm, nn.BatchNorm2d, nn.GroupNorm)):
  474. nn.init.ones_(module.weight)
  475. nn.init.zeros_(module.bias)
  476. elif hasattr(module, 'init_weights'):
  477. # NOTE if a parent module contains init_weights method, it can override the init of the
  478. # child modules as this will be called in depth-first order.
  479. module.init_weights()
  480. def checkpoint_filter_fn(state_dict, model):
  481. """ Remap checkpoints if needed """
  482. if 'patch_embed.proj.weight' in state_dict:
  483. # Remap FB ResMlp models -> timm
  484. out_dict = {}
  485. for k, v in state_dict.items():
  486. k = k.replace('patch_embed.', 'stem.')
  487. k = k.replace('attn.', 'linear_tokens.')
  488. k = k.replace('mlp.', 'mlp_channels.')
  489. k = k.replace('gamma_', 'ls')
  490. if k.endswith('.alpha') or k.endswith('.beta'):
  491. v = v.reshape(1, 1, -1)
  492. out_dict[k] = v
  493. return out_dict
  494. return state_dict
  495. def _create_mixer(variant, pretrained=False, **kwargs) -> MlpMixer:
  496. out_indices = kwargs.pop('out_indices', 3)
  497. model = build_model_with_cfg(
  498. MlpMixer,
  499. variant,
  500. pretrained,
  501. pretrained_filter_fn=checkpoint_filter_fn,
  502. feature_cfg=dict(out_indices=out_indices, feature_cls='getter'),
  503. **kwargs,
  504. )
  505. return model
  506. def _cfg(url='', **kwargs) -> Dict[str, Any]:
  507. return {
  508. 'url': url,
  509. 'num_classes': 1000, 'input_size': (3, 224, 224), 'pool_size': None,
  510. 'crop_pct': 0.875, 'interpolation': 'bicubic', 'fixed_input_size': True,
  511. 'mean': (0.5, 0.5, 0.5), 'std': (0.5, 0.5, 0.5),
  512. 'first_conv': 'stem.proj', 'classifier': 'head',
  513. 'license': 'apache-2.0',
  514. **kwargs
  515. }
  516. default_cfgs = generate_default_cfgs({
  517. 'mixer_s32_224.untrained': _cfg(),
  518. 'mixer_s16_224.untrained': _cfg(),
  519. 'mixer_b32_224.untrained': _cfg(),
  520. 'mixer_b16_224.goog_in21k_ft_in1k': _cfg(
  521. hf_hub_id='timm/',
  522. url='https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-vitjx/jx_mixer_b16_224-76587d61.pth',
  523. ),
  524. 'mixer_b16_224.goog_in21k': _cfg(
  525. hf_hub_id='timm/',
  526. url='https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-vitjx/jx_mixer_b16_224_in21k-617b3de2.pth',
  527. num_classes=21843
  528. ),
  529. 'mixer_l32_224.untrained': _cfg(),
  530. 'mixer_l16_224.goog_in21k_ft_in1k': _cfg(
  531. hf_hub_id='timm/',
  532. url='https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-vitjx/jx_mixer_l16_224-92f9adc4.pth',
  533. ),
  534. 'mixer_l16_224.goog_in21k': _cfg(
  535. hf_hub_id='timm/',
  536. url='https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-vitjx/jx_mixer_l16_224_in21k-846aa33c.pth',
  537. num_classes=21843
  538. ),
  539. # Mixer ImageNet-21K-P pretraining
  540. 'mixer_b16_224.miil_in21k': _cfg(
  541. hf_hub_id='timm/',
  542. url='https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-tresnet/mixer_b16_224_miil_in21k-2a558a71.pth',
  543. mean=(0., 0., 0.), std=(1., 1., 1.), crop_pct=0.875, interpolation='bilinear', num_classes=11221,
  544. ),
  545. 'mixer_b16_224.miil_in21k_ft_in1k': _cfg(
  546. hf_hub_id='timm/',
  547. url='https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-tresnet/mixer_b16_224_miil-9229a591.pth',
  548. mean=(0., 0., 0.), std=(1., 1., 1.), crop_pct=0.875, interpolation='bilinear',
  549. ),
  550. 'gmixer_12_224.untrained': _cfg(mean=IMAGENET_DEFAULT_MEAN, std=IMAGENET_DEFAULT_STD),
  551. 'gmixer_24_224.ra3_in1k': _cfg(
  552. hf_hub_id='timm/',
  553. url='https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-weights/gmixer_24_224_raa-7daf7ae6.pth',
  554. mean=IMAGENET_DEFAULT_MEAN, std=IMAGENET_DEFAULT_STD),
  555. 'resmlp_12_224.fb_in1k': _cfg(
  556. hf_hub_id='timm/',
  557. url='https://dl.fbaipublicfiles.com/deit/resmlp_12_no_dist.pth',
  558. mean=IMAGENET_DEFAULT_MEAN, std=IMAGENET_DEFAULT_STD),
  559. 'resmlp_24_224.fb_in1k': _cfg(
  560. hf_hub_id='timm/',
  561. url='https://dl.fbaipublicfiles.com/deit/resmlp_24_no_dist.pth',
  562. #url='https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-weights/resmlp_24_224_raa-a8256759.pth',
  563. mean=IMAGENET_DEFAULT_MEAN, std=IMAGENET_DEFAULT_STD),
  564. 'resmlp_36_224.fb_in1k': _cfg(
  565. hf_hub_id='timm/',
  566. url='https://dl.fbaipublicfiles.com/deit/resmlp_36_no_dist.pth',
  567. mean=IMAGENET_DEFAULT_MEAN, std=IMAGENET_DEFAULT_STD),
  568. 'resmlp_big_24_224.fb_in1k': _cfg(
  569. hf_hub_id='timm/',
  570. url='https://dl.fbaipublicfiles.com/deit/resmlpB_24_no_dist.pth',
  571. mean=IMAGENET_DEFAULT_MEAN, std=IMAGENET_DEFAULT_STD),
  572. 'resmlp_12_224.fb_distilled_in1k': _cfg(
  573. hf_hub_id='timm/',
  574. url='https://dl.fbaipublicfiles.com/deit/resmlp_12_dist.pth',
  575. mean=IMAGENET_DEFAULT_MEAN, std=IMAGENET_DEFAULT_STD),
  576. 'resmlp_24_224.fb_distilled_in1k': _cfg(
  577. hf_hub_id='timm/',
  578. url='https://dl.fbaipublicfiles.com/deit/resmlp_24_dist.pth',
  579. mean=IMAGENET_DEFAULT_MEAN, std=IMAGENET_DEFAULT_STD),
  580. 'resmlp_36_224.fb_distilled_in1k': _cfg(
  581. hf_hub_id='timm/',
  582. url='https://dl.fbaipublicfiles.com/deit/resmlp_36_dist.pth',
  583. mean=IMAGENET_DEFAULT_MEAN, std=IMAGENET_DEFAULT_STD),
  584. 'resmlp_big_24_224.fb_distilled_in1k': _cfg(
  585. hf_hub_id='timm/',
  586. url='https://dl.fbaipublicfiles.com/deit/resmlpB_24_dist.pth',
  587. mean=IMAGENET_DEFAULT_MEAN, std=IMAGENET_DEFAULT_STD),
  588. 'resmlp_big_24_224.fb_in22k_ft_in1k': _cfg(
  589. hf_hub_id='timm/',
  590. url='https://dl.fbaipublicfiles.com/deit/resmlpB_24_22k.pth',
  591. mean=IMAGENET_DEFAULT_MEAN, std=IMAGENET_DEFAULT_STD),
  592. 'resmlp_12_224.fb_dino': _cfg(
  593. hf_hub_id='timm/',
  594. url='https://dl.fbaipublicfiles.com/deit/resmlp_12_dino.pth',
  595. mean=IMAGENET_DEFAULT_MEAN, std=IMAGENET_DEFAULT_STD),
  596. 'resmlp_24_224.fb_dino': _cfg(
  597. hf_hub_id='timm/',
  598. url='https://dl.fbaipublicfiles.com/deit/resmlp_24_dino.pth',
  599. mean=IMAGENET_DEFAULT_MEAN, std=IMAGENET_DEFAULT_STD),
  600. 'gmlp_ti16_224.untrained': _cfg(),
  601. 'gmlp_s16_224.ra3_in1k': _cfg(
  602. hf_hub_id='timm/',
  603. url='https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-weights/gmlp_s16_224_raa-10536d42.pth',
  604. ),
  605. 'gmlp_b16_224.untrained': _cfg(),
  606. })
  607. @register_model
  608. def mixer_s32_224(pretrained=False, **kwargs) -> MlpMixer:
  609. """ Mixer-S/32 224x224
  610. Paper: 'MLP-Mixer: An all-MLP Architecture for Vision' - https://arxiv.org/abs/2105.01601
  611. """
  612. model_args = dict(patch_size=32, num_blocks=8, embed_dim=512, **kwargs)
  613. model = _create_mixer('mixer_s32_224', pretrained=pretrained, **model_args)
  614. return model
  615. @register_model
  616. def mixer_s16_224(pretrained=False, **kwargs) -> MlpMixer:
  617. """ Mixer-S/16 224x224
  618. Paper: 'MLP-Mixer: An all-MLP Architecture for Vision' - https://arxiv.org/abs/2105.01601
  619. """
  620. model_args = dict(patch_size=16, num_blocks=8, embed_dim=512, **kwargs)
  621. model = _create_mixer('mixer_s16_224', pretrained=pretrained, **model_args)
  622. return model
  623. @register_model
  624. def mixer_b32_224(pretrained=False, **kwargs) -> MlpMixer:
  625. """ Mixer-B/32 224x224
  626. Paper: 'MLP-Mixer: An all-MLP Architecture for Vision' - https://arxiv.org/abs/2105.01601
  627. """
  628. model_args = dict(patch_size=32, num_blocks=12, embed_dim=768, **kwargs)
  629. model = _create_mixer('mixer_b32_224', pretrained=pretrained, **model_args)
  630. return model
  631. @register_model
  632. def mixer_b16_224(pretrained=False, **kwargs) -> MlpMixer:
  633. """ Mixer-B/16 224x224. ImageNet-1k pretrained weights.
  634. Paper: 'MLP-Mixer: An all-MLP Architecture for Vision' - https://arxiv.org/abs/2105.01601
  635. """
  636. model_args = dict(patch_size=16, num_blocks=12, embed_dim=768, **kwargs)
  637. model = _create_mixer('mixer_b16_224', pretrained=pretrained, **model_args)
  638. return model
  639. @register_model
  640. def mixer_l32_224(pretrained=False, **kwargs) -> MlpMixer:
  641. """ Mixer-L/32 224x224.
  642. Paper: 'MLP-Mixer: An all-MLP Architecture for Vision' - https://arxiv.org/abs/2105.01601
  643. """
  644. model_args = dict(patch_size=32, num_blocks=24, embed_dim=1024, **kwargs)
  645. model = _create_mixer('mixer_l32_224', pretrained=pretrained, **model_args)
  646. return model
  647. @register_model
  648. def mixer_l16_224(pretrained=False, **kwargs) -> MlpMixer:
  649. """ Mixer-L/16 224x224. ImageNet-1k pretrained weights.
  650. Paper: 'MLP-Mixer: An all-MLP Architecture for Vision' - https://arxiv.org/abs/2105.01601
  651. """
  652. model_args = dict(patch_size=16, num_blocks=24, embed_dim=1024, **kwargs)
  653. model = _create_mixer('mixer_l16_224', pretrained=pretrained, **model_args)
  654. return model
  655. @register_model
  656. def gmixer_12_224(pretrained=False, **kwargs) -> MlpMixer:
  657. """ Glu-Mixer-12 224x224
  658. Experiment by Ross Wightman, adding SwiGLU to MLP-Mixer
  659. """
  660. model_args = dict(
  661. patch_size=16, num_blocks=12, embed_dim=384, mlp_ratio=(1.0, 4.0),
  662. mlp_layer=GluMlp, act_layer=nn.SiLU, **kwargs)
  663. model = _create_mixer('gmixer_12_224', pretrained=pretrained, **model_args)
  664. return model
  665. @register_model
  666. def gmixer_24_224(pretrained=False, **kwargs) -> MlpMixer:
  667. """ Glu-Mixer-24 224x224
  668. Experiment by Ross Wightman, adding SwiGLU to MLP-Mixer
  669. """
  670. model_args = dict(
  671. patch_size=16, num_blocks=24, embed_dim=384, mlp_ratio=(1.0, 4.0),
  672. mlp_layer=GluMlp, act_layer=nn.SiLU, **kwargs)
  673. model = _create_mixer('gmixer_24_224', pretrained=pretrained, **model_args)
  674. return model
  675. @register_model
  676. def resmlp_12_224(pretrained=False, **kwargs) -> MlpMixer:
  677. """ ResMLP-12
  678. Paper: `ResMLP: Feedforward networks for image classification...` - https://arxiv.org/abs/2105.03404
  679. """
  680. model_args = dict(
  681. patch_size=16, num_blocks=12, embed_dim=384, mlp_ratio=4, block_layer=ResBlock, norm_layer=Affine, **kwargs)
  682. model = _create_mixer('resmlp_12_224', pretrained=pretrained, **model_args)
  683. return model
  684. @register_model
  685. def resmlp_24_224(pretrained=False, **kwargs) -> MlpMixer:
  686. """ ResMLP-24
  687. Paper: `ResMLP: Feedforward networks for image classification...` - https://arxiv.org/abs/2105.03404
  688. """
  689. model_args = dict(
  690. patch_size=16, num_blocks=24, embed_dim=384, mlp_ratio=4,
  691. block_layer=partial(ResBlock, init_values=1e-5), norm_layer=Affine, **kwargs)
  692. model = _create_mixer('resmlp_24_224', pretrained=pretrained, **model_args)
  693. return model
  694. @register_model
  695. def resmlp_36_224(pretrained=False, **kwargs) -> MlpMixer:
  696. """ ResMLP-36
  697. Paper: `ResMLP: Feedforward networks for image classification...` - https://arxiv.org/abs/2105.03404
  698. """
  699. model_args = dict(
  700. patch_size=16, num_blocks=36, embed_dim=384, mlp_ratio=4,
  701. block_layer=partial(ResBlock, init_values=1e-6), norm_layer=Affine, **kwargs)
  702. model = _create_mixer('resmlp_36_224', pretrained=pretrained, **model_args)
  703. return model
  704. @register_model
  705. def resmlp_big_24_224(pretrained=False, **kwargs) -> MlpMixer:
  706. """ ResMLP-B-24
  707. Paper: `ResMLP: Feedforward networks for image classification...` - https://arxiv.org/abs/2105.03404
  708. """
  709. model_args = dict(
  710. patch_size=8, num_blocks=24, embed_dim=768, mlp_ratio=4,
  711. block_layer=partial(ResBlock, init_values=1e-6), norm_layer=Affine, **kwargs)
  712. model = _create_mixer('resmlp_big_24_224', pretrained=pretrained, **model_args)
  713. return model
  714. @register_model
  715. def gmlp_ti16_224(pretrained=False, **kwargs) -> MlpMixer:
  716. """ gMLP-Tiny
  717. Paper: `Pay Attention to MLPs` - https://arxiv.org/abs/2105.08050
  718. """
  719. model_args = dict(
  720. patch_size=16, num_blocks=30, embed_dim=128, mlp_ratio=6, block_layer=SpatialGatingBlock,
  721. mlp_layer=GatedMlp, **kwargs)
  722. model = _create_mixer('gmlp_ti16_224', pretrained=pretrained, **model_args)
  723. return model
  724. @register_model
  725. def gmlp_s16_224(pretrained=False, **kwargs) -> MlpMixer:
  726. """ gMLP-Small
  727. Paper: `Pay Attention to MLPs` - https://arxiv.org/abs/2105.08050
  728. """
  729. model_args = dict(
  730. patch_size=16, num_blocks=30, embed_dim=256, mlp_ratio=6, block_layer=SpatialGatingBlock,
  731. mlp_layer=GatedMlp, **kwargs)
  732. model = _create_mixer('gmlp_s16_224', pretrained=pretrained, **model_args)
  733. return model
  734. @register_model
  735. def gmlp_b16_224(pretrained=False, **kwargs) -> MlpMixer:
  736. """ gMLP-Base
  737. Paper: `Pay Attention to MLPs` - https://arxiv.org/abs/2105.08050
  738. """
  739. model_args = dict(
  740. patch_size=16, num_blocks=30, embed_dim=512, mlp_ratio=6, block_layer=SpatialGatingBlock,
  741. mlp_layer=GatedMlp, **kwargs)
  742. model = _create_mixer('gmlp_b16_224', pretrained=pretrained, **model_args)
  743. return model
  744. register_model_deprecations(__name__, {
  745. 'mixer_b16_224_in21k': 'mixer_b16_224.goog_in21k_ft_in1k',
  746. 'mixer_l16_224_in21k': 'mixer_l16_224.goog_in21k_ft_in1k',
  747. 'mixer_b16_224_miil': 'mixer_b16_224.miil_in21k_ft_in1k',
  748. 'mixer_b16_224_miil_in21k': 'mixer_b16_224.miil_in21k',
  749. 'resmlp_12_distilled_224': 'resmlp_12_224.fb_distilled_in1k',
  750. 'resmlp_24_distilled_224': 'resmlp_24_224.fb_distilled_in1k',
  751. 'resmlp_36_distilled_224': 'resmlp_36_224.fb_distilled_in1k',
  752. 'resmlp_big_24_distilled_224': 'resmlp_big_24_224.fb_distilled_in1k',
  753. 'resmlp_big_24_224_in22ft1k': 'resmlp_big_24_224.fb_in22k_ft_in1k',
  754. 'resmlp_12_224_dino': 'resmlp_12_224',
  755. 'resmlp_24_224_dino': 'resmlp_24_224',
  756. })