fasternet.py 19 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505
  1. """FasterNet
  2. Run, Don't Walk: Chasing Higher FLOPS for Faster Neural Networks
  3. - paper: https://arxiv.org/abs/2303.03667
  4. - code: https://github.com/JierunChen/FasterNet
  5. @article{chen2023run,
  6. title={Run, Don't Walk: Chasing Higher FLOPS for Faster Neural Networks},
  7. author={Chen, Jierun and Kao, Shiu-hong and He, Hao and Zhuo, Weipeng and Wen, Song and Lee, Chul-Ho and Chan, S-H Gary},
  8. journal={arXiv preprint arXiv:2303.03667},
  9. year={2023}
  10. }
  11. Modifications by / Copyright 2025 Ryan Hou & Ross Wightman, original copyrights below
  12. """
  13. # Copyright (c) Microsoft Corporation.
  14. # Licensed under the MIT License.
  15. from functools import partial
  16. from typing import Any, Dict, List, Optional, Set, Tuple, Type, Union
  17. import torch
  18. import torch.nn as nn
  19. import torch.nn.functional as F
  20. from timm.data import IMAGENET_DEFAULT_MEAN, IMAGENET_DEFAULT_STD
  21. from timm.layers import SelectAdaptivePool2d, Linear, DropPath, trunc_normal_, LayerType, calculate_drop_path_rates
  22. from ._builder import build_model_with_cfg
  23. from ._features import feature_take_indices
  24. from ._manipulate import checkpoint_seq
  25. from ._registry import register_model, generate_default_cfgs
  26. __all__ = ['FasterNet']
  27. class Partial_conv3(nn.Module):
  28. def __init__(self, dim: int, n_div: int, forward: str, device=None, dtype=None):
  29. dd = {'device': device, 'dtype': dtype}
  30. super().__init__()
  31. self.dim_conv3 = dim // n_div
  32. self.dim_untouched = dim - self.dim_conv3
  33. self.partial_conv3 = nn.Conv2d(self.dim_conv3, self.dim_conv3, 3, 1, 1, bias=False, **dd)
  34. if forward == 'slicing':
  35. self.forward = self.forward_slicing
  36. elif forward == 'split_cat':
  37. self.forward = self.forward_split_cat
  38. else:
  39. raise NotImplementedError
  40. def forward_slicing(self, x: torch.Tensor) -> torch.Tensor:
  41. # only for inference
  42. x = x.clone() # !!! Keep the original input intact for the residual connection later
  43. x[:, :self.dim_conv3, :, :] = self.partial_conv3(x[:, :self.dim_conv3, :, :])
  44. return x
  45. def forward_split_cat(self, x: torch.Tensor) -> torch.Tensor:
  46. # for training/inference
  47. x1, x2 = torch.split(x, [self.dim_conv3, self.dim_untouched], dim=1)
  48. x1 = self.partial_conv3(x1)
  49. x = torch.cat((x1, x2), 1)
  50. return x
  51. class MLPBlock(nn.Module):
  52. def __init__(
  53. self,
  54. dim: int,
  55. n_div: int,
  56. mlp_ratio: float,
  57. drop_path: float,
  58. layer_scale_init_value: float,
  59. act_layer: Type[nn.Module] = partial(nn.ReLU, inplace=True),
  60. norm_layer: Type[nn.Module] = nn.BatchNorm2d,
  61. pconv_fw_type: str = 'split_cat',
  62. device=None,
  63. dtype=None,
  64. ):
  65. dd = {'device': device, 'dtype': dtype}
  66. super().__init__()
  67. mlp_hidden_dim = int(dim * mlp_ratio)
  68. self.mlp = nn.Sequential(*[
  69. nn.Conv2d(dim, mlp_hidden_dim, 1, bias=False, **dd),
  70. norm_layer(mlp_hidden_dim, **dd),
  71. act_layer(),
  72. nn.Conv2d(mlp_hidden_dim, dim, 1, bias=False, **dd),
  73. ])
  74. self.spatial_mixing = Partial_conv3(dim, n_div, pconv_fw_type, **dd)
  75. if layer_scale_init_value > 0:
  76. self.layer_scale = nn.Parameter(
  77. layer_scale_init_value * torch.ones((dim), **dd), requires_grad=True)
  78. else:
  79. self.layer_scale = None
  80. self.drop_path = DropPath(drop_path) if drop_path > 0. else nn.Identity()
  81. def forward(self, x: torch.Tensor) -> torch.Tensor:
  82. shortcut = x
  83. x = self.spatial_mixing(x)
  84. if self.layer_scale is not None:
  85. x = shortcut + self.drop_path(
  86. self.layer_scale.unsqueeze(-1).unsqueeze(-1) * self.mlp(x))
  87. else:
  88. x = shortcut + self.drop_path(self.mlp(x))
  89. return x
  90. class Block(nn.Module):
  91. def __init__(
  92. self,
  93. dim: int,
  94. depth: int,
  95. n_div: int,
  96. mlp_ratio: float,
  97. drop_path: float,
  98. layer_scale_init_value: float,
  99. act_layer: Type[nn.Module] = partial(nn.ReLU, inplace=True),
  100. norm_layer: Type[nn.Module] = nn.BatchNorm2d,
  101. pconv_fw_type: str = 'split_cat',
  102. use_merge: bool = True,
  103. merge_size: Union[int, Tuple[int, int]] = 2,
  104. device=None,
  105. dtype=None,
  106. ):
  107. dd = {'device': device, 'dtype': dtype}
  108. super().__init__()
  109. self.grad_checkpointing = False
  110. self.blocks = nn.Sequential(*[
  111. MLPBlock(
  112. dim=dim,
  113. n_div=n_div,
  114. mlp_ratio=mlp_ratio,
  115. drop_path=drop_path[i],
  116. layer_scale_init_value=layer_scale_init_value,
  117. norm_layer=norm_layer,
  118. act_layer=act_layer,
  119. pconv_fw_type=pconv_fw_type,
  120. **dd,
  121. )
  122. for i in range(depth)
  123. ])
  124. self.downsample = PatchMerging(
  125. dim=dim // 2,
  126. patch_size=merge_size,
  127. norm_layer=norm_layer,
  128. **dd,
  129. ) if use_merge else nn.Identity()
  130. def forward(self, x: torch.Tensor) -> torch.Tensor:
  131. x = self.downsample(x)
  132. if self.grad_checkpointing and not torch.jit.is_scripting():
  133. x = checkpoint_seq(self.blocks, x)
  134. else:
  135. x = self.blocks(x)
  136. return x
  137. class PatchEmbed(nn.Module):
  138. def __init__(
  139. self,
  140. in_chans: int,
  141. embed_dim: int,
  142. patch_size: Union[int, Tuple[int, int]] = 4,
  143. norm_layer: Type[nn.Module] = nn.BatchNorm2d,
  144. device=None,
  145. dtype=None,
  146. ):
  147. dd = {'device': device, 'dtype': dtype}
  148. super().__init__()
  149. self.proj = nn.Conv2d(in_chans, embed_dim, patch_size, patch_size, bias=False, **dd)
  150. self.norm = norm_layer(embed_dim, **dd)
  151. def forward(self, x: torch.Tensor) -> torch.Tensor:
  152. return self.norm(self.proj(x))
  153. class PatchMerging(nn.Module):
  154. def __init__(
  155. self,
  156. dim: int,
  157. patch_size: Union[int, Tuple[int, int]] = 2,
  158. norm_layer: Type[nn.Module] = nn.BatchNorm2d,
  159. device=None,
  160. dtype=None,
  161. ):
  162. dd = {'device': device, 'dtype': dtype}
  163. super().__init__()
  164. self.reduction = nn.Conv2d(dim, 2 * dim, patch_size, patch_size, bias=False, **dd)
  165. self.norm = norm_layer(2 * dim, **dd)
  166. def forward(self, x: torch.Tensor) -> torch.Tensor:
  167. return self.norm(self.reduction(x))
  168. class FasterNet(nn.Module):
  169. def __init__(
  170. self,
  171. in_chans: int = 3,
  172. num_classes: int = 1000,
  173. global_pool: str = 'avg',
  174. embed_dim: int = 96,
  175. depths: Union[int, Tuple[int, ...]] = (1, 2, 8, 2),
  176. mlp_ratio: float = 2.,
  177. n_div: int = 4,
  178. patch_size: Union[int, Tuple[int, int]] = 4,
  179. merge_size: Union[int, Tuple[int, int]] = 2,
  180. patch_norm: bool = True,
  181. feature_dim: int = 1280,
  182. drop_rate: float = 0.,
  183. drop_path_rate: float = 0.1,
  184. layer_scale_init_value: float = 0.,
  185. act_layer: Type[nn.Module] = partial(nn.ReLU, inplace=True),
  186. norm_layer: Type[nn.Module] = nn.BatchNorm2d,
  187. pconv_fw_type: str = 'split_cat',
  188. device=None,
  189. dtype=None,
  190. ):
  191. super().__init__()
  192. dd = {'device': device, 'dtype': dtype}
  193. assert pconv_fw_type in ('split_cat', 'slicing',)
  194. self.num_classes = num_classes
  195. self.drop_rate = drop_rate
  196. if not isinstance(depths, (list, tuple)):
  197. depths = (depths) # it means the model has only one stage
  198. self.num_stages = len(depths)
  199. self.feature_info = []
  200. self.patch_embed = PatchEmbed(
  201. in_chans=in_chans,
  202. embed_dim=embed_dim,
  203. patch_size=patch_size,
  204. norm_layer=norm_layer if patch_norm else nn.Identity,
  205. **dd,
  206. )
  207. # stochastic depth decay rule
  208. dpr = calculate_drop_path_rates(drop_path_rate, depths, stagewise=True)
  209. # build layers
  210. stages_list = []
  211. for i in range(self.num_stages):
  212. dim = int(embed_dim * 2 ** i)
  213. stage = Block(
  214. dim=dim,
  215. depth=depths[i],
  216. n_div=n_div,
  217. mlp_ratio=mlp_ratio,
  218. drop_path=dpr[i],
  219. layer_scale_init_value=layer_scale_init_value,
  220. norm_layer=norm_layer,
  221. act_layer=act_layer,
  222. pconv_fw_type=pconv_fw_type,
  223. use_merge=False if i == 0 else True,
  224. merge_size=merge_size,
  225. **dd,
  226. )
  227. stages_list.append(stage)
  228. self.feature_info += [dict(num_chs=dim, reduction=2**(i+2), module=f'stages.{i}')]
  229. self.stages = nn.Sequential(*stages_list)
  230. # building last several layers
  231. self.num_features = prev_chs = int(embed_dim * 2 ** (self.num_stages - 1))
  232. self.head_hidden_size = out_chs = feature_dim # 1280
  233. self.global_pool = SelectAdaptivePool2d(pool_type=global_pool)
  234. self.conv_head = nn.Conv2d(prev_chs, out_chs, 1, 1, 0, bias=False, **dd)
  235. self.act = act_layer()
  236. self.flatten = nn.Flatten(1) if global_pool else nn.Identity() # don't flatten if pooling disabled
  237. self.classifier = Linear(out_chs, num_classes, bias=True, **dd) if num_classes > 0 else nn.Identity()
  238. self._initialize_weights()
  239. def _initialize_weights(self):
  240. for name, m in self.named_modules():
  241. if isinstance(m, nn.Linear):
  242. trunc_normal_(m.weight, std=.02)
  243. if isinstance(m, nn.Linear) and m.bias is not None:
  244. nn.init.constant_(m.bias, 0)
  245. elif isinstance(m, nn.Conv2d):
  246. trunc_normal_(m.weight, std=.02)
  247. if m.bias is not None:
  248. nn.init.constant_(m.bias, 0)
  249. @torch.jit.ignore
  250. def no_weight_decay(self) -> Set:
  251. return set()
  252. @torch.jit.ignore
  253. def group_matcher(self, coarse: bool = False) -> Dict[str, Any]:
  254. matcher = dict(
  255. stem=r'^patch_embed', # stem and embed
  256. blocks=r'^stages\.(\d+)' if coarse else [
  257. (r'^stages\.(\d+).downsample', (0,)),
  258. (r'^stages\.(\d+)\.blocks\.(\d+)', None),
  259. (r'^conv_head', (99999,)),
  260. ]
  261. )
  262. return matcher
  263. @torch.jit.ignore
  264. def set_grad_checkpointing(self, enable=True):
  265. for s in self.stages:
  266. s.grad_checkpointing = enable
  267. @torch.jit.ignore
  268. def get_classifier(self) -> nn.Module:
  269. return self.classifier
  270. def reset_classifier(self, num_classes: int, global_pool: str = 'avg', device=None, dtype=None):
  271. dd = {'device': device, 'dtype': dtype}
  272. self.num_classes = num_classes
  273. # cannot meaningfully change pooling of efficient head after creation
  274. self.global_pool = SelectAdaptivePool2d(pool_type=global_pool)
  275. self.flatten = nn.Flatten(1) if global_pool else nn.Identity() # don't flatten if pooling disabled
  276. self.classifier = Linear(self.head_hidden_size, num_classes, **dd) if num_classes > 0 else nn.Identity()
  277. def forward_intermediates(
  278. self,
  279. x: torch.Tensor,
  280. indices: Optional[Union[int, List[int]]] = None,
  281. norm: bool = False,
  282. stop_early: bool = False,
  283. output_fmt: str = 'NCHW',
  284. intermediates_only: bool = False,
  285. ) -> Union[List[torch.Tensor], Tuple[torch.Tensor, List[torch.Tensor]]]:
  286. """ Forward features that returns intermediates.
  287. Args:
  288. x: Input image tensor
  289. indices: Take last n blocks if int, all if None, select matching indices if sequence
  290. norm: Apply norm layer to compatible intermediates
  291. stop_early: Stop iterating over blocks when last desired intermediate hit
  292. output_fmt: Shape of intermediate feature outputs
  293. intermediates_only: Only return intermediate features
  294. Returns:
  295. """
  296. assert output_fmt in ('NCHW',), 'Output shape must be NCHW.'
  297. intermediates = []
  298. take_indices, max_index = feature_take_indices(len(self.stages), indices)
  299. # forward pass
  300. x = self.patch_embed(x)
  301. if torch.jit.is_scripting() or not stop_early: # can't slice blocks in torchscript
  302. stages = self.stages
  303. else:
  304. stages = self.stages[:max_index + 1]
  305. for feat_idx, stage in enumerate(stages):
  306. x = stage(x)
  307. if feat_idx in take_indices:
  308. intermediates.append(x)
  309. if intermediates_only:
  310. return intermediates
  311. return x, intermediates
  312. def prune_intermediate_layers(
  313. self,
  314. indices: Union[int, List[int]] = 1,
  315. prune_norm: bool = False,
  316. prune_head: bool = True,
  317. ):
  318. """ Prune layers not required for specified intermediates.
  319. """
  320. take_indices, max_index = feature_take_indices(len(self.stages), indices)
  321. self.stages = self.stages[:max_index + 1] # truncate blocks w/ stem as idx 0
  322. if prune_head:
  323. self.reset_classifier(0, '')
  324. return take_indices
  325. def forward_features(self, x: torch.Tensor) -> torch.Tensor:
  326. x = self.patch_embed(x)
  327. x = self.stages(x)
  328. return x
  329. def forward_head(self, x: torch.Tensor, pre_logits: bool = False) -> torch.Tensor:
  330. x = self.global_pool(x)
  331. x = self.conv_head(x)
  332. x = self.act(x)
  333. x = self.flatten(x)
  334. if self.drop_rate > 0.:
  335. x = F.dropout(x, p=self.drop_rate, training=self.training)
  336. return x if pre_logits else self.classifier(x)
  337. def forward(self, x: torch.Tensor) -> torch.Tensor:
  338. x = self.forward_features(x)
  339. x = self.forward_head(x)
  340. return x
  341. def checkpoint_filter_fn(state_dict: Dict[str, torch.Tensor], model: nn.Module) -> Dict[str, torch.Tensor]:
  342. # if 'avgpool_pre_head' in state_dict:
  343. # return state_dict
  344. #
  345. # out_dict = {
  346. # 'conv_head.weight': state_dict.pop('avgpool_pre_head.1.weight'),
  347. # 'classifier.weight': state_dict.pop('head.weight'),
  348. # 'classifier.bias': state_dict.pop('head.bias')
  349. # }
  350. #
  351. # stage_mapping = {
  352. # 'stages.1.': 'stages.1.downsample.',
  353. # 'stages.2.': 'stages.1.',
  354. # 'stages.3.': 'stages.2.downsample.',
  355. # 'stages.4.': 'stages.2.',
  356. # 'stages.5.': 'stages.3.downsample.',
  357. # 'stages.6.': 'stages.3.'
  358. # }
  359. #
  360. # for k, v in state_dict.items():
  361. # for old_prefix, new_prefix in stage_mapping.items():
  362. # if k.startswith(old_prefix):
  363. # k = k.replace(old_prefix, new_prefix)
  364. # break
  365. # out_dict[k] = v
  366. return state_dict
  367. def _cfg(url: str = '', **kwargs: Any) -> Dict[str, Any]:
  368. return {
  369. 'url': url, 'num_classes': 1000, 'input_size': (3, 224, 224), 'pool_size': (7, 7),
  370. 'crop_pct': 1.0, 'interpolation': 'bicubic', 'test_crop_pct': 0.9,
  371. 'mean': IMAGENET_DEFAULT_MEAN, 'std': IMAGENET_DEFAULT_STD,
  372. 'first_conv': 'patch_embed.proj', 'classifier': 'classifier',
  373. 'paper_ids': 'arXiv:2303.03667',
  374. 'paper_name': "Run, Don't Walk: Chasing Higher FLOPS for Faster Neural Networks",
  375. 'origin_url': 'https://github.com/JierunChen/FasterNet',
  376. 'license': 'apache-2.0',
  377. **kwargs
  378. }
  379. default_cfgs = generate_default_cfgs({
  380. 'fasternet_t0.in1k': _cfg(
  381. hf_hub_id='timm/',
  382. #url='https://github.com/JierunChen/FasterNet/releases/download/v1.0/fasternet_t0-epoch.281-val_acc1.71.9180.pth',
  383. ),
  384. 'fasternet_t1.in1k': _cfg(
  385. hf_hub_id='timm/',
  386. #url='https://github.com/JierunChen/FasterNet/releases/download/v1.0/fasternet_t1-epoch.291-val_acc1.76.2180.pth',
  387. ),
  388. 'fasternet_t2.in1k': _cfg(
  389. hf_hub_id='timm/',
  390. #url='https://github.com/JierunChen/FasterNet/releases/download/v1.0/fasternet_t2-epoch.289-val_acc1.78.8860.pth',
  391. ),
  392. 'fasternet_s.in1k': _cfg(
  393. hf_hub_id='timm/',
  394. #url='https://github.com/JierunChen/FasterNet/releases/download/v1.0/fasternet_s-epoch.299-val_acc1.81.2840.pth',
  395. ),
  396. 'fasternet_m.in1k': _cfg(
  397. hf_hub_id='timm/',
  398. #url='https://github.com/JierunChen/FasterNet/releases/download/v1.0/fasternet_m-epoch.291-val_acc1.82.9620.pth',
  399. ),
  400. 'fasternet_l.in1k': _cfg(
  401. hf_hub_id='timm/',
  402. #url='https://github.com/JierunChen/FasterNet/releases/download/v1.0/fasternet_l-epoch.299-val_acc1.83.5060.pth',
  403. ),
  404. })
  405. def _create_fasternet(variant: str, pretrained: bool = False, **kwargs: Any) -> FasterNet:
  406. model = build_model_with_cfg(
  407. FasterNet, variant, pretrained,
  408. pretrained_filter_fn=checkpoint_filter_fn,
  409. feature_cfg=dict(out_indices=(0, 1, 2, 3), flatten_sequential=True),
  410. **kwargs,
  411. )
  412. return model
  413. @register_model
  414. def fasternet_t0(pretrained: bool = False, **kwargs: Any) -> FasterNet:
  415. model_args = dict(embed_dim=40, depths=(1, 2, 8, 2), drop_path_rate=0.0, act_layer=nn.GELU)
  416. return _create_fasternet('fasternet_t0', pretrained=pretrained, **dict(model_args, **kwargs))
  417. @register_model
  418. def fasternet_t1(pretrained: bool = False, **kwargs: Any) -> FasterNet:
  419. model_args = dict(embed_dim=64, depths=(1, 2, 8, 2), drop_path_rate=0.02, act_layer=nn.GELU)
  420. return _create_fasternet('fasternet_t1', pretrained=pretrained, **dict(model_args, **kwargs))
  421. @register_model
  422. def fasternet_t2(pretrained: bool = False, **kwargs: Any) -> FasterNet:
  423. model_args = dict(embed_dim=96, depths=(1, 2, 8, 2), drop_path_rate=0.05)
  424. return _create_fasternet('fasternet_t2', pretrained=pretrained, **dict(model_args, **kwargs))
  425. @register_model
  426. def fasternet_s(pretrained: bool = False, **kwargs: Any) -> FasterNet:
  427. model_args = dict(embed_dim=128, depths=(1, 2, 13, 2), drop_path_rate=0.1)
  428. return _create_fasternet('fasternet_s', pretrained=pretrained, **dict(model_args, **kwargs))
  429. @register_model
  430. def fasternet_m(pretrained: bool = False, **kwargs: Any) -> FasterNet:
  431. model_args = dict(embed_dim=144, depths=(3, 4, 18, 3), drop_path_rate=0.2)
  432. return _create_fasternet('fasternet_m', pretrained=pretrained, **dict(model_args, **kwargs))
  433. @register_model
  434. def fasternet_l(pretrained: bool = False, **kwargs: Any) -> FasterNet:
  435. model_args = dict(embed_dim=192, depths=(3, 4, 18, 3), drop_path_rate=0.3)
  436. return _create_fasternet('fasternet_l', pretrained=pretrained, **dict(model_args, **kwargs))