xception_aligned.py 16 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482
  1. """Pytorch impl of Aligned Xception 41, 65, 71
  2. This is a correct, from scratch impl of Aligned Xception (Deeplab) models compatible with TF weights at
  3. https://github.com/tensorflow/models/blob/master/research/deeplab/g3doc/model_zoo.md
  4. Hacked together by / Copyright 2020 Ross Wightman
  5. """
  6. from functools import partial
  7. from typing import List, Dict, Type, Optional
  8. import torch
  9. import torch.nn as nn
  10. from timm.data import IMAGENET_INCEPTION_MEAN, IMAGENET_INCEPTION_STD
  11. from timm.layers import ClassifierHead, ConvNormAct, DropPath, PadType, create_conv2d, get_norm_act_layer
  12. from timm.layers.helpers import to_3tuple
  13. from ._builder import build_model_with_cfg
  14. from ._manipulate import checkpoint_seq
  15. from ._registry import register_model, generate_default_cfgs
  16. __all__ = ['XceptionAligned']
  17. class SeparableConv2d(nn.Module):
  18. def __init__(
  19. self,
  20. in_chs: int,
  21. out_chs: int,
  22. kernel_size: int = 3,
  23. stride: int = 1,
  24. dilation: int = 1,
  25. padding: PadType = '',
  26. act_layer: Type[nn.Module] = nn.ReLU,
  27. norm_layer: Type[nn.Module] = nn.BatchNorm2d,
  28. device=None,
  29. dtype=None,
  30. ):
  31. dd = {'device': device, 'dtype': dtype}
  32. super().__init__()
  33. self.kernel_size = kernel_size
  34. self.dilation = dilation
  35. # depthwise convolution
  36. self.conv_dw = create_conv2d(
  37. in_chs,
  38. in_chs,
  39. kernel_size,
  40. stride=stride,
  41. padding=padding,
  42. dilation=dilation,
  43. depthwise=True,
  44. **dd,
  45. )
  46. self.bn_dw = norm_layer(in_chs, **dd)
  47. self.act_dw = act_layer(inplace=True) if act_layer is not None else nn.Identity()
  48. # pointwise convolution
  49. self.conv_pw = create_conv2d(in_chs, out_chs, kernel_size=1, **dd)
  50. self.bn_pw = norm_layer(out_chs, **dd)
  51. self.act_pw = act_layer(inplace=True) if act_layer is not None else nn.Identity()
  52. def forward(self, x):
  53. x = self.conv_dw(x)
  54. x = self.bn_dw(x)
  55. x = self.act_dw(x)
  56. x = self.conv_pw(x)
  57. x = self.bn_pw(x)
  58. x = self.act_pw(x)
  59. return x
  60. class PreSeparableConv2d(nn.Module):
  61. def __init__(
  62. self,
  63. in_chs: int,
  64. out_chs: int,
  65. kernel_size: int = 3,
  66. stride: int = 1,
  67. dilation: int = 1,
  68. padding: PadType = '',
  69. act_layer: Type[nn.Module] = nn.ReLU,
  70. norm_layer: Type[nn.Module] = nn.BatchNorm2d,
  71. first_act: bool = True,
  72. device=None,
  73. dtype=None,
  74. ):
  75. dd = {'device': device, 'dtype': dtype}
  76. super().__init__()
  77. norm_act_layer = get_norm_act_layer(norm_layer, act_layer=act_layer)
  78. self.kernel_size = kernel_size
  79. self.dilation = dilation
  80. self.norm = norm_act_layer(in_chs, inplace=True, **dd) if first_act else nn.Identity()
  81. # depthwise convolution
  82. self.conv_dw = create_conv2d(
  83. in_chs,
  84. in_chs,
  85. kernel_size,
  86. stride=stride,
  87. padding=padding,
  88. dilation=dilation,
  89. depthwise=True,
  90. **dd,
  91. )
  92. # pointwise convolution
  93. self.conv_pw = create_conv2d(in_chs, out_chs, kernel_size=1, **dd)
  94. def forward(self, x):
  95. x = self.norm(x)
  96. x = self.conv_dw(x)
  97. x = self.conv_pw(x)
  98. return x
  99. class XceptionModule(nn.Module):
  100. def __init__(
  101. self,
  102. in_chs: int,
  103. out_chs: int,
  104. stride: int = 1,
  105. dilation: int = 1,
  106. pad_type: PadType = '',
  107. start_with_relu: bool = True,
  108. no_skip: bool = False,
  109. act_layer: Type[nn.Module] = nn.ReLU,
  110. norm_layer: Optional[Type[nn.Module]] = None,
  111. drop_path: Optional[nn.Module] = None,
  112. device=None,
  113. dtype=None,
  114. ):
  115. dd = {'device': device, 'dtype': dtype}
  116. super().__init__()
  117. out_chs = to_3tuple(out_chs)
  118. self.in_channels = in_chs
  119. self.out_channels = out_chs[-1]
  120. self.no_skip = no_skip
  121. if not no_skip and (self.out_channels != self.in_channels or stride != 1):
  122. self.shortcut = ConvNormAct(
  123. in_chs,
  124. self.out_channels,
  125. 1,
  126. stride=stride,
  127. norm_layer=norm_layer,
  128. apply_act=False,
  129. **dd,
  130. )
  131. else:
  132. self.shortcut = None
  133. separable_act_layer = None if start_with_relu else act_layer
  134. self.stack = nn.Sequential()
  135. for i in range(3):
  136. if start_with_relu:
  137. self.stack.add_module(f'act{i + 1}', act_layer(inplace=i > 0))
  138. self.stack.add_module(f'conv{i + 1}', SeparableConv2d(
  139. in_chs,
  140. out_chs[i],
  141. 3,
  142. stride=stride if i == 2 else 1,
  143. dilation=dilation,
  144. padding=pad_type,
  145. act_layer=separable_act_layer,
  146. norm_layer=norm_layer,
  147. **dd,
  148. ))
  149. in_chs = out_chs[i]
  150. self.drop_path = drop_path
  151. def forward(self, x):
  152. skip = x
  153. x = self.stack(x)
  154. if self.shortcut is not None:
  155. skip = self.shortcut(skip)
  156. if not self.no_skip:
  157. if self.drop_path is not None:
  158. x = self.drop_path(x)
  159. x = x + skip
  160. return x
  161. class PreXceptionModule(nn.Module):
  162. def __init__(
  163. self,
  164. in_chs: int,
  165. out_chs: int,
  166. stride: int = 1,
  167. dilation: int = 1,
  168. pad_type: PadType = '',
  169. no_skip: bool = False,
  170. act_layer: Type[nn.Module] = nn.ReLU,
  171. norm_layer: Optional[Type[nn.Module]] = None,
  172. drop_path: Optional[nn.Module] = None,
  173. device=None,
  174. dtype=None,
  175. ):
  176. dd = {'device': device, 'dtype': dtype}
  177. super().__init__()
  178. out_chs = to_3tuple(out_chs)
  179. self.in_channels = in_chs
  180. self.out_channels = out_chs[-1]
  181. self.no_skip = no_skip
  182. if not no_skip and (self.out_channels != self.in_channels or stride != 1):
  183. self.shortcut = create_conv2d(in_chs, self.out_channels, 1, stride=stride, **dd)
  184. else:
  185. self.shortcut = nn.Identity()
  186. self.norm = get_norm_act_layer(norm_layer, act_layer=act_layer)(in_chs, inplace=True, **dd)
  187. self.stack = nn.Sequential()
  188. for i in range(3):
  189. self.stack.add_module(f'conv{i + 1}', PreSeparableConv2d(
  190. in_chs,
  191. out_chs[i],
  192. 3,
  193. stride=stride if i == 2 else 1,
  194. dilation=dilation,
  195. padding=pad_type,
  196. act_layer=act_layer,
  197. norm_layer=norm_layer,
  198. first_act=i > 0,
  199. **dd,
  200. ))
  201. in_chs = out_chs[i]
  202. self.drop_path = drop_path
  203. def forward(self, x):
  204. x = self.norm(x)
  205. skip = x
  206. x = self.stack(x)
  207. if not self.no_skip:
  208. if self.drop_path is not None:
  209. x = self.drop_path(x)
  210. x = x + self.shortcut(skip)
  211. return x
  212. class XceptionAligned(nn.Module):
  213. """Modified Aligned Xception
  214. """
  215. def __init__(
  216. self,
  217. block_cfg: List[Dict],
  218. num_classes: int = 1000,
  219. in_chans: int = 3,
  220. output_stride: int = 32,
  221. preact: bool = False,
  222. act_layer: Type[nn.Module] = nn.ReLU,
  223. norm_layer: Type[nn.Module] = nn.BatchNorm2d,
  224. drop_rate: float = 0.,
  225. drop_path_rate: float = 0.,
  226. global_pool: str = 'avg',
  227. device=None,
  228. dtype=None,
  229. ):
  230. super().__init__()
  231. dd = {'device': device, 'dtype': dtype}
  232. assert output_stride in (8, 16, 32)
  233. self.num_classes = num_classes
  234. self.drop_rate = drop_rate
  235. self.grad_checkpointing = False
  236. layer_args = dict(act_layer=act_layer, norm_layer=norm_layer, **dd)
  237. self.stem = nn.Sequential(*[
  238. ConvNormAct(in_chans, 32, kernel_size=3, stride=2, **layer_args),
  239. create_conv2d(32, 64, kernel_size=3, stride=1, **dd) if preact else
  240. ConvNormAct(32, 64, kernel_size=3, stride=1, **layer_args)
  241. ])
  242. curr_dilation = 1
  243. curr_stride = 2
  244. self.feature_info = []
  245. self.blocks = nn.Sequential()
  246. module_fn = PreXceptionModule if preact else XceptionModule
  247. net_num_blocks = len(block_cfg)
  248. net_block_idx = 0
  249. for i, b in enumerate(block_cfg):
  250. block_dpr = drop_path_rate * net_block_idx / (net_num_blocks - 1) # stochastic depth linear decay rule
  251. b['drop_path'] = DropPath(block_dpr) if block_dpr > 0. else None
  252. b['dilation'] = curr_dilation
  253. if b['stride'] > 1:
  254. name = f'blocks.{i}.stack.conv2' if preact else f'blocks.{i}.stack.act3'
  255. self.feature_info += [dict(num_chs=to_3tuple(b['out_chs'])[-2], reduction=curr_stride, module=name)]
  256. next_stride = curr_stride * b['stride']
  257. if next_stride > output_stride:
  258. curr_dilation *= b['stride']
  259. b['stride'] = 1
  260. else:
  261. curr_stride = next_stride
  262. self.blocks.add_module(str(i), module_fn(**b, **layer_args))
  263. self.num_features = self.blocks[-1].out_channels
  264. net_block_idx += 1
  265. self.feature_info += [dict(
  266. num_chs=self.num_features, reduction=curr_stride, module='blocks.' + str(len(self.blocks) - 1))]
  267. self.act = act_layer(inplace=True) if preact else nn.Identity()
  268. self.head_hidden_size = self.num_features
  269. self.head = ClassifierHead(
  270. in_features=self.num_features,
  271. num_classes=num_classes,
  272. pool_type=global_pool,
  273. drop_rate=drop_rate,
  274. **dd,
  275. )
  276. @torch.jit.ignore
  277. def group_matcher(self, coarse=False):
  278. return dict(
  279. stem=r'^stem',
  280. blocks=r'^blocks\.(\d+)',
  281. )
  282. @torch.jit.ignore
  283. def set_grad_checkpointing(self, enable=True):
  284. self.grad_checkpointing = enable
  285. @torch.jit.ignore
  286. def get_classifier(self) -> nn.Module:
  287. return self.head.fc
  288. def reset_classifier(self, num_classes: int, global_pool: Optional[str] = None):
  289. self.num_classes = num_classes
  290. self.head.reset(num_classes, pool_type=global_pool)
  291. def forward_features(self, x):
  292. x = self.stem(x)
  293. if self.grad_checkpointing and not torch.jit.is_scripting():
  294. x = checkpoint_seq(self.blocks, x)
  295. else:
  296. x = self.blocks(x)
  297. x = self.act(x)
  298. return x
  299. def forward_head(self, x, pre_logits: bool = False):
  300. return self.head(x, pre_logits=pre_logits) if pre_logits else self.head(x)
  301. def forward(self, x):
  302. x = self.forward_features(x)
  303. x = self.forward_head(x)
  304. return x
  305. def _xception(variant, pretrained=False, **kwargs):
  306. return build_model_with_cfg(
  307. XceptionAligned,
  308. variant,
  309. pretrained,
  310. feature_cfg=dict(flatten_sequential=True, feature_cls='hook'),
  311. **kwargs,
  312. )
  313. def _cfg(url='', **kwargs):
  314. return {
  315. 'url': url,
  316. 'num_classes': 1000, 'input_size': (3, 299, 299), 'pool_size': (10, 10),
  317. 'crop_pct': 0.903, 'interpolation': 'bicubic',
  318. 'mean': IMAGENET_INCEPTION_MEAN, 'std': IMAGENET_INCEPTION_STD,
  319. 'first_conv': 'stem.0.conv', 'classifier': 'head.fc', 'license': 'apache-2.0',
  320. **kwargs
  321. }
  322. default_cfgs = generate_default_cfgs({
  323. 'xception65.ra3_in1k': _cfg(
  324. hf_hub_id='timm/',
  325. crop_pct=0.94,
  326. ),
  327. 'xception41.tf_in1k': _cfg(hf_hub_id='timm/'),
  328. 'xception65.tf_in1k': _cfg(hf_hub_id='timm/'),
  329. 'xception71.tf_in1k': _cfg(hf_hub_id='timm/'),
  330. 'xception41p.ra3_in1k': _cfg(
  331. hf_hub_id='timm/',
  332. crop_pct=0.94,
  333. ),
  334. 'xception65p.ra3_in1k': _cfg(
  335. hf_hub_id='timm/',
  336. crop_pct=0.94,
  337. ),
  338. })
  339. @register_model
  340. def xception41(pretrained=False, **kwargs) -> XceptionAligned:
  341. """ Modified Aligned Xception-41
  342. """
  343. block_cfg = [
  344. # entry flow
  345. dict(in_chs=64, out_chs=128, stride=2),
  346. dict(in_chs=128, out_chs=256, stride=2),
  347. dict(in_chs=256, out_chs=728, stride=2),
  348. # middle flow
  349. *([dict(in_chs=728, out_chs=728, stride=1)] * 8),
  350. # exit flow
  351. dict(in_chs=728, out_chs=(728, 1024, 1024), stride=2),
  352. dict(in_chs=1024, out_chs=(1536, 1536, 2048), stride=1, no_skip=True, start_with_relu=False),
  353. ]
  354. model_args = dict(block_cfg=block_cfg, norm_layer=partial(nn.BatchNorm2d, eps=.001, momentum=.1))
  355. return _xception('xception41', pretrained=pretrained, **dict(model_args, **kwargs))
  356. @register_model
  357. def xception65(pretrained=False, **kwargs) -> XceptionAligned:
  358. """ Modified Aligned Xception-65
  359. """
  360. block_cfg = [
  361. # entry flow
  362. dict(in_chs=64, out_chs=128, stride=2),
  363. dict(in_chs=128, out_chs=256, stride=2),
  364. dict(in_chs=256, out_chs=728, stride=2),
  365. # middle flow
  366. *([dict(in_chs=728, out_chs=728, stride=1)] * 16),
  367. # exit flow
  368. dict(in_chs=728, out_chs=(728, 1024, 1024), stride=2),
  369. dict(in_chs=1024, out_chs=(1536, 1536, 2048), stride=1, no_skip=True, start_with_relu=False),
  370. ]
  371. model_args = dict(block_cfg=block_cfg, norm_layer=partial(nn.BatchNorm2d, eps=.001, momentum=.1))
  372. return _xception('xception65', pretrained=pretrained, **dict(model_args, **kwargs))
  373. @register_model
  374. def xception71(pretrained=False, **kwargs) -> XceptionAligned:
  375. """ Modified Aligned Xception-71
  376. """
  377. block_cfg = [
  378. # entry flow
  379. dict(in_chs=64, out_chs=128, stride=2),
  380. dict(in_chs=128, out_chs=256, stride=1),
  381. dict(in_chs=256, out_chs=256, stride=2),
  382. dict(in_chs=256, out_chs=728, stride=1),
  383. dict(in_chs=728, out_chs=728, stride=2),
  384. # middle flow
  385. *([dict(in_chs=728, out_chs=728, stride=1)] * 16),
  386. # exit flow
  387. dict(in_chs=728, out_chs=(728, 1024, 1024), stride=2),
  388. dict(in_chs=1024, out_chs=(1536, 1536, 2048), stride=1, no_skip=True, start_with_relu=False),
  389. ]
  390. model_args = dict(block_cfg=block_cfg, norm_layer=partial(nn.BatchNorm2d, eps=.001, momentum=.1))
  391. return _xception('xception71', pretrained=pretrained, **dict(model_args, **kwargs))
  392. @register_model
  393. def xception41p(pretrained=False, **kwargs) -> XceptionAligned:
  394. """ Modified Aligned Xception-41 w/ Pre-Act
  395. """
  396. block_cfg = [
  397. # entry flow
  398. dict(in_chs=64, out_chs=128, stride=2),
  399. dict(in_chs=128, out_chs=256, stride=2),
  400. dict(in_chs=256, out_chs=728, stride=2),
  401. # middle flow
  402. *([dict(in_chs=728, out_chs=728, stride=1)] * 8),
  403. # exit flow
  404. dict(in_chs=728, out_chs=(728, 1024, 1024), stride=2),
  405. dict(in_chs=1024, out_chs=(1536, 1536, 2048), no_skip=True, stride=1),
  406. ]
  407. model_args = dict(block_cfg=block_cfg, preact=True, norm_layer=nn.BatchNorm2d)
  408. return _xception('xception41p', pretrained=pretrained, **dict(model_args, **kwargs))
  409. @register_model
  410. def xception65p(pretrained=False, **kwargs) -> XceptionAligned:
  411. """ Modified Aligned Xception-65 w/ Pre-Act
  412. """
  413. block_cfg = [
  414. # entry flow
  415. dict(in_chs=64, out_chs=128, stride=2),
  416. dict(in_chs=128, out_chs=256, stride=2),
  417. dict(in_chs=256, out_chs=728, stride=2),
  418. # middle flow
  419. *([dict(in_chs=728, out_chs=728, stride=1)] * 16),
  420. # exit flow
  421. dict(in_chs=728, out_chs=(728, 1024, 1024), stride=2),
  422. dict(in_chs=1024, out_chs=(1536, 1536, 2048), stride=1, no_skip=True),
  423. ]
  424. model_args = dict(
  425. block_cfg=block_cfg, preact=True, norm_layer=partial(nn.BatchNorm2d, eps=.001, momentum=.1))
  426. return _xception('xception65p', pretrained=pretrained, **dict(model_args, **kwargs))