inception_v4.py 15 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444
  1. """ Pytorch Inception-V4 implementation
  2. Sourced from https://github.com/Cadene/tensorflow-model-zoo.torch (MIT License) which is
  3. based upon Google's Tensorflow implementation and pretrained weights (Apache 2.0 License)
  4. """
  5. from functools import partial
  6. from typing import List, Optional, Tuple, Union, Type
  7. import torch
  8. import torch.nn as nn
  9. from timm.data import IMAGENET_INCEPTION_MEAN, IMAGENET_INCEPTION_STD
  10. from timm.layers import create_classifier, ConvNormAct
  11. from ._builder import build_model_with_cfg
  12. from ._features import feature_take_indices
  13. from ._registry import register_model, generate_default_cfgs
  14. __all__ = ['InceptionV4']
  15. class Mixed3a(nn.Module):
  16. def __init__(
  17. self,
  18. conv_block: Type[nn.Module] = ConvNormAct,
  19. device=None,
  20. dtype=None,
  21. ):
  22. dd = {'device': device, 'dtype': dtype}
  23. super().__init__()
  24. self.maxpool = nn.MaxPool2d(3, stride=2)
  25. self.conv = conv_block(64, 96, kernel_size=3, stride=2, **dd)
  26. def forward(self, x):
  27. x0 = self.maxpool(x)
  28. x1 = self.conv(x)
  29. out = torch.cat((x0, x1), 1)
  30. return out
  31. class Mixed4a(nn.Module):
  32. def __init__(
  33. self,
  34. conv_block: Type[nn.Module] = ConvNormAct,
  35. device=None,
  36. dtype=None,
  37. ):
  38. dd = {'device': device, 'dtype': dtype}
  39. super().__init__()
  40. self.branch0 = nn.Sequential(
  41. conv_block(160, 64, kernel_size=1, stride=1, **dd),
  42. conv_block(64, 96, kernel_size=3, stride=1, **dd)
  43. )
  44. self.branch1 = nn.Sequential(
  45. conv_block(160, 64, kernel_size=1, stride=1, **dd),
  46. conv_block(64, 64, kernel_size=(1, 7), stride=1, padding=(0, 3), **dd),
  47. conv_block(64, 64, kernel_size=(7, 1), stride=1, padding=(3, 0), **dd),
  48. conv_block(64, 96, kernel_size=(3, 3), stride=1, **dd)
  49. )
  50. def forward(self, x):
  51. x0 = self.branch0(x)
  52. x1 = self.branch1(x)
  53. out = torch.cat((x0, x1), 1)
  54. return out
  55. class Mixed5a(nn.Module):
  56. def __init__(
  57. self,
  58. conv_block: Type[nn.Module] = ConvNormAct,
  59. device=None,
  60. dtype=None,
  61. ):
  62. dd = {'device': device, 'dtype': dtype}
  63. super().__init__()
  64. self.conv = conv_block(192, 192, kernel_size=3, stride=2, **dd)
  65. self.maxpool = nn.MaxPool2d(3, stride=2)
  66. def forward(self, x):
  67. x0 = self.conv(x)
  68. x1 = self.maxpool(x)
  69. out = torch.cat((x0, x1), 1)
  70. return out
  71. class InceptionA(nn.Module):
  72. def __init__(
  73. self,
  74. conv_block: Type[nn.Module] = ConvNormAct,
  75. device=None,
  76. dtype=None,
  77. ):
  78. dd = {'device': device, 'dtype': dtype}
  79. super().__init__()
  80. self.branch0 = conv_block(384, 96, kernel_size=1, stride=1, **dd)
  81. self.branch1 = nn.Sequential(
  82. conv_block(384, 64, kernel_size=1, stride=1, **dd),
  83. conv_block(64, 96, kernel_size=3, stride=1, padding=1, **dd)
  84. )
  85. self.branch2 = nn.Sequential(
  86. conv_block(384, 64, kernel_size=1, stride=1, **dd),
  87. conv_block(64, 96, kernel_size=3, stride=1, padding=1, **dd),
  88. conv_block(96, 96, kernel_size=3, stride=1, padding=1, **dd)
  89. )
  90. self.branch3 = nn.Sequential(
  91. nn.AvgPool2d(3, stride=1, padding=1, count_include_pad=False),
  92. conv_block(384, 96, kernel_size=1, stride=1, **dd)
  93. )
  94. def forward(self, x):
  95. x0 = self.branch0(x)
  96. x1 = self.branch1(x)
  97. x2 = self.branch2(x)
  98. x3 = self.branch3(x)
  99. out = torch.cat((x0, x1, x2, x3), 1)
  100. return out
  101. class ReductionA(nn.Module):
  102. def __init__(
  103. self,
  104. conv_block: Type[nn.Module] = ConvNormAct,
  105. device=None,
  106. dtype=None,
  107. ):
  108. dd = {'device': device, 'dtype': dtype}
  109. super().__init__()
  110. self.branch0 = conv_block(384, 384, kernel_size=3, stride=2, **dd)
  111. self.branch1 = nn.Sequential(
  112. conv_block(384, 192, kernel_size=1, stride=1, **dd),
  113. conv_block(192, 224, kernel_size=3, stride=1, padding=1, **dd),
  114. conv_block(224, 256, kernel_size=3, stride=2, **dd)
  115. )
  116. self.branch2 = nn.MaxPool2d(3, stride=2)
  117. def forward(self, x):
  118. x0 = self.branch0(x)
  119. x1 = self.branch1(x)
  120. x2 = self.branch2(x)
  121. out = torch.cat((x0, x1, x2), 1)
  122. return out
  123. class InceptionB(nn.Module):
  124. def __init__(
  125. self,
  126. conv_block: Type[nn.Module] = ConvNormAct,
  127. device=None,
  128. dtype=None,
  129. ):
  130. dd = {'device': device, 'dtype': dtype}
  131. super().__init__()
  132. self.branch0 = conv_block(1024, 384, kernel_size=1, stride=1, **dd)
  133. self.branch1 = nn.Sequential(
  134. conv_block(1024, 192, kernel_size=1, stride=1, **dd),
  135. conv_block(192, 224, kernel_size=(1, 7), stride=1, padding=(0, 3), **dd),
  136. conv_block(224, 256, kernel_size=(7, 1), stride=1, padding=(3, 0), **dd)
  137. )
  138. self.branch2 = nn.Sequential(
  139. conv_block(1024, 192, kernel_size=1, stride=1, **dd),
  140. conv_block(192, 192, kernel_size=(7, 1), stride=1, padding=(3, 0), **dd),
  141. conv_block(192, 224, kernel_size=(1, 7), stride=1, padding=(0, 3), **dd),
  142. conv_block(224, 224, kernel_size=(7, 1), stride=1, padding=(3, 0), **dd),
  143. conv_block(224, 256, kernel_size=(1, 7), stride=1, padding=(0, 3), **dd)
  144. )
  145. self.branch3 = nn.Sequential(
  146. nn.AvgPool2d(3, stride=1, padding=1, count_include_pad=False),
  147. conv_block(1024, 128, kernel_size=1, stride=1, **dd)
  148. )
  149. def forward(self, x):
  150. x0 = self.branch0(x)
  151. x1 = self.branch1(x)
  152. x2 = self.branch2(x)
  153. x3 = self.branch3(x)
  154. out = torch.cat((x0, x1, x2, x3), 1)
  155. return out
  156. class ReductionB(nn.Module):
  157. def __init__(
  158. self,
  159. conv_block: Type[nn.Module] = ConvNormAct,
  160. device=None,
  161. dtype=None,
  162. ):
  163. dd = {'device': device, 'dtype': dtype}
  164. super().__init__()
  165. self.branch0 = nn.Sequential(
  166. conv_block(1024, 192, kernel_size=1, stride=1, **dd),
  167. conv_block(192, 192, kernel_size=3, stride=2, **dd)
  168. )
  169. self.branch1 = nn.Sequential(
  170. conv_block(1024, 256, kernel_size=1, stride=1, **dd),
  171. conv_block(256, 256, kernel_size=(1, 7), stride=1, padding=(0, 3), **dd),
  172. conv_block(256, 320, kernel_size=(7, 1), stride=1, padding=(3, 0), **dd),
  173. conv_block(320, 320, kernel_size=3, stride=2, **dd)
  174. )
  175. self.branch2 = nn.MaxPool2d(3, stride=2)
  176. def forward(self, x):
  177. x0 = self.branch0(x)
  178. x1 = self.branch1(x)
  179. x2 = self.branch2(x)
  180. out = torch.cat((x0, x1, x2), 1)
  181. return out
  182. class InceptionC(nn.Module):
  183. def __init__(
  184. self,
  185. conv_block: Type[nn.Module] = ConvNormAct,
  186. device=None,
  187. dtype=None,
  188. ):
  189. dd = {'device': device, 'dtype': dtype}
  190. super().__init__()
  191. self.branch0 = conv_block(1536, 256, kernel_size=1, stride=1, **dd)
  192. self.branch1_0 = conv_block(1536, 384, kernel_size=1, stride=1, **dd)
  193. self.branch1_1a = conv_block(384, 256, kernel_size=(1, 3), stride=1, padding=(0, 1), **dd)
  194. self.branch1_1b = conv_block(384, 256, kernel_size=(3, 1), stride=1, padding=(1, 0), **dd)
  195. self.branch2_0 = conv_block(1536, 384, kernel_size=1, stride=1, **dd)
  196. self.branch2_1 = conv_block(384, 448, kernel_size=(3, 1), stride=1, padding=(1, 0), **dd)
  197. self.branch2_2 = conv_block(448, 512, kernel_size=(1, 3), stride=1, padding=(0, 1), **dd)
  198. self.branch2_3a = conv_block(512, 256, kernel_size=(1, 3), stride=1, padding=(0, 1), **dd)
  199. self.branch2_3b = conv_block(512, 256, kernel_size=(3, 1), stride=1, padding=(1, 0), **dd)
  200. self.branch3 = nn.Sequential(
  201. nn.AvgPool2d(3, stride=1, padding=1, count_include_pad=False),
  202. conv_block(1536, 256, kernel_size=1, stride=1, **dd)
  203. )
  204. def forward(self, x):
  205. x0 = self.branch0(x)
  206. x1_0 = self.branch1_0(x)
  207. x1_1a = self.branch1_1a(x1_0)
  208. x1_1b = self.branch1_1b(x1_0)
  209. x1 = torch.cat((x1_1a, x1_1b), 1)
  210. x2_0 = self.branch2_0(x)
  211. x2_1 = self.branch2_1(x2_0)
  212. x2_2 = self.branch2_2(x2_1)
  213. x2_3a = self.branch2_3a(x2_2)
  214. x2_3b = self.branch2_3b(x2_2)
  215. x2 = torch.cat((x2_3a, x2_3b), 1)
  216. x3 = self.branch3(x)
  217. out = torch.cat((x0, x1, x2, x3), 1)
  218. return out
  219. class InceptionV4(nn.Module):
  220. def __init__(
  221. self,
  222. num_classes: int = 1000,
  223. in_chans: int = 3,
  224. output_stride: int = 32,
  225. drop_rate: float = 0.,
  226. global_pool: str = 'avg',
  227. norm_layer: str = 'batchnorm2d',
  228. norm_eps: float = 1e-3,
  229. act_layer: str = 'relu',
  230. device=None,
  231. dtype=None,
  232. ) -> None:
  233. dd = {'device': device, 'dtype': dtype}
  234. super().__init__()
  235. assert output_stride == 32
  236. self.num_classes = num_classes
  237. self.num_features = self.head_hidden_size = 1536
  238. conv_block = partial(
  239. ConvNormAct,
  240. padding=0,
  241. norm_layer=norm_layer,
  242. act_layer=act_layer,
  243. norm_kwargs=dict(eps=norm_eps),
  244. act_kwargs=dict(inplace=True),
  245. )
  246. features = [
  247. conv_block(in_chans, 32, kernel_size=3, stride=2, **dd),
  248. conv_block(32, 32, kernel_size=3, stride=1, **dd),
  249. conv_block(32, 64, kernel_size=3, stride=1, padding=1, **dd),
  250. Mixed3a(conv_block, **dd),
  251. Mixed4a(conv_block, **dd),
  252. Mixed5a(conv_block, **dd),
  253. ]
  254. features += [InceptionA(conv_block, **dd) for _ in range(4)]
  255. features += [ReductionA(conv_block, **dd)] # Mixed6a
  256. features += [InceptionB(conv_block, **dd) for _ in range(7)]
  257. features += [ReductionB(conv_block, **dd)] # Mixed7a
  258. features += [InceptionC(conv_block, **dd) for _ in range(3)]
  259. self.features = nn.Sequential(*features)
  260. self.feature_info = [
  261. dict(num_chs=64, reduction=2, module='features.2'),
  262. dict(num_chs=160, reduction=4, module='features.3'),
  263. dict(num_chs=384, reduction=8, module='features.9'),
  264. dict(num_chs=1024, reduction=16, module='features.17'),
  265. dict(num_chs=1536, reduction=32, module='features.21'),
  266. ]
  267. self.global_pool, self.head_drop, self.last_linear = create_classifier(
  268. self.num_features,
  269. self.num_classes,
  270. pool_type=global_pool,
  271. drop_rate=drop_rate,
  272. **dd,
  273. )
  274. @torch.jit.ignore
  275. def group_matcher(self, coarse=False):
  276. return dict(
  277. stem=r'^features\.[012]\.',
  278. blocks=r'^features\.(\d+)'
  279. )
  280. @torch.jit.ignore
  281. def set_grad_checkpointing(self, enable=True):
  282. assert not enable, 'gradient checkpointing not supported'
  283. @torch.jit.ignore
  284. def get_classifier(self) -> nn.Module:
  285. return self.last_linear
  286. def reset_classifier(self, num_classes: int, global_pool: str = 'avg'):
  287. self.num_classes = num_classes
  288. self.global_pool, self.last_linear = create_classifier(
  289. self.num_features, self.num_classes, pool_type=global_pool)
  290. def forward_intermediates(
  291. self,
  292. x: torch.Tensor,
  293. indices: Optional[Union[int, List[int]]] = None,
  294. norm: bool = False,
  295. stop_early: bool = False,
  296. output_fmt: str = 'NCHW',
  297. intermediates_only: bool = False,
  298. ) -> Union[List[torch.Tensor], Tuple[torch.Tensor, List[torch.Tensor]]]:
  299. """ Forward features that returns intermediates.
  300. Args:
  301. x: Input image tensor
  302. indices: Take last n blocks if int, all if None, select matching indices if sequence
  303. norm: Apply norm layer to compatible intermediates
  304. stop_early: Stop iterating over blocks when last desired intermediate hit
  305. output_fmt: Shape of intermediate feature outputs
  306. intermediates_only: Only return intermediate features
  307. Returns:
  308. """
  309. assert output_fmt in ('NCHW',), 'Output shape must be NCHW.'
  310. intermediates = []
  311. stage_ends = [int(info['module'].split('.')[-1]) for info in self.feature_info]
  312. take_indices, max_index = feature_take_indices(len(stage_ends), indices)
  313. take_indices = [stage_ends[i] for i in take_indices]
  314. max_index = stage_ends[max_index]
  315. # forward pass
  316. if torch.jit.is_scripting() or not stop_early: # can't slice blocks in torchscript
  317. stages = self.features
  318. else:
  319. stages = self.features[:max_index + 1]
  320. for feat_idx, stage in enumerate(stages):
  321. x = stage(x)
  322. if feat_idx in take_indices:
  323. intermediates.append(x)
  324. if intermediates_only:
  325. return intermediates
  326. return x, intermediates
  327. def prune_intermediate_layers(
  328. self,
  329. indices: Union[int, List[int]] = 1,
  330. prune_norm: bool = False,
  331. prune_head: bool = True,
  332. ):
  333. """ Prune layers not required for specified intermediates.
  334. """
  335. stage_ends = [int(info['module'].split('.')[-1]) for info in self.feature_info]
  336. take_indices, max_index = feature_take_indices(len(stage_ends), indices)
  337. max_index = stage_ends[max_index]
  338. self.features = self.features[:max_index + 1] # truncate blocks w/ stem as idx 0
  339. if prune_head:
  340. self.reset_classifier(0, '')
  341. return take_indices
  342. def forward_features(self, x):
  343. return self.features(x)
  344. def forward_head(self, x, pre_logits: bool = False):
  345. x = self.global_pool(x)
  346. x = self.head_drop(x)
  347. return x if pre_logits else self.last_linear(x)
  348. def forward(self, x):
  349. x = self.forward_features(x)
  350. x = self.forward_head(x)
  351. return x
  352. def _create_inception_v4(variant, pretrained=False, **kwargs) -> InceptionV4:
  353. return build_model_with_cfg(
  354. InceptionV4,
  355. variant,
  356. pretrained,
  357. feature_cfg=dict(flatten_sequential=True),
  358. **kwargs,
  359. )
  360. default_cfgs = generate_default_cfgs({
  361. 'inception_v4.tf_in1k': {
  362. 'hf_hub_id': 'timm/',
  363. 'num_classes': 1000, 'input_size': (3, 299, 299), 'pool_size': (8, 8),
  364. 'crop_pct': 0.875, 'interpolation': 'bicubic',
  365. 'mean': IMAGENET_INCEPTION_MEAN, 'std': IMAGENET_INCEPTION_STD,
  366. 'first_conv': 'features.0.conv', 'classifier': 'last_linear',
  367. 'license': 'apache-2.0',
  368. }
  369. })
  370. @register_model
  371. def inception_v4(pretrained=False, **kwargs):
  372. return _create_inception_v4('inception_v4', pretrained, **kwargs)