dpn.py 14 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406
  1. """ PyTorch implementation of DualPathNetworks
  2. Based on original MXNet implementation https://github.com/cypw/DPNs with
  3. many ideas from another PyTorch implementation https://github.com/oyam/pytorch-DPNs.
  4. This implementation is compatible with the pretrained weights from cypw's MXNet implementation.
  5. Hacked together by / Copyright 2020 Ross Wightman
  6. """
  7. from collections import OrderedDict
  8. from functools import partial
  9. from typing import Tuple, Type, Optional
  10. import torch
  11. import torch.nn as nn
  12. import torch.nn.functional as F
  13. from timm.data import IMAGENET_DPN_MEAN, IMAGENET_DPN_STD, IMAGENET_DEFAULT_MEAN, IMAGENET_DEFAULT_STD
  14. from timm.layers import BatchNormAct2d, ConvNormAct, create_conv2d, create_classifier, get_norm_act_layer
  15. from ._builder import build_model_with_cfg
  16. from ._registry import register_model, generate_default_cfgs
  17. __all__ = ['DPN']
  18. class CatBnAct(nn.Module):
  19. def __init__(
  20. self,
  21. in_chs: int,
  22. norm_layer: Type[nn.Module] = BatchNormAct2d,
  23. device=None,
  24. dtype=None,
  25. ):
  26. dd = {'device': device, 'dtype': dtype}
  27. super().__init__()
  28. self.bn = norm_layer(in_chs, eps=0.001, **dd)
  29. @torch.jit._overload_method # noqa: F811
  30. def forward(self, x):
  31. # type: (Tuple[torch.Tensor, torch.Tensor]) -> (torch.Tensor)
  32. pass
  33. @torch.jit._overload_method # noqa: F811
  34. def forward(self, x):
  35. # type: (torch.Tensor) -> (torch.Tensor)
  36. pass
  37. def forward(self, x):
  38. if isinstance(x, tuple):
  39. x = torch.cat(x, dim=1)
  40. return self.bn(x)
  41. class BnActConv2d(nn.Module):
  42. def __init__(
  43. self,
  44. in_chs: int,
  45. out_chs: int,
  46. kernel_size: int,
  47. stride: int,
  48. groups: int = 1,
  49. norm_layer: Type[nn.Module] = BatchNormAct2d,
  50. device=None,
  51. dtype=None,
  52. ):
  53. dd = {'device': device, 'dtype': dtype}
  54. super().__init__()
  55. self.bn = norm_layer(in_chs, eps=0.001, **dd)
  56. self.conv = create_conv2d(in_chs, out_chs, kernel_size, stride=stride, groups=groups, **dd)
  57. def forward(self, x):
  58. return self.conv(self.bn(x))
  59. class DualPathBlock(nn.Module):
  60. def __init__(
  61. self,
  62. in_chs: int,
  63. num_1x1_a: int,
  64. num_3x3_b: int,
  65. num_1x1_c: int,
  66. inc: int,
  67. groups: int,
  68. block_type: str = 'normal',
  69. b: bool = False,
  70. device=None,
  71. dtype=None,
  72. ):
  73. dd = {'device': device, 'dtype': dtype}
  74. super().__init__()
  75. self.num_1x1_c = num_1x1_c
  76. self.inc = inc
  77. self.b = b
  78. if block_type == 'proj':
  79. self.key_stride = 1
  80. self.has_proj = True
  81. elif block_type == 'down':
  82. self.key_stride = 2
  83. self.has_proj = True
  84. else:
  85. assert block_type == 'normal'
  86. self.key_stride = 1
  87. self.has_proj = False
  88. self.c1x1_w_s1 = None
  89. self.c1x1_w_s2 = None
  90. if self.has_proj:
  91. # Using different member names here to allow easier parameter key matching for conversion
  92. if self.key_stride == 2:
  93. self.c1x1_w_s2 = BnActConv2d(
  94. in_chs=in_chs, out_chs=num_1x1_c + 2 * inc, kernel_size=1, stride=2, **dd)
  95. else:
  96. self.c1x1_w_s1 = BnActConv2d(
  97. in_chs=in_chs, out_chs=num_1x1_c + 2 * inc, kernel_size=1, stride=1, **dd)
  98. self.c1x1_a = BnActConv2d(in_chs=in_chs, out_chs=num_1x1_a, kernel_size=1, stride=1, **dd)
  99. self.c3x3_b = BnActConv2d(
  100. in_chs=num_1x1_a, out_chs=num_3x3_b, kernel_size=3, stride=self.key_stride, groups=groups, **dd)
  101. if b:
  102. self.c1x1_c = CatBnAct(in_chs=num_3x3_b, **dd)
  103. self.c1x1_c1 = create_conv2d(num_3x3_b, num_1x1_c, kernel_size=1, **dd)
  104. self.c1x1_c2 = create_conv2d(num_3x3_b, inc, kernel_size=1, **dd)
  105. else:
  106. self.c1x1_c = BnActConv2d(in_chs=num_3x3_b, out_chs=num_1x1_c + inc, kernel_size=1, stride=1, **dd)
  107. self.c1x1_c1 = None
  108. self.c1x1_c2 = None
  109. @torch.jit._overload_method # noqa: F811
  110. def forward(self, x):
  111. # type: (Tuple[torch.Tensor, torch.Tensor]) -> Tuple[torch.Tensor, torch.Tensor]
  112. pass
  113. @torch.jit._overload_method # noqa: F811
  114. def forward(self, x):
  115. # type: (torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor]
  116. pass
  117. def forward(self, x) -> Tuple[torch.Tensor, torch.Tensor]:
  118. if isinstance(x, tuple):
  119. x_in = torch.cat(x, dim=1)
  120. else:
  121. x_in = x
  122. if self.c1x1_w_s1 is None and self.c1x1_w_s2 is None:
  123. # self.has_proj == False, torchscript requires condition on module == None
  124. x_s1 = x[0]
  125. x_s2 = x[1]
  126. else:
  127. # self.has_proj == True
  128. if self.c1x1_w_s1 is not None:
  129. # self.key_stride = 1
  130. x_s = self.c1x1_w_s1(x_in)
  131. else:
  132. # self.key_stride = 2
  133. x_s = self.c1x1_w_s2(x_in)
  134. x_s1 = x_s[:, :self.num_1x1_c, :, :]
  135. x_s2 = x_s[:, self.num_1x1_c:, :, :]
  136. x_in = self.c1x1_a(x_in)
  137. x_in = self.c3x3_b(x_in)
  138. x_in = self.c1x1_c(x_in)
  139. if self.c1x1_c1 is not None:
  140. # self.b == True, using None check for torchscript compat
  141. out1 = self.c1x1_c1(x_in)
  142. out2 = self.c1x1_c2(x_in)
  143. else:
  144. out1 = x_in[:, :self.num_1x1_c, :, :]
  145. out2 = x_in[:, self.num_1x1_c:, :, :]
  146. resid = x_s1 + out1
  147. dense = torch.cat([x_s2, out2], dim=1)
  148. return resid, dense
  149. class DPN(nn.Module):
  150. def __init__(
  151. self,
  152. k_sec: Tuple[int, ...] = (3, 4, 20, 3),
  153. inc_sec: Tuple[int, ...] = (16, 32, 24, 128),
  154. k_r: int = 96,
  155. groups: int = 32,
  156. num_classes: int = 1000,
  157. in_chans: int = 3,
  158. output_stride: int = 32,
  159. global_pool: str = 'avg',
  160. small: bool = False,
  161. num_init_features: int = 64,
  162. b: bool = False,
  163. drop_rate: float = 0.,
  164. norm_layer: str = 'batchnorm2d',
  165. act_layer: str = 'relu',
  166. fc_act_layer: str = 'elu',
  167. device=None,
  168. dtype=None,
  169. ):
  170. super().__init__()
  171. dd = {'device': device, 'dtype': dtype}
  172. self.num_classes = num_classes
  173. self.drop_rate = drop_rate
  174. self.b = b
  175. assert output_stride == 32 # FIXME look into dilation support
  176. norm_layer = partial(get_norm_act_layer(norm_layer, act_layer=act_layer), eps=.001)
  177. fc_norm_layer = partial(get_norm_act_layer(norm_layer, act_layer=fc_act_layer), eps=.001, inplace=False)
  178. bw_factor = 1 if small else 4
  179. blocks = OrderedDict()
  180. # conv1
  181. blocks['conv1_1'] = ConvNormAct(
  182. in_chans,
  183. num_init_features,
  184. kernel_size=3 if small else 7,
  185. stride=2,
  186. norm_layer=norm_layer,
  187. **dd,
  188. )
  189. blocks['conv1_pool'] = nn.MaxPool2d(kernel_size=3, stride=2, padding=1)
  190. self.feature_info = [dict(num_chs=num_init_features, reduction=2, module='features.conv1_1')]
  191. # conv2
  192. bw = 64 * bw_factor
  193. inc = inc_sec[0]
  194. r = (k_r * bw) // (64 * bw_factor)
  195. blocks['conv2_1'] = DualPathBlock(num_init_features, r, r, bw, inc, groups, 'proj', b, **dd)
  196. in_chs = bw + 3 * inc
  197. for i in range(2, k_sec[0] + 1):
  198. blocks['conv2_' + str(i)] = DualPathBlock(in_chs, r, r, bw, inc, groups, 'normal', b, **dd)
  199. in_chs += inc
  200. self.feature_info += [dict(num_chs=in_chs, reduction=4, module=f'features.conv2_{k_sec[0]}')]
  201. # conv3
  202. bw = 128 * bw_factor
  203. inc = inc_sec[1]
  204. r = (k_r * bw) // (64 * bw_factor)
  205. blocks['conv3_1'] = DualPathBlock(in_chs, r, r, bw, inc, groups, 'down', b, **dd)
  206. in_chs = bw + 3 * inc
  207. for i in range(2, k_sec[1] + 1):
  208. blocks['conv3_' + str(i)] = DualPathBlock(in_chs, r, r, bw, inc, groups, 'normal', b, **dd)
  209. in_chs += inc
  210. self.feature_info += [dict(num_chs=in_chs, reduction=8, module=f'features.conv3_{k_sec[1]}')]
  211. # conv4
  212. bw = 256 * bw_factor
  213. inc = inc_sec[2]
  214. r = (k_r * bw) // (64 * bw_factor)
  215. blocks['conv4_1'] = DualPathBlock(in_chs, r, r, bw, inc, groups, 'down', b, **dd)
  216. in_chs = bw + 3 * inc
  217. for i in range(2, k_sec[2] + 1):
  218. blocks['conv4_' + str(i)] = DualPathBlock(in_chs, r, r, bw, inc, groups, 'normal', b, **dd)
  219. in_chs += inc
  220. self.feature_info += [dict(num_chs=in_chs, reduction=16, module=f'features.conv4_{k_sec[2]}')]
  221. # conv5
  222. bw = 512 * bw_factor
  223. inc = inc_sec[3]
  224. r = (k_r * bw) // (64 * bw_factor)
  225. blocks['conv5_1'] = DualPathBlock(in_chs, r, r, bw, inc, groups, 'down', b, **dd)
  226. in_chs = bw + 3 * inc
  227. for i in range(2, k_sec[3] + 1):
  228. blocks['conv5_' + str(i)] = DualPathBlock(in_chs, r, r, bw, inc, groups, 'normal', b, **dd)
  229. in_chs += inc
  230. self.feature_info += [dict(num_chs=in_chs, reduction=32, module=f'features.conv5_{k_sec[3]}')]
  231. blocks['conv5_bn_ac'] = CatBnAct(in_chs, norm_layer=fc_norm_layer, **dd)
  232. self.num_features = self.head_hidden_size = in_chs
  233. self.features = nn.Sequential(blocks)
  234. # Using 1x1 conv for the FC layer to allow the extra pooling scheme
  235. self.global_pool, self.classifier = create_classifier(
  236. self.num_features,
  237. self.num_classes,
  238. pool_type=global_pool,
  239. use_conv=True,
  240. **dd,
  241. )
  242. self.flatten = nn.Flatten(1) if global_pool else nn.Identity()
  243. @torch.jit.ignore
  244. def group_matcher(self, coarse=False):
  245. matcher = dict(
  246. stem=r'^features\.conv1',
  247. blocks=[
  248. (r'^features\.conv(\d+)' if coarse else r'^features\.conv(\d+)_(\d+)', None),
  249. (r'^features\.conv5_bn_ac', (99999,))
  250. ]
  251. )
  252. return matcher
  253. @torch.jit.ignore
  254. def set_grad_checkpointing(self, enable=True):
  255. assert not enable, 'gradient checkpointing not supported'
  256. @torch.jit.ignore
  257. def get_classifier(self) -> nn.Module:
  258. return self.classifier
  259. def reset_classifier(self, num_classes: int, global_pool: str = 'avg'):
  260. self.num_classes = num_classes
  261. self.global_pool, self.classifier = create_classifier(
  262. self.num_features, self.num_classes, pool_type=global_pool, use_conv=True)
  263. self.flatten = nn.Flatten(1) if global_pool else nn.Identity()
  264. def forward_features(self, x):
  265. return self.features(x)
  266. def forward_head(self, x, pre_logits: bool = False):
  267. x = self.global_pool(x)
  268. if self.drop_rate > 0.:
  269. x = F.dropout(x, p=self.drop_rate, training=self.training)
  270. if pre_logits:
  271. return self.flatten(x)
  272. x = self.classifier(x)
  273. return self.flatten(x)
  274. def forward(self, x):
  275. x = self.forward_features(x)
  276. x = self.forward_head(x)
  277. return x
  278. def _create_dpn(variant, pretrained=False, **kwargs):
  279. return build_model_with_cfg(
  280. DPN,
  281. variant,
  282. pretrained,
  283. feature_cfg=dict(feature_concat=True, flatten_sequential=True),
  284. **kwargs,
  285. )
  286. def _cfg(url='', **kwargs):
  287. return {
  288. 'url': url, 'num_classes': 1000, 'input_size': (3, 224, 224), 'pool_size': (7, 7),
  289. 'crop_pct': 0.875, 'interpolation': 'bicubic',
  290. 'mean': IMAGENET_DPN_MEAN, 'std': IMAGENET_DPN_STD,
  291. 'first_conv': 'features.conv1_1.conv', 'classifier': 'classifier', 'license': 'apache-2.0',
  292. **kwargs
  293. }
  294. default_cfgs = generate_default_cfgs({
  295. 'dpn48b.untrained': _cfg(mean=IMAGENET_DEFAULT_MEAN, std=IMAGENET_DEFAULT_STD),
  296. 'dpn68.mx_in1k': _cfg(hf_hub_id='timm/'),
  297. 'dpn68b.ra_in1k': _cfg(
  298. hf_hub_id='timm/',
  299. mean=IMAGENET_DEFAULT_MEAN, std=IMAGENET_DEFAULT_STD,
  300. crop_pct=0.95, test_input_size=(3, 288, 288), test_crop_pct=1.0),
  301. 'dpn68b.mx_in1k': _cfg(hf_hub_id='timm/'),
  302. 'dpn92.mx_in1k': _cfg(hf_hub_id='timm/'),
  303. 'dpn98.mx_in1k': _cfg(hf_hub_id='timm/'),
  304. 'dpn131.mx_in1k': _cfg(hf_hub_id='timm/'),
  305. 'dpn107.mx_in1k': _cfg(hf_hub_id='timm/')
  306. })
  307. @register_model
  308. def dpn48b(pretrained=False, **kwargs) -> DPN:
  309. model_args = dict(
  310. small=True, num_init_features=10, k_r=128, groups=32,
  311. b=True, k_sec=(3, 4, 6, 3), inc_sec=(16, 32, 32, 64), act_layer='silu')
  312. return _create_dpn('dpn48b', pretrained=pretrained, **dict(model_args, **kwargs))
  313. @register_model
  314. def dpn68(pretrained=False, **kwargs) -> DPN:
  315. model_args = dict(
  316. small=True, num_init_features=10, k_r=128, groups=32,
  317. k_sec=(3, 4, 12, 3), inc_sec=(16, 32, 32, 64))
  318. return _create_dpn('dpn68', pretrained=pretrained, **dict(model_args, **kwargs))
  319. @register_model
  320. def dpn68b(pretrained=False, **kwargs) -> DPN:
  321. model_args = dict(
  322. small=True, num_init_features=10, k_r=128, groups=32,
  323. b=True, k_sec=(3, 4, 12, 3), inc_sec=(16, 32, 32, 64))
  324. return _create_dpn('dpn68b', pretrained=pretrained, **dict(model_args, **kwargs))
  325. @register_model
  326. def dpn92(pretrained=False, **kwargs) -> DPN:
  327. model_args = dict(
  328. num_init_features=64, k_r=96, groups=32,
  329. k_sec=(3, 4, 20, 3), inc_sec=(16, 32, 24, 128))
  330. return _create_dpn('dpn92', pretrained=pretrained, **dict(model_args, **kwargs))
  331. @register_model
  332. def dpn98(pretrained=False, **kwargs) -> DPN:
  333. model_args = dict(
  334. num_init_features=96, k_r=160, groups=40,
  335. k_sec=(3, 6, 20, 3), inc_sec=(16, 32, 32, 128))
  336. return _create_dpn('dpn98', pretrained=pretrained, **dict(model_args, **kwargs))
  337. @register_model
  338. def dpn131(pretrained=False, **kwargs) -> DPN:
  339. model_args = dict(
  340. num_init_features=128, k_r=160, groups=40,
  341. k_sec=(4, 8, 28, 3), inc_sec=(16, 32, 32, 128))
  342. return _create_dpn('dpn131', pretrained=pretrained, **dict(model_args, **kwargs))
  343. @register_model
  344. def dpn107(pretrained=False, **kwargs) -> DPN:
  345. model_args = dict(
  346. num_init_features=128, k_r=200, groups=50,
  347. k_sec=(4, 8, 20, 3), inc_sec=(20, 64, 64, 128))
  348. return _create_dpn('dpn107', pretrained=pretrained, **dict(model_args, **kwargs))