tresnet.py 16 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447
  1. """
  2. TResNet: High Performance GPU-Dedicated Architecture
  3. https://arxiv.org/pdf/2003.13630.pdf
  4. Original model: https://github.com/mrT23/TResNet
  5. """
  6. from collections import OrderedDict
  7. from functools import partial
  8. from typing import List, Optional, Tuple, Union, Type
  9. import torch
  10. import torch.nn as nn
  11. from timm.layers import SpaceToDepth, BlurPool2d, ClassifierHead, SEModule, ConvNormAct, DropPath, calculate_drop_path_rates
  12. from ._builder import build_model_with_cfg
  13. from ._features import feature_take_indices
  14. from ._manipulate import checkpoint, checkpoint_seq
  15. from ._registry import register_model, generate_default_cfgs, register_model_deprecations
  16. __all__ = ['TResNet'] # model_registry will add each entrypoint fn to this
  17. class BasicBlock(nn.Module):
  18. expansion = 1
  19. def __init__(
  20. self,
  21. inplanes: int,
  22. planes: int,
  23. stride: int = 1,
  24. downsample: Optional[nn.Module] = None,
  25. use_se: bool = True,
  26. aa_layer: Optional[Type[nn.Module]] = None,
  27. drop_path_rate: float = 0.,
  28. device=None,
  29. dtype=None,
  30. ) -> None:
  31. dd = {'device': device, 'dtype': dtype}
  32. super().__init__()
  33. self.downsample = downsample
  34. self.stride = stride
  35. act_layer = partial(nn.LeakyReLU, negative_slope=1e-3)
  36. self.conv1 = ConvNormAct(
  37. inplanes,
  38. planes,
  39. kernel_size=3,
  40. stride=stride,
  41. act_layer=act_layer,
  42. aa_layer=aa_layer,
  43. **dd,
  44. )
  45. self.conv2 = ConvNormAct(planes, planes, kernel_size=3, stride=1, apply_act=False, **dd)
  46. self.act = nn.ReLU(inplace=True)
  47. rd_chs = max(planes * self.expansion // 4, 64)
  48. self.se = SEModule(planes * self.expansion, rd_channels=rd_chs, **dd) if use_se else None
  49. self.drop_path = DropPath(drop_path_rate) if drop_path_rate > 0 else nn.Identity()
  50. def forward(self, x):
  51. if self.downsample is not None:
  52. shortcut = self.downsample(x)
  53. else:
  54. shortcut = x
  55. out = self.conv1(x)
  56. out = self.conv2(out)
  57. if self.se is not None:
  58. out = self.se(out)
  59. out = self.drop_path(out) + shortcut
  60. out = self.act(out)
  61. return out
  62. class Bottleneck(nn.Module):
  63. expansion = 4
  64. def __init__(
  65. self,
  66. inplanes: int,
  67. planes: int,
  68. stride: int = 1,
  69. downsample: Optional[nn.Module] = None,
  70. use_se: bool = True,
  71. act_layer: Optional[Type[nn.Module]] = None,
  72. aa_layer: Optional[Type[nn.Module]] = None,
  73. drop_path_rate: float = 0.,
  74. device=None,
  75. dtype=None,
  76. ) -> None:
  77. dd = {'device': device, 'dtype': dtype}
  78. super().__init__()
  79. self.downsample = downsample
  80. self.stride = stride
  81. act_layer = act_layer or partial(nn.LeakyReLU, negative_slope=1e-3)
  82. self.conv1 = ConvNormAct(inplanes, planes, kernel_size=1, stride=1, act_layer=act_layer, **dd)
  83. self.conv2 = ConvNormAct(
  84. planes,
  85. planes,
  86. kernel_size=3,
  87. stride=stride,
  88. act_layer=act_layer,
  89. aa_layer=aa_layer,
  90. **dd,
  91. )
  92. reduction_chs = max(planes * self.expansion // 8, 64)
  93. self.se = SEModule(planes, rd_channels=reduction_chs, **dd) if use_se else None
  94. self.conv3 = ConvNormAct(planes, planes * self.expansion, kernel_size=1, stride=1, apply_act=False, **dd)
  95. self.drop_path = DropPath(drop_path_rate) if drop_path_rate > 0 else nn.Identity()
  96. self.act = nn.ReLU(inplace=True)
  97. def forward(self, x):
  98. if self.downsample is not None:
  99. shortcut = self.downsample(x)
  100. else:
  101. shortcut = x
  102. out = self.conv1(x)
  103. out = self.conv2(out)
  104. if self.se is not None:
  105. out = self.se(out)
  106. out = self.conv3(out)
  107. out = self.drop_path(out) + shortcut
  108. out = self.act(out)
  109. return out
  110. class TResNet(nn.Module):
  111. def __init__(
  112. self,
  113. layers: List[int],
  114. in_chans: int = 3,
  115. num_classes: int = 1000,
  116. width_factor: float = 1.0,
  117. v2: bool = False,
  118. global_pool: str = 'fast',
  119. drop_rate: float = 0.,
  120. drop_path_rate: float = 0.,
  121. device=None,
  122. dtype=None,
  123. ) -> None:
  124. super().__init__()
  125. dd = {'device': device, 'dtype': dtype}
  126. self.num_classes = num_classes
  127. self.drop_rate = drop_rate
  128. self.grad_checkpointing = False
  129. aa_layer = BlurPool2d
  130. act_layer = nn.LeakyReLU
  131. # TResnet stages
  132. self.inplanes = int(64 * width_factor)
  133. self.planes = int(64 * width_factor)
  134. if v2:
  135. self.inplanes = self.inplanes // 8 * 8
  136. self.planes = self.planes // 8 * 8
  137. dpr = calculate_drop_path_rates(drop_path_rate, layers, stagewise=True)
  138. conv1 = ConvNormAct(in_chans * 16, self.planes, stride=1, kernel_size=3, act_layer=act_layer, **dd)
  139. layer1 = self._make_layer(
  140. Bottleneck if v2 else BasicBlock,
  141. self.planes, layers[0], stride=1, use_se=True, aa_layer=aa_layer, drop_path_rate=dpr[0], **dd)
  142. layer2 = self._make_layer(
  143. Bottleneck if v2 else BasicBlock,
  144. self.planes * 2, layers[1], stride=2, use_se=True, aa_layer=aa_layer, drop_path_rate=dpr[1], **dd)
  145. layer3 = self._make_layer(
  146. Bottleneck,
  147. self.planes * 4, layers[2], stride=2, use_se=True, aa_layer=aa_layer, drop_path_rate=dpr[2], **dd)
  148. layer4 = self._make_layer(
  149. Bottleneck,
  150. self.planes * 8, layers[3], stride=2, use_se=False, aa_layer=aa_layer, drop_path_rate=dpr[3], **dd)
  151. # body
  152. self.body = nn.Sequential(OrderedDict([
  153. ('s2d', SpaceToDepth()),
  154. ('conv1', conv1),
  155. ('layer1', layer1),
  156. ('layer2', layer2),
  157. ('layer3', layer3),
  158. ('layer4', layer4),
  159. ]))
  160. self.feature_info = [
  161. dict(num_chs=self.planes, reduction=2, module=''), # Not with S2D?
  162. dict(num_chs=self.planes * (Bottleneck.expansion if v2 else 1), reduction=4, module='body.layer1'),
  163. dict(num_chs=self.planes * 2 * (Bottleneck.expansion if v2 else 1), reduction=8, module='body.layer2'),
  164. dict(num_chs=self.planes * 4 * Bottleneck.expansion, reduction=16, module='body.layer3'),
  165. dict(num_chs=self.planes * 8 * Bottleneck.expansion, reduction=32, module='body.layer4'),
  166. ]
  167. # head
  168. self.num_features = self.head_hidden_size = (self.planes * 8) * Bottleneck.expansion
  169. self.head = ClassifierHead(self.num_features, num_classes, pool_type=global_pool, drop_rate=drop_rate, **dd)
  170. # model initialization
  171. for m in self.modules():
  172. if isinstance(m, nn.Conv2d):
  173. nn.init.kaiming_normal_(m.weight, mode='fan_out', nonlinearity='leaky_relu')
  174. if isinstance(m, nn.Linear):
  175. m.weight.data.normal_(0, 0.01)
  176. # residual connections special initialization
  177. for m in self.modules():
  178. if isinstance(m, BasicBlock):
  179. nn.init.zeros_(m.conv2.bn.weight)
  180. if isinstance(m, Bottleneck):
  181. nn.init.zeros_(m.conv3.bn.weight)
  182. def _make_layer(
  183. self,
  184. block,
  185. planes,
  186. blocks,
  187. stride=1,
  188. use_se=True,
  189. aa_layer=None,
  190. drop_path_rate=0.,
  191. device=None,
  192. dtype=None,
  193. ):
  194. dd = {'device': device, 'dtype': dtype}
  195. downsample = None
  196. if stride != 1 or self.inplanes != planes * block.expansion:
  197. layers = []
  198. if stride == 2:
  199. # avg pooling before 1x1 conv
  200. layers.append(nn.AvgPool2d(kernel_size=2, stride=2, ceil_mode=True, count_include_pad=False))
  201. layers += [ConvNormAct(
  202. self.inplanes, planes * block.expansion, kernel_size=1, stride=1, apply_act=False, **dd)]
  203. downsample = nn.Sequential(*layers)
  204. layers = []
  205. for i in range(blocks):
  206. layers.append(block(
  207. self.inplanes,
  208. planes,
  209. stride=stride if i == 0 else 1,
  210. downsample=downsample if i == 0 else None,
  211. use_se=use_se,
  212. aa_layer=aa_layer,
  213. drop_path_rate=drop_path_rate[i] if isinstance(drop_path_rate, list) else drop_path_rate,
  214. **dd,
  215. ))
  216. self.inplanes = planes * block.expansion
  217. return nn.Sequential(*layers)
  218. @torch.jit.ignore
  219. def group_matcher(self, coarse=False):
  220. matcher = dict(stem=r'^body\.conv1', blocks=r'^body\.layer(\d+)' if coarse else r'^body\.layer(\d+)\.(\d+)')
  221. return matcher
  222. @torch.jit.ignore
  223. def set_grad_checkpointing(self, enable=True):
  224. self.grad_checkpointing = enable
  225. @torch.jit.ignore
  226. def get_classifier(self) -> nn.Module:
  227. return self.head.fc
  228. def reset_classifier(self, num_classes: int, global_pool: Optional[str] = None):
  229. self.num_classes = num_classes
  230. self.head.reset(num_classes, pool_type=global_pool)
  231. def forward_intermediates(
  232. self,
  233. x: torch.Tensor,
  234. indices: Optional[Union[int, List[int]]] = None,
  235. norm: bool = False,
  236. stop_early: bool = False,
  237. output_fmt: str = 'NCHW',
  238. intermediates_only: bool = False,
  239. ) -> Union[List[torch.Tensor], Tuple[torch.Tensor, List[torch.Tensor]]]:
  240. """ Forward features that returns intermediates.
  241. Args:
  242. x: Input image tensor
  243. indices: Take last n blocks if int, all if None, select matching indices if sequence
  244. norm: Apply norm layer to compatible intermediates
  245. stop_early: Stop iterating over blocks when last desired intermediate hit
  246. output_fmt: Shape of intermediate feature outputs
  247. intermediates_only: Only return intermediate features
  248. Returns:
  249. """
  250. assert output_fmt in ('NCHW',), 'Output shape must be NCHW.'
  251. intermediates = []
  252. stage_ends = [1, 2, 3, 4, 5]
  253. take_indices, max_index = feature_take_indices(len(stage_ends), indices)
  254. take_indices = [stage_ends[i] for i in take_indices]
  255. max_index = stage_ends[max_index]
  256. # forward pass
  257. if torch.jit.is_scripting() or not stop_early: # can't slice blocks in torchscript
  258. stages = self.body
  259. else:
  260. stages = self.body[:max_index + 1]
  261. for feat_idx, stage in enumerate(stages):
  262. if self.grad_checkpointing and not torch.jit.is_scripting():
  263. x = checkpoint(stage, x)
  264. else:
  265. x = stage(x)
  266. if feat_idx in take_indices:
  267. intermediates.append(x)
  268. if intermediates_only:
  269. return intermediates
  270. return x, intermediates
  271. def prune_intermediate_layers(
  272. self,
  273. indices: Union[int, List[int]] = 1,
  274. prune_norm: bool = False,
  275. prune_head: bool = True,
  276. ):
  277. """ Prune layers not required for specified intermediates.
  278. """
  279. stage_ends = [1, 2, 3, 4, 5]
  280. take_indices, max_index = feature_take_indices(len(stage_ends), indices)
  281. max_index = stage_ends[max_index]
  282. self.body = self.body[:max_index + 1] # truncate blocks w/ stem as idx 0
  283. if prune_head:
  284. self.reset_classifier(0, '')
  285. return take_indices
  286. def forward_features(self, x):
  287. if self.grad_checkpointing and not torch.jit.is_scripting():
  288. x = self.body.s2d(x)
  289. x = self.body.conv1(x)
  290. x = checkpoint_seq([
  291. self.body.layer1,
  292. self.body.layer2,
  293. self.body.layer3,
  294. self.body.layer4],
  295. x, flatten=True)
  296. else:
  297. x = self.body(x)
  298. return x
  299. def forward_head(self, x, pre_logits: bool = False):
  300. return self.head(x, pre_logits=pre_logits) if pre_logits else self.head(x)
  301. def forward(self, x):
  302. x = self.forward_features(x)
  303. x = self.forward_head(x)
  304. return x
  305. def checkpoint_filter_fn(state_dict, model):
  306. if 'body.conv1.conv.weight' in state_dict:
  307. return state_dict
  308. import re
  309. state_dict = state_dict.get('model', state_dict)
  310. state_dict = state_dict.get('state_dict', state_dict)
  311. out_dict = {}
  312. for k, v in state_dict.items():
  313. k = re.sub(r'conv(\d+)\.0.0', lambda x: f'conv{int(x.group(1))}.conv', k)
  314. k = re.sub(r'conv(\d+)\.0.1', lambda x: f'conv{int(x.group(1))}.bn', k)
  315. k = re.sub(r'conv(\d+)\.0', lambda x: f'conv{int(x.group(1))}.conv', k)
  316. k = re.sub(r'conv(\d+)\.1', lambda x: f'conv{int(x.group(1))}.bn', k)
  317. k = re.sub(r'downsample\.(\d+)\.0', lambda x: f'downsample.{int(x.group(1))}.conv', k)
  318. k = re.sub(r'downsample\.(\d+)\.1', lambda x: f'downsample.{int(x.group(1))}.bn', k)
  319. if k.endswith('bn.weight'):
  320. # convert weight from inplace_abn to batchnorm
  321. v = v.abs().add(1e-5)
  322. out_dict[k] = v
  323. return out_dict
  324. def _create_tresnet(variant, pretrained=False, **kwargs):
  325. return build_model_with_cfg(
  326. TResNet,
  327. variant,
  328. pretrained,
  329. pretrained_filter_fn=checkpoint_filter_fn,
  330. feature_cfg=dict(out_indices=(1, 2, 3, 4), flatten_sequential=True),
  331. **kwargs,
  332. )
  333. def _cfg(url='', **kwargs):
  334. return {
  335. 'url': url, 'num_classes': 1000, 'input_size': (3, 224, 224), 'pool_size': (7, 7),
  336. 'crop_pct': 0.875, 'interpolation': 'bilinear',
  337. 'mean': (0., 0., 0.), 'std': (1., 1., 1.),
  338. 'first_conv': 'body.conv1.conv', 'classifier': 'head.fc',
  339. 'license': 'apache-2.0',
  340. **kwargs
  341. }
  342. default_cfgs = generate_default_cfgs({
  343. 'tresnet_m.miil_in21k_ft_in1k': _cfg(hf_hub_id='timm/'),
  344. 'tresnet_m.miil_in21k': _cfg(hf_hub_id='timm/', num_classes=11221),
  345. 'tresnet_m.miil_in1k': _cfg(hf_hub_id='timm/'),
  346. 'tresnet_l.miil_in1k': _cfg(hf_hub_id='timm/'),
  347. 'tresnet_xl.miil_in1k': _cfg(hf_hub_id='timm/'),
  348. 'tresnet_m.miil_in1k_448': _cfg(
  349. input_size=(3, 448, 448), pool_size=(14, 14),
  350. hf_hub_id='timm/'),
  351. 'tresnet_l.miil_in1k_448': _cfg(
  352. input_size=(3, 448, 448), pool_size=(14, 14),
  353. hf_hub_id='timm/'),
  354. 'tresnet_xl.miil_in1k_448': _cfg(
  355. input_size=(3, 448, 448), pool_size=(14, 14),
  356. hf_hub_id='timm/'),
  357. 'tresnet_v2_l.miil_in21k_ft_in1k': _cfg(hf_hub_id='timm/'),
  358. 'tresnet_v2_l.miil_in21k': _cfg(hf_hub_id='timm/', num_classes=11221),
  359. })
  360. @register_model
  361. def tresnet_m(pretrained=False, **kwargs) -> TResNet:
  362. model_args = dict(layers=[3, 4, 11, 3])
  363. return _create_tresnet('tresnet_m', pretrained=pretrained, **dict(model_args, **kwargs))
  364. @register_model
  365. def tresnet_l(pretrained=False, **kwargs) -> TResNet:
  366. model_args = dict(layers=[4, 5, 18, 3], width_factor=1.2)
  367. return _create_tresnet('tresnet_l', pretrained=pretrained, **dict(model_args, **kwargs))
  368. @register_model
  369. def tresnet_xl(pretrained=False, **kwargs) -> TResNet:
  370. model_args = dict(layers=[4, 5, 24, 3], width_factor=1.3)
  371. return _create_tresnet('tresnet_xl', pretrained=pretrained, **dict(model_args, **kwargs))
  372. @register_model
  373. def tresnet_v2_l(pretrained=False, **kwargs) -> TResNet:
  374. model_args = dict(layers=[3, 4, 23, 3], width_factor=1.0, v2=True)
  375. return _create_tresnet('tresnet_v2_l', pretrained=pretrained, **dict(model_args, **kwargs))
  376. register_model_deprecations(__name__, {
  377. 'tresnet_m_miil_in21k': 'tresnet_m.miil_in21k',
  378. 'tresnet_m_448': 'tresnet_m.miil_in1k_448',
  379. 'tresnet_l_448': 'tresnet_l.miil_in1k_448',
  380. 'tresnet_xl_448': 'tresnet_xl.miil_in1k_448',
  381. })