starnet.py 13 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361
  1. """
  2. Implementation of Prof-of-Concept Network: StarNet.
  3. We make StarNet as simple as possible [to show the key contribution of element-wise multiplication]:
  4. - like NO layer-scale in network design,
  5. - and NO EMA during training,
  6. - which would improve the performance further.
  7. Created by: Xu Ma (Email: ma.xu1@northeastern.edu)
  8. Modified Date: Mar/29/2024
  9. """
  10. from typing import Any, Dict, List, Optional, Set, Tuple, Union, Type
  11. import torch
  12. import torch.nn as nn
  13. import torch.nn.functional as F
  14. from timm.data import IMAGENET_DEFAULT_MEAN, IMAGENET_DEFAULT_STD
  15. from timm.layers import DropPath, SelectAdaptivePool2d, Linear, LayerType, trunc_normal_, calculate_drop_path_rates
  16. from ._builder import build_model_with_cfg
  17. from ._features import feature_take_indices
  18. from ._manipulate import checkpoint_seq
  19. from ._registry import register_model, generate_default_cfgs
  20. __all__ = ['StarNet']
  21. class ConvBN(nn.Sequential):
  22. def __init__(
  23. self,
  24. in_channels: int,
  25. out_channels: int,
  26. kernel_size: int = 1,
  27. stride: int = 1,
  28. padding: int = 0,
  29. with_bn: bool = True,
  30. device=None,
  31. dtype=None,
  32. **kwargs,
  33. ):
  34. dd = {'device': device, 'dtype': dtype}
  35. super().__init__()
  36. self.add_module('conv', nn.Conv2d(
  37. in_channels, out_channels, kernel_size, stride=stride, padding=padding, **dd, **kwargs))
  38. if with_bn:
  39. self.add_module('bn', nn.BatchNorm2d(out_channels, **dd))
  40. nn.init.constant_(self.bn.weight, 1)
  41. nn.init.constant_(self.bn.bias, 0)
  42. class Block(nn.Module):
  43. def __init__(
  44. self,
  45. dim: int,
  46. mlp_ratio: int = 3,
  47. drop_path: float = 0.,
  48. act_layer: Type[nn.Module] = nn.ReLU6,
  49. device=None,
  50. dtype=None,
  51. ):
  52. dd = {'device': device, 'dtype': dtype}
  53. super().__init__()
  54. self.dwconv = ConvBN(dim, dim, 7, 1, 3, groups=dim, with_bn=True, **dd)
  55. self.f1 = ConvBN(dim, mlp_ratio * dim, 1, with_bn=False, **dd)
  56. self.f2 = ConvBN(dim, mlp_ratio * dim, 1, with_bn=False, **dd)
  57. self.g = ConvBN(mlp_ratio * dim, dim, 1, with_bn=True, **dd)
  58. self.dwconv2 = ConvBN(dim, dim, 7, 1, 3, groups=dim, with_bn=False, **dd)
  59. self.act = act_layer()
  60. self.drop_path = DropPath(drop_path) if drop_path > 0. else nn.Identity()
  61. def forward(self, x: torch.Tensor) -> torch.Tensor:
  62. residual = x
  63. x = self.dwconv(x)
  64. x1, x2 = self.f1(x), self.f2(x)
  65. x = self.act(x1) * x2
  66. x = self.dwconv2(self.g(x))
  67. x = residual + self.drop_path(x)
  68. return x
  69. class StarNet(nn.Module):
  70. def __init__(
  71. self,
  72. base_dim: int = 32,
  73. depths: List[int] = [3, 3, 12, 5],
  74. mlp_ratio: int = 4,
  75. drop_rate: float = 0.,
  76. drop_path_rate: float = 0.,
  77. act_layer: Type[nn.Module] = nn.ReLU6,
  78. num_classes: int = 1000,
  79. in_chans: int = 3,
  80. global_pool: str = 'avg',
  81. output_stride: int = 32,
  82. device=None,
  83. dtype=None,
  84. **kwargs,
  85. ):
  86. dd = {'device': device, 'dtype': dtype}
  87. super().__init__()
  88. assert output_stride == 32
  89. self.num_classes = num_classes
  90. self.drop_rate = drop_rate
  91. self.grad_checkpointing = False
  92. self.feature_info = []
  93. stem_chs = 32
  94. # stem layer
  95. self.stem = nn.Sequential(
  96. ConvBN(in_chans, stem_chs, kernel_size=3, stride=2, padding=1, **dd),
  97. act_layer(),
  98. )
  99. prev_chs = stem_chs
  100. # build stages
  101. dpr = calculate_drop_path_rates(drop_path_rate, sum(depths)) # stochastic depth
  102. stages = []
  103. cur = 0
  104. for i_layer in range(len(depths)):
  105. embed_dim = base_dim * 2 ** i_layer
  106. down_sampler = ConvBN(prev_chs, embed_dim, 3, stride=2, padding=1, **dd)
  107. blocks = [Block(embed_dim, mlp_ratio, dpr[cur + i], act_layer, **dd) for i in range(depths[i_layer])]
  108. cur += depths[i_layer]
  109. prev_chs = embed_dim
  110. stages.append(nn.Sequential(down_sampler, *blocks))
  111. self.feature_info.append(dict(
  112. num_chs=prev_chs, reduction=2**(i_layer+2), module=f'stages.{i_layer}'))
  113. self.stages = nn.Sequential(*stages)
  114. # head
  115. self.num_features = self.head_hidden_size = prev_chs
  116. self.norm = nn.BatchNorm2d(self.num_features, **dd)
  117. self.global_pool = SelectAdaptivePool2d(pool_type=global_pool)
  118. self.flatten = nn.Flatten(1) if global_pool else nn.Identity() # don't flatten if pooling disabled
  119. self.head = Linear(self.num_features, num_classes, **dd) if num_classes > 0 else nn.Identity()
  120. self.apply(self._init_weights)
  121. def _init_weights(self, m):
  122. if isinstance(m, (nn.Linear, nn.Conv2d)):
  123. trunc_normal_(m.weight, std=.02)
  124. if isinstance(m, nn.Linear) and m.bias is not None:
  125. nn.init.constant_(m.bias, 0)
  126. elif isinstance(m, nn.BatchNorm2d):
  127. nn.init.constant_(m.bias, 0)
  128. nn.init.constant_(m.weight, 1.0)
  129. @torch.jit.ignore
  130. def no_weight_decay(self) -> Set:
  131. return set()
  132. @torch.jit.ignore
  133. def group_matcher(self, coarse: bool = False) -> Dict[str, Any]:
  134. matcher = dict(
  135. stem=r'^stem\.\d+',
  136. blocks=[
  137. (r'^stages\.(\d+)' if coarse else r'^stages\.(\d+)\.(\d+)', None),
  138. (r'norm', (99999,))
  139. ]
  140. )
  141. return matcher
  142. @torch.jit.ignore
  143. def set_grad_checkpointing(self, enable: bool = True):
  144. self.grad_checkpointing = enable
  145. @torch.jit.ignore
  146. def get_classifier(self) -> nn.Module:
  147. return self.head
  148. def reset_classifier(self, num_classes: int, global_pool: Optional[str] = None):
  149. self.num_classes = num_classes
  150. if global_pool is not None:
  151. # NOTE: cannot meaningfully change pooling of efficient head after creation
  152. self.global_pool = SelectAdaptivePool2d(pool_type=global_pool)
  153. self.flatten = nn.Flatten(1) if global_pool else nn.Identity() # don't flatten if pooling disabled
  154. self.head = Linear(
  155. self.head_hidden_size, num_classes,
  156. device=self.head.weight.device if isinstance(self.head, nn.Linear) else None,
  157. dtype=self.head.weight.dtype if isinstance(self.head, nn.Linear) else None,
  158. ) if num_classes > 0 else nn.Identity()
  159. def forward_intermediates(
  160. self,
  161. x: torch.Tensor,
  162. indices: Optional[Union[int, List[int]]] = None,
  163. norm: bool = False,
  164. stop_early: bool = False,
  165. output_fmt: str = 'NCHW',
  166. intermediates_only: bool = False,
  167. ) -> Union[List[torch.Tensor], Tuple[torch.Tensor, List[torch.Tensor]]]:
  168. """ Forward features that returns intermediates.
  169. Args:
  170. x: Input image tensor
  171. indices: Take last n blocks if int, all if None, select matching indices if sequence
  172. norm: Apply norm layer to compatible intermediates
  173. stop_early: Stop iterating over blocks when last desired intermediate hit
  174. output_fmt: Shape of intermediate feature outputs
  175. intermediates_only: Only return intermediate features
  176. Returns:
  177. """
  178. assert output_fmt in ('NCHW',), 'Output shape must be NCHW.'
  179. intermediates = []
  180. take_indices, max_index = feature_take_indices(len(self.stages), indices)
  181. last_idx = len(self.stages) - 1
  182. # forward pass
  183. x = self.stem(x)
  184. if torch.jit.is_scripting() or not stop_early: # can't slice blocks in torchscript
  185. stages = self.stages
  186. else:
  187. stages = self.stages[:max_index + 1]
  188. for feat_idx, stage in enumerate(stages):
  189. if self.grad_checkpointing and not torch.jit.is_scripting():
  190. x = checkpoint_seq(stage, x)
  191. else:
  192. x = stage(x)
  193. if feat_idx in take_indices:
  194. if norm and feat_idx == last_idx:
  195. x_inter = self.norm(x) # applying final norm last intermediate
  196. else:
  197. x_inter = x
  198. intermediates.append(x_inter)
  199. if intermediates_only:
  200. return intermediates
  201. if feat_idx == last_idx:
  202. x = self.norm(x)
  203. return x, intermediates
  204. def prune_intermediate_layers(
  205. self,
  206. indices: Union[int, List[int]] = 1,
  207. prune_norm: bool = False,
  208. prune_head: bool = True,
  209. ):
  210. """ Prune layers not required for specified intermediates.
  211. """
  212. take_indices, max_index = feature_take_indices(len(self.stages), indices)
  213. self.stages = self.stages[:max_index + 1] # truncate blocks w/ stem as idx 0
  214. if prune_norm:
  215. self.norm = nn.Identity()
  216. if prune_head:
  217. self.reset_classifier(0, '')
  218. return take_indices
  219. def forward_features(self, x: torch.Tensor) -> torch.Tensor:
  220. x = self.stem(x)
  221. if self.grad_checkpointing and not torch.jit.is_scripting():
  222. x = checkpoint_seq(self.stages, x)
  223. else:
  224. x = self.stages(x)
  225. x = self.norm(x)
  226. return x
  227. def forward_head(self, x: torch.Tensor, pre_logits: bool = False) -> torch.Tensor:
  228. x = self.global_pool(x)
  229. x = self.flatten(x)
  230. if self.drop_rate > 0.:
  231. x = F.dropout(x, p=self.drop_rate, training=self.training)
  232. return x if pre_logits else self.head(x)
  233. def forward(self, x: torch.Tensor) -> torch.Tensor:
  234. x = self.forward_features(x)
  235. x = self.forward_head(x)
  236. return x
  237. def checkpoint_filter_fn(state_dict: Dict[str, torch.Tensor], model: nn.Module) -> Dict[str, torch.Tensor]:
  238. return state_dict.get('state_dict', state_dict)
  239. def _cfg(url: str = '', **kwargs: Any) -> Dict[str, Any]:
  240. return {
  241. 'url': url, 'num_classes': 1000, 'input_size': (3, 224, 224), 'pool_size': (7, 7),
  242. 'crop_pct': 0.875, 'interpolation': 'bicubic',
  243. 'mean': IMAGENET_DEFAULT_MEAN, 'std': IMAGENET_DEFAULT_STD,
  244. 'first_conv': 'stem.0.conv', 'classifier': 'head',
  245. 'paper_ids': 'arXiv:2403.19967',
  246. 'paper_name': 'Rewrite the Stars',
  247. 'origin_url': 'https://github.com/ma-xu/Rewrite-the-Stars', 'license': 'apache-2.0',
  248. **kwargs
  249. }
  250. default_cfgs = generate_default_cfgs({
  251. 'starnet_s1.in1k': _cfg(
  252. hf_hub_id='timm/',
  253. #url='https://github.com/ma-xu/Rewrite-the-Stars/releases/download/checkpoints_v1/starnet_s1.pth.tar',
  254. ),
  255. 'starnet_s2.in1k': _cfg(
  256. hf_hub_id='timm/',
  257. #url='https://github.com/ma-xu/Rewrite-the-Stars/releases/download/checkpoints_v1/starnet_s2.pth.tar',
  258. ),
  259. 'starnet_s3.in1k': _cfg(
  260. hf_hub_id='timm/',
  261. #url='https://github.com/ma-xu/Rewrite-the-Stars/releases/download/checkpoints_v1/starnet_s3.pth.tar',
  262. ),
  263. 'starnet_s4.in1k': _cfg(
  264. hf_hub_id='timm/',
  265. #url='https://github.com/ma-xu/Rewrite-the-Stars/releases/download/checkpoints_v1/starnet_s4.pth.tar',
  266. ),
  267. 'starnet_s050.untrained': _cfg(),
  268. 'starnet_s100.untrained': _cfg(),
  269. 'starnet_s150.untrained': _cfg(),
  270. })
  271. def _create_starnet(variant: str, pretrained: bool = False, **kwargs: Any) -> StarNet:
  272. model = build_model_with_cfg(
  273. StarNet, variant, pretrained,
  274. pretrained_filter_fn=checkpoint_filter_fn,
  275. feature_cfg=dict(out_indices=(0, 1, 2, 3), flatten_sequential=True),
  276. **kwargs,
  277. )
  278. return model
  279. @register_model
  280. def starnet_s1(pretrained: bool = False, **kwargs: Any) -> StarNet:
  281. model_args = dict(base_dim=24, depths=[2, 2, 8, 3])
  282. return _create_starnet('starnet_s1', pretrained=pretrained, **dict(model_args, **kwargs))
  283. @register_model
  284. def starnet_s2(pretrained: bool = False, **kwargs: Any) -> StarNet:
  285. model_args = dict(base_dim=32, depths=[1, 2, 6, 2])
  286. return _create_starnet('starnet_s2', pretrained=pretrained, **dict(model_args, **kwargs))
  287. @register_model
  288. def starnet_s3(pretrained: bool = False, **kwargs: Any) -> StarNet:
  289. model_args = dict(base_dim=32, depths=[2, 2, 8, 4])
  290. return _create_starnet('starnet_s3', pretrained=pretrained, **dict(model_args, **kwargs))
  291. @register_model
  292. def starnet_s4(pretrained: bool = False, **kwargs: Any) -> StarNet:
  293. model_args = dict(base_dim=32, depths=[3, 3, 12, 5])
  294. return _create_starnet('starnet_s4', pretrained=pretrained, **dict(model_args, **kwargs))
  295. # very small networks #
  296. @register_model
  297. def starnet_s050(pretrained: bool = False, **kwargs: Any) -> StarNet:
  298. model_args = dict(base_dim=16, depths=[1, 1, 3, 1], mlp_ratio=3)
  299. return _create_starnet('starnet_s050', pretrained=pretrained, **dict(model_args, **kwargs))
  300. @register_model
  301. def starnet_s100(pretrained: bool = False, **kwargs: Any) -> StarNet:
  302. model_args = dict(base_dim=20, depths=[1, 2, 4, 1], mlp_ratio=4)
  303. return _create_starnet('starnet_s100', pretrained=pretrained, **dict(model_args, **kwargs))
  304. @register_model
  305. def starnet_s150(pretrained: bool = False, **kwargs: Any) -> StarNet:
  306. model_args = dict(base_dim=24, depths=[1, 2, 4, 2], mlp_ratio=3)
  307. return _create_starnet('starnet_s150', pretrained=pretrained, **dict(model_args, **kwargs))