ghostnet.py 36 KB

1234567891011121314151617181920212223242526272829303132333435363738394041424344454647484950515253545556575859606162636465666768697071727374757677787980818283848586878889909192939495969798991001011021031041051061071081091101111121131141151161171181191201211221231241251261271281291301311321331341351361371381391401411421431441451461471481491501511521531541551561571581591601611621631641651661671681691701711721731741751761771781791801811821831841851861871881891901911921931941951961971981992002012022032042052062072082092102112122132142152162172182192202212222232242252262272282292302312322332342352362372382392402412422432442452462472482492502512522532542552562572582592602612622632642652662672682692702712722732742752762772782792802812822832842852862872882892902912922932942952962972982993003013023033043053063073083093103113123133143153163173183193203213223233243253263273283293303313323333343353363373383393403413423433443453463473483493503513523533543553563573583593603613623633643653663673683693703713723733743753763773783793803813823833843853863873883893903913923933943953963973983994004014024034044054064074084094104114124134144154164174184194204214224234244254264274284294304314324334344354364374384394404414424434444454464474484494504514524534544554564574584594604614624634644654664674684694704714724734744754764774784794804814824834844854864874884894904914924934944954964974984995005015025035045055065075085095105115125135145155165175185195205215225235245255265275285295305315325335345355365375385395405415425435445455465475485495505515525535545555565575585595605615625635645655665675685695705715725735745755765775785795805815825835845855865875885895905915925935945955965975985996006016026036046056066076086096106116126136146156166176186196206216226236246256266276286296306316326336346356366376386396406416426436446456466476486496506516526536546556566576586596606616626636646656666676686696706716726736746756766776786796806816826836846856866876886896906916926936946956966976986997007017027037047057067077087097107117127137147157167177187197207217227237247257267277287297307317327337347357367377387397407417427437447457467477487497507517527537547557567577587597607617627637647657667677687697707717727737747757767777787797807817827837847857867877887897907917927937947957967977987998008018028038048058068078088098108118128138148158168178188198208218228238248258268278288298308318328338348358368378388398408418428438448458468478488498508518528538548558568578588598608618628638648658668678688698708718728738748758768778788798808818828838848858868878888898908918928938948958968978988999009019029039049059069079089099109119129139149159169179189199209219229239249259269279289299309319329339349359369379389399409419429439449459469479489499509519529539549559569579589599609619629639649659669679689699709719729739749759769779789799809819829839849859869879889899909919929939949959969979989991000100110021003100410051006100710081009101010111012101310141015
  1. """
  2. An implementation of GhostNet & GhostNetV2 Models as defined in:
  3. GhostNet: More Features from Cheap Operations. https://arxiv.org/abs/1911.11907
  4. GhostNetV2: Enhance Cheap Operation with Long-Range Attention. https://proceedings.neurips.cc/paper_files/paper/2022/file/40b60852a4abdaa696b5a1a78da34635-Paper-Conference.pdf
  5. GhostNetV3: Exploring the Training Strategies for Compact Models. https://arxiv.org/abs/2404.11202
  6. The train script & code of models at:
  7. Original model: https://github.com/huawei-noah/CV-backbones/tree/master/ghostnet_pytorch
  8. Original model: https://github.com/huawei-noah/Efficient-AI-Backbones/blob/master/ghostnetv2_pytorch/model/ghostnetv2_torch.py
  9. Original model: https://github.com/huawei-noah/Efficient-AI-Backbones/blob/master/ghostnetv3_pytorch/ghostnetv3.py
  10. """
  11. import math
  12. from functools import partial
  13. from typing import Any, Dict, List, Set, Optional, Tuple, Union, Type
  14. import torch
  15. import torch.nn as nn
  16. import torch.nn.functional as F
  17. from timm.data import IMAGENET_DEFAULT_MEAN, IMAGENET_DEFAULT_STD
  18. from timm.layers import SelectAdaptivePool2d, Linear, make_divisible
  19. from timm.utils.model import reparameterize_model
  20. from ._builder import build_model_with_cfg
  21. from ._efficientnet_blocks import SqueezeExcite, ConvBnAct
  22. from ._features import feature_take_indices
  23. from ._manipulate import checkpoint_seq
  24. from ._registry import register_model, generate_default_cfgs
  25. __all__ = ['GhostNet']
  26. _SE_LAYER = partial(SqueezeExcite, gate_layer='hard_sigmoid', rd_round_fn=partial(make_divisible, divisor=4))
  27. class GhostModule(nn.Module):
  28. def __init__(
  29. self,
  30. in_chs: int,
  31. out_chs: int,
  32. kernel_size: int = 1,
  33. ratio: int = 2,
  34. dw_size: int = 3,
  35. stride: int = 1,
  36. act_layer: Type[nn.Module] = nn.ReLU,
  37. device=None,
  38. dtype=None,
  39. ):
  40. dd = {'device': device, 'dtype': dtype}
  41. super().__init__()
  42. self.out_chs = out_chs
  43. init_chs = math.ceil(out_chs / ratio)
  44. new_chs = init_chs * (ratio - 1)
  45. self.primary_conv = nn.Sequential(
  46. nn.Conv2d(in_chs, init_chs, kernel_size, stride, kernel_size // 2, bias=False, **dd),
  47. nn.BatchNorm2d(init_chs, **dd),
  48. act_layer(inplace=True),
  49. )
  50. self.cheap_operation = nn.Sequential(
  51. nn.Conv2d(init_chs, new_chs, dw_size, 1, dw_size//2, groups=init_chs, bias=False, **dd),
  52. nn.BatchNorm2d(new_chs, **dd),
  53. act_layer(inplace=True),
  54. )
  55. def forward(self, x: torch.Tensor) -> torch.Tensor:
  56. x1 = self.primary_conv(x)
  57. x2 = self.cheap_operation(x1)
  58. out = torch.cat([x1, x2], dim=1)
  59. return out[:, :self.out_chs, :, :]
  60. class GhostModuleV2(nn.Module):
  61. def __init__(
  62. self,
  63. in_chs: int,
  64. out_chs: int,
  65. kernel_size: int = 1,
  66. ratio: int = 2,
  67. dw_size: int = 3,
  68. stride: int = 1,
  69. act_layer: Type[nn.Module] = nn.ReLU,
  70. device=None,
  71. dtype=None,
  72. ):
  73. dd = {'device': device, 'dtype': dtype}
  74. super().__init__()
  75. self.gate_fn = nn.Sigmoid()
  76. self.out_chs = out_chs
  77. init_chs = math.ceil(out_chs / ratio)
  78. new_chs = init_chs * (ratio - 1)
  79. self.primary_conv = nn.Sequential(
  80. nn.Conv2d(in_chs, init_chs, kernel_size, stride, kernel_size // 2, bias=False, **dd),
  81. nn.BatchNorm2d(init_chs, **dd),
  82. act_layer(inplace=True),
  83. )
  84. self.cheap_operation = nn.Sequential(
  85. nn.Conv2d(init_chs, new_chs, dw_size, 1, dw_size // 2, groups=init_chs, bias=False, **dd),
  86. nn.BatchNorm2d(new_chs, **dd),
  87. act_layer(inplace=True),
  88. )
  89. self.short_conv = nn.Sequential(
  90. nn.Conv2d(in_chs, out_chs, kernel_size, stride, kernel_size // 2, bias=False, **dd),
  91. nn.BatchNorm2d(out_chs, **dd),
  92. nn.Conv2d(out_chs, out_chs, kernel_size=(1, 5), stride=1, padding=(0, 2), groups=out_chs, bias=False, **dd),
  93. nn.BatchNorm2d(out_chs, **dd),
  94. nn.Conv2d(out_chs, out_chs, kernel_size=(5, 1), stride=1, padding=(2, 0), groups=out_chs, bias=False, **dd),
  95. nn.BatchNorm2d(out_chs, **dd),
  96. )
  97. def forward(self, x: torch.Tensor) -> torch.Tensor:
  98. res = self.short_conv(F.avg_pool2d(x, kernel_size=2, stride=2))
  99. x1 = self.primary_conv(x)
  100. x2 = self.cheap_operation(x1)
  101. out = torch.cat([x1, x2], dim=1)
  102. return out[:, :self.out_chs, :, :] * F.interpolate(
  103. self.gate_fn(res), size=(out.shape[-2], out.shape[-1]), mode='nearest')
  104. class GhostModuleV3(nn.Module):
  105. def __init__(
  106. self,
  107. in_chs: int,
  108. out_chs: int,
  109. kernel_size: int = 1,
  110. ratio: int = 2,
  111. dw_size: int = 3,
  112. stride: int = 1,
  113. act_layer: Type[nn.Module] = nn.ReLU,
  114. mode: str = 'original',
  115. device=None,
  116. dtype=None,
  117. ):
  118. dd = {'device': device, 'dtype': dtype}
  119. super().__init__()
  120. self.gate_fn = nn.Sigmoid()
  121. self.out_chs = out_chs
  122. init_chs = math.ceil(out_chs / ratio)
  123. new_chs = init_chs * (ratio - 1)
  124. self.mode = mode
  125. self.num_conv_branches = 3
  126. self.infer_mode = False
  127. if not self.infer_mode:
  128. self.primary_conv = nn.Identity()
  129. self.cheap_operation = nn.Identity()
  130. self.primary_rpr_skip = None
  131. self.primary_rpr_scale = None
  132. self.primary_rpr_conv = nn.ModuleList([
  133. ConvBnAct(
  134. in_chs,
  135. init_chs,
  136. kernel_size,
  137. stride,
  138. pad_type=kernel_size // 2,
  139. act_layer=None,
  140. **dd,
  141. ) for _ in range(self.num_conv_branches)
  142. ])
  143. # Re-parameterizable scale branch
  144. self.primary_activation = act_layer(inplace=True)
  145. self.cheap_rpr_skip = nn.BatchNorm2d(init_chs, **dd)
  146. self.cheap_rpr_conv = nn.ModuleList([
  147. ConvBnAct(
  148. init_chs,
  149. new_chs,
  150. dw_size,
  151. 1,
  152. pad_type=dw_size // 2,
  153. group_size=1,
  154. act_layer=None,
  155. **dd,
  156. ) for _ in range(self.num_conv_branches)
  157. ])
  158. # Re-parameterizable scale branch
  159. self.cheap_rpr_scale = ConvBnAct(init_chs, new_chs, 1, 1, pad_type=0, group_size=1, act_layer=None, **dd)
  160. self.cheap_activation = act_layer(inplace=True)
  161. self.short_conv = nn.Sequential(
  162. nn.Conv2d(in_chs, out_chs, kernel_size, stride, kernel_size//2, bias=False, **dd),
  163. nn.BatchNorm2d(out_chs, **dd),
  164. nn.Conv2d(out_chs, out_chs, kernel_size=(1,5), stride=1, padding=(0,2), groups=out_chs, bias=False, **dd),
  165. nn.BatchNorm2d(out_chs, **dd),
  166. nn.Conv2d(out_chs, out_chs, kernel_size=(5,1), stride=1, padding=(2,0), groups=out_chs, bias=False, **dd),
  167. nn.BatchNorm2d(out_chs, **dd),
  168. ) if self.mode in ['shortcut'] else nn.Identity()
  169. self.in_channels = init_chs
  170. self.groups = init_chs
  171. self.kernel_size = dw_size
  172. def forward(self, x):
  173. if self.infer_mode:
  174. x1 = self.primary_conv(x)
  175. x2 = self.cheap_operation(x1)
  176. else:
  177. x1 = 0
  178. for primary_rpr_conv in self.primary_rpr_conv:
  179. x1 += primary_rpr_conv(x)
  180. x1 = self.primary_activation(x1)
  181. x2 = self.cheap_rpr_scale(x1) + self.cheap_rpr_skip(x1)
  182. for cheap_rpr_conv in self.cheap_rpr_conv:
  183. x2 += cheap_rpr_conv(x1)
  184. x2 = self.cheap_activation(x2)
  185. out = torch.cat([x1,x2], dim=1)
  186. if self.mode not in ['shortcut']:
  187. return out
  188. else:
  189. res = self.short_conv(F.avg_pool2d(x, kernel_size=2, stride=2))
  190. return out[:,:self.out_chs,:,:] * F.interpolate(
  191. self.gate_fn(res), size=(out.shape[-2], out.shape[-1]), mode='nearest')
  192. def _get_kernel_bias_primary(self):
  193. kernel_scale = 0
  194. bias_scale = 0
  195. if self.primary_rpr_scale is not None:
  196. kernel_scale, bias_scale = self._fuse_bn_tensor(self.primary_rpr_scale)
  197. pad = self.kernel_size // 2
  198. kernel_scale = F.pad(kernel_scale, [pad, pad, pad, pad])
  199. kernel_identity = 0
  200. bias_identity = 0
  201. if self.primary_rpr_skip is not None:
  202. kernel_identity, bias_identity = self._fuse_bn_tensor(self.primary_rpr_skip)
  203. kernel_conv = 0
  204. bias_conv = 0
  205. for ix in range(self.num_conv_branches):
  206. _kernel, _bias = self._fuse_bn_tensor(self.primary_rpr_conv[ix])
  207. kernel_conv += _kernel
  208. bias_conv += _bias
  209. kernel_final = kernel_conv + kernel_scale + kernel_identity
  210. bias_final = bias_conv + bias_scale + bias_identity
  211. return kernel_final, bias_final
  212. def _get_kernel_bias_cheap(self):
  213. kernel_scale = 0
  214. bias_scale = 0
  215. if self.cheap_rpr_scale is not None:
  216. kernel_scale, bias_scale = self._fuse_bn_tensor(self.cheap_rpr_scale)
  217. pad = self.kernel_size // 2
  218. kernel_scale = F.pad(kernel_scale, [pad, pad, pad, pad])
  219. kernel_identity = 0
  220. bias_identity = 0
  221. if self.cheap_rpr_skip is not None:
  222. kernel_identity, bias_identity = self._fuse_bn_tensor(self.cheap_rpr_skip)
  223. kernel_conv = 0
  224. bias_conv = 0
  225. for ix in range(self.num_conv_branches):
  226. _kernel, _bias = self._fuse_bn_tensor(self.cheap_rpr_conv[ix])
  227. kernel_conv += _kernel
  228. bias_conv += _bias
  229. kernel_final = kernel_conv + kernel_scale + kernel_identity
  230. bias_final = bias_conv + bias_scale + bias_identity
  231. return kernel_final, bias_final
  232. def _fuse_bn_tensor(self, branch):
  233. if isinstance(branch, ConvBnAct):
  234. kernel = branch.conv.weight
  235. running_mean = branch.bn1.running_mean
  236. running_var = branch.bn1.running_var
  237. gamma = branch.bn1.weight
  238. beta = branch.bn1.bias
  239. eps = branch.bn1.eps
  240. else:
  241. assert isinstance(branch, nn.BatchNorm2d)
  242. if not hasattr(self, 'id_tensor'):
  243. input_dim = self.in_channels // self.groups
  244. kernel_value = torch.zeros(
  245. (self.in_channels, input_dim, self.kernel_size, self.kernel_size),
  246. dtype=branch.weight.dtype,
  247. device=branch.weight.device
  248. )
  249. for i in range(self.in_channels):
  250. kernel_value[i, i % input_dim, self.kernel_size // 2, self.kernel_size // 2] = 1
  251. self.id_tensor = kernel_value
  252. kernel = self.id_tensor
  253. running_mean = branch.running_mean
  254. running_var = branch.running_var
  255. gamma = branch.weight
  256. beta = branch.bias
  257. eps = branch.eps
  258. std = (running_var + eps).sqrt()
  259. t = (gamma / std).reshape(-1, 1, 1, 1)
  260. return kernel * t, beta - running_mean * gamma / std
  261. def switch_to_deploy(self):
  262. if self.infer_mode:
  263. return
  264. primary_kernel, primary_bias = self._get_kernel_bias_primary()
  265. self.primary_conv = nn.Conv2d(
  266. in_channels=self.primary_rpr_conv[0].conv.in_channels,
  267. out_channels=self.primary_rpr_conv[0].conv.out_channels,
  268. kernel_size=self.primary_rpr_conv[0].conv.kernel_size,
  269. stride=self.primary_rpr_conv[0].conv.stride,
  270. padding=self.primary_rpr_conv[0].conv.padding,
  271. dilation=self.primary_rpr_conv[0].conv.dilation,
  272. groups=self.primary_rpr_conv[0].conv.groups,
  273. bias=True
  274. )
  275. self.primary_conv.weight.data = primary_kernel
  276. self.primary_conv.bias.data = primary_bias
  277. self.primary_conv = nn.Sequential(
  278. self.primary_conv,
  279. self.primary_activation if self.primary_activation is not None else nn.Sequential()
  280. )
  281. cheap_kernel, cheap_bias = self._get_kernel_bias_cheap()
  282. self.cheap_operation = nn.Conv2d(
  283. in_channels=self.cheap_rpr_conv[0].conv.in_channels,
  284. out_channels=self.cheap_rpr_conv[0].conv.out_channels,
  285. kernel_size=self.cheap_rpr_conv[0].conv.kernel_size,
  286. stride=self.cheap_rpr_conv[0].conv.stride,
  287. padding=self.cheap_rpr_conv[0].conv.padding,
  288. dilation=self.cheap_rpr_conv[0].conv.dilation,
  289. groups=self.cheap_rpr_conv[0].conv.groups,
  290. bias=True
  291. )
  292. self.cheap_operation.weight.data = cheap_kernel
  293. self.cheap_operation.bias.data = cheap_bias
  294. self.cheap_operation = nn.Sequential(
  295. self.cheap_operation,
  296. self.cheap_activation if self.cheap_activation is not None else nn.Sequential()
  297. )
  298. # Delete un-used branches
  299. for para in self.parameters():
  300. para.detach_()
  301. if hasattr(self, 'primary_rpr_conv'):
  302. self.__delattr__('primary_rpr_conv')
  303. if hasattr(self, 'primary_rpr_scale'):
  304. self.__delattr__('primary_rpr_scale')
  305. if hasattr(self, 'primary_rpr_skip'):
  306. self.__delattr__('primary_rpr_skip')
  307. if hasattr(self, 'cheap_rpr_conv'):
  308. self.__delattr__('cheap_rpr_conv')
  309. if hasattr(self, 'cheap_rpr_scale'):
  310. self.__delattr__('cheap_rpr_scale')
  311. if hasattr(self, 'cheap_rpr_skip'):
  312. self.__delattr__('cheap_rpr_skip')
  313. self.infer_mode = True
  314. def reparameterize(self):
  315. self.switch_to_deploy()
  316. class GhostBottleneck(nn.Module):
  317. """ GhostV1/V2 bottleneck w/ optional SE"""
  318. def __init__(
  319. self,
  320. in_chs: int,
  321. mid_chs: int,
  322. out_chs: int,
  323. dw_kernel_size: int = 3,
  324. stride: int = 1,
  325. act_layer: Type[nn.Module] = nn.ReLU,
  326. se_ratio: float = 0.,
  327. mode: str = 'original',
  328. device=None,
  329. dtype=None,
  330. ):
  331. dd = {'device': device, 'dtype': dtype}
  332. super().__init__()
  333. has_se = se_ratio is not None and se_ratio > 0.
  334. self.stride = stride
  335. # Point-wise expansion
  336. if mode == 'original':
  337. self.ghost1 = GhostModule(in_chs, mid_chs, act_layer=act_layer, **dd)
  338. else:
  339. self.ghost1 = GhostModuleV2(in_chs, mid_chs, act_layer=act_layer, **dd)
  340. # Depth-wise convolution
  341. if self.stride > 1:
  342. self.conv_dw = nn.Conv2d(
  343. mid_chs,
  344. mid_chs,
  345. dw_kernel_size,
  346. stride=stride,
  347. padding=(dw_kernel_size-1)//2,
  348. groups=mid_chs,
  349. bias=False,
  350. **dd,
  351. )
  352. self.bn_dw = nn.BatchNorm2d(mid_chs, **dd)
  353. else:
  354. self.conv_dw = None
  355. self.bn_dw = None
  356. # Squeeze-and-excitation
  357. self.se = _SE_LAYER(mid_chs, rd_ratio=se_ratio, **dd) if has_se else None
  358. # Point-wise linear projection
  359. self.ghost2 = GhostModule(mid_chs, out_chs, act_layer=nn.Identity, **dd)
  360. # shortcut
  361. if in_chs == out_chs and self.stride == 1:
  362. self.shortcut = nn.Sequential()
  363. else:
  364. self.shortcut = nn.Sequential(
  365. nn.Conv2d(
  366. in_chs,
  367. in_chs,
  368. dw_kernel_size,
  369. stride=stride,
  370. padding=(dw_kernel_size-1)//2,
  371. groups=in_chs,
  372. bias=False,
  373. **dd,
  374. ),
  375. nn.BatchNorm2d(in_chs, **dd),
  376. nn.Conv2d(in_chs, out_chs, 1, stride=1, padding=0, bias=False, **dd),
  377. nn.BatchNorm2d(out_chs, **dd),
  378. )
  379. def forward(self, x: torch.Tensor) -> torch.Tensor:
  380. shortcut = x
  381. # 1st ghost bottleneck
  382. x = self.ghost1(x)
  383. # Depth-wise convolution
  384. if self.conv_dw is not None:
  385. x = self.conv_dw(x)
  386. x = self.bn_dw(x)
  387. # Squeeze-and-excitation
  388. if self.se is not None:
  389. x = self.se(x)
  390. # 2nd ghost bottleneck
  391. x = self.ghost2(x)
  392. x += self.shortcut(shortcut)
  393. return x
  394. class GhostBottleneckV3(nn.Module):
  395. """ GhostV3 bottleneck w/ optional SE"""
  396. def __init__(
  397. self,
  398. in_chs: int,
  399. mid_chs: int,
  400. out_chs: int,
  401. dw_kernel_size: int = 3,
  402. stride: int = 1,
  403. act_layer: Type[nn.Module] = nn.ReLU,
  404. se_ratio: float = 0.,
  405. mode: str = 'original',
  406. device=None,
  407. dtype=None,
  408. ):
  409. dd = {'device': device, 'dtype': dtype}
  410. super().__init__()
  411. has_se = se_ratio is not None and se_ratio > 0.
  412. self.stride = stride
  413. self.num_conv_branches = 3
  414. self.infer_mode = False
  415. if not self.infer_mode:
  416. self.conv_dw = nn.Identity()
  417. self.bn_dw = nn.Identity()
  418. # Point-wise expansion
  419. self.ghost1 = GhostModuleV3(in_chs, mid_chs, act_layer=act_layer, mode=mode, **dd)
  420. # Depth-wise convolution
  421. if self.stride > 1:
  422. self.dw_rpr_conv = nn.ModuleList([ConvBnAct(
  423. mid_chs,
  424. mid_chs,
  425. dw_kernel_size,
  426. stride,
  427. pad_type=(dw_kernel_size - 1) // 2,
  428. group_size=1,
  429. act_layer=None,
  430. **dd,
  431. ) for _ in range(self.num_conv_branches)
  432. ])
  433. # Re-parameterizable scale branch
  434. self.dw_rpr_scale = ConvBnAct(mid_chs, mid_chs, 1, 2, pad_type=0, group_size=1, act_layer=None, **dd)
  435. self.kernel_size = dw_kernel_size
  436. self.in_channels = mid_chs
  437. else:
  438. self.dw_rpr_conv = nn.ModuleList()
  439. self.dw_rpr_scale = nn.Identity()
  440. self.dw_rpr_skip = None
  441. # Squeeze-and-excitation
  442. self.se = _SE_LAYER(mid_chs, rd_ratio=se_ratio, **dd) if has_se else nn.Identity()
  443. # Point-wise linear projection
  444. self.ghost2 = GhostModuleV3(mid_chs, out_chs, act_layer=nn.Identity, mode='original', **dd)
  445. # shortcut
  446. if in_chs == out_chs and self.stride == 1:
  447. self.shortcut = nn.Identity()
  448. else:
  449. self.shortcut = nn.Sequential(
  450. nn.Conv2d(
  451. in_chs,
  452. in_chs,
  453. dw_kernel_size,
  454. stride=stride,
  455. padding=(dw_kernel_size-1)//2,
  456. groups=in_chs,
  457. bias=False,
  458. **dd,
  459. ),
  460. nn.BatchNorm2d(in_chs, **dd),
  461. nn.Conv2d(in_chs, out_chs, 1, stride=1, padding=0, bias=False, **dd),
  462. nn.BatchNorm2d(out_chs, **dd),
  463. )
  464. def forward(self, x: torch.Tensor) -> torch.Tensor:
  465. shortcut = x
  466. # 1st ghost bottleneck
  467. x = self.ghost1(x)
  468. # Depth-wise convolution
  469. if self.stride > 1:
  470. if self.infer_mode:
  471. x = self.conv_dw(x)
  472. x = self.bn_dw(x)
  473. else:
  474. x1 = self.dw_rpr_scale(x)
  475. for dw_rpr_conv in self.dw_rpr_conv:
  476. x1 += dw_rpr_conv(x)
  477. x = x1
  478. # Squeeze-and-excitation
  479. x = self.se(x)
  480. # 2nd ghost bottleneck
  481. x = self.ghost2(x)
  482. x += self.shortcut(shortcut)
  483. return x
  484. def _get_kernel_bias_dw(self):
  485. kernel_scale = 0
  486. bias_scale = 0
  487. if self.dw_rpr_scale is not None:
  488. kernel_scale, bias_scale = self._fuse_bn_tensor(self.dw_rpr_scale)
  489. pad = self.kernel_size // 2
  490. kernel_scale = F.pad(kernel_scale, [pad, pad, pad, pad])
  491. kernel_identity = 0
  492. bias_identity = 0
  493. if self.dw_rpr_skip is not None:
  494. kernel_identity, bias_identity = self._fuse_bn_tensor(self.dw_rpr_skip)
  495. kernel_conv = 0
  496. bias_conv = 0
  497. for ix in range(self.num_conv_branches):
  498. _kernel, _bias = self._fuse_bn_tensor(self.dw_rpr_conv[ix])
  499. kernel_conv += _kernel
  500. bias_conv += _bias
  501. kernel_final = kernel_conv + kernel_scale + kernel_identity
  502. bias_final = bias_conv + bias_scale + bias_identity
  503. return kernel_final, bias_final
  504. def _fuse_bn_tensor(self, branch):
  505. if isinstance(branch, ConvBnAct):
  506. kernel = branch.conv.weight
  507. running_mean = branch.bn1.running_mean
  508. running_var = branch.bn1.running_var
  509. gamma = branch.bn1.weight
  510. beta = branch.bn1.bias
  511. eps = branch.bn1.eps
  512. else:
  513. assert isinstance(branch, nn.BatchNorm2d)
  514. if not hasattr(self, 'id_tensor'):
  515. input_dim = self.in_channels // self.groups
  516. kernel_value = torch.zeros(
  517. (self.in_channels, input_dim, self.kernel_size, self.kernel_size),
  518. dtype=branch.weight.dtype,
  519. device=branch.weight.device
  520. )
  521. for i in range(self.in_channels):
  522. kernel_value[i, i % input_dim, self.kernel_size // 2, self.kernel_size // 2] = 1
  523. self.id_tensor = kernel_value
  524. kernel = self.id_tensor
  525. running_mean = branch.running_mean
  526. running_var = branch.running_var
  527. gamma = branch.weight
  528. beta = branch.bias
  529. eps = branch.eps
  530. std = (running_var + eps).sqrt()
  531. t = (gamma / std).reshape(-1, 1, 1, 1)
  532. return kernel * t, beta - running_mean * gamma / std
  533. def switch_to_deploy(self):
  534. if self.infer_mode or self.stride == 1:
  535. return
  536. dw_kernel, dw_bias = self._get_kernel_bias_dw()
  537. self.conv_dw = nn.Conv2d(
  538. in_channels=self.dw_rpr_conv[0].conv.in_channels,
  539. out_channels=self.dw_rpr_conv[0].conv.out_channels,
  540. kernel_size=self.dw_rpr_conv[0].conv.kernel_size,
  541. stride=self.dw_rpr_conv[0].conv.stride,
  542. padding=self.dw_rpr_conv[0].conv.padding,
  543. dilation=self.dw_rpr_conv[0].conv.dilation,
  544. groups=self.dw_rpr_conv[0].conv.groups,
  545. bias=True
  546. )
  547. self.conv_dw.weight.data = dw_kernel
  548. self.conv_dw.bias.data = dw_bias
  549. self.bn_dw = nn.Identity()
  550. # Delete un-used branches
  551. for para in self.parameters():
  552. para.detach_()
  553. if hasattr(self, 'dw_rpr_conv'):
  554. self.__delattr__('dw_rpr_conv')
  555. if hasattr(self, 'dw_rpr_scale'):
  556. self.__delattr__('dw_rpr_scale')
  557. if hasattr(self, 'dw_rpr_skip'):
  558. self.__delattr__('dw_rpr_skip')
  559. self.infer_mode = True
  560. def reparameterize(self):
  561. self.switch_to_deploy()
  562. class GhostNet(nn.Module):
  563. def __init__(
  564. self,
  565. cfgs: List[List[List[Union[int, float]]]],
  566. num_classes: int = 1000,
  567. width: float = 1.0,
  568. in_chans: int = 3,
  569. output_stride: int = 32,
  570. global_pool: str = 'avg',
  571. drop_rate: float = 0.2,
  572. version: str = 'v1',
  573. device=None,
  574. dtype=None,
  575. ):
  576. super().__init__()
  577. dd = {'device': device, 'dtype': dtype}
  578. # setting of inverted residual blocks
  579. assert output_stride == 32, 'only output_stride==32 is valid, dilation not supported'
  580. self.cfgs = cfgs
  581. self.num_classes = num_classes
  582. self.drop_rate = drop_rate
  583. self.grad_checkpointing = False
  584. self.feature_info = []
  585. Bottleneck = GhostBottleneckV3 if version == 'v3' else GhostBottleneck
  586. # building first layer
  587. stem_chs = make_divisible(16 * width, 4)
  588. self.conv_stem = nn.Conv2d(in_chans, stem_chs, 3, 2, 1, bias=False, **dd)
  589. self.feature_info.append(dict(num_chs=stem_chs, reduction=2, module=f'conv_stem'))
  590. self.bn1 = nn.BatchNorm2d(stem_chs, **dd)
  591. self.act1 = nn.ReLU(inplace=True)
  592. prev_chs = stem_chs
  593. # building inverted residual blocks
  594. stages = nn.ModuleList([])
  595. stage_idx = 0
  596. layer_idx = 0
  597. net_stride = 2
  598. for cfg in self.cfgs:
  599. layers = []
  600. s = 1
  601. for k, exp_size, c, se_ratio, s in cfg:
  602. out_chs = make_divisible(c * width, 4)
  603. mid_chs = make_divisible(exp_size * width, 4)
  604. layer_kwargs = dict(**dd)
  605. if version == 'v2' and layer_idx > 1:
  606. layer_kwargs['mode'] = 'attn'
  607. if version == 'v3' and layer_idx > 1:
  608. layer_kwargs['mode'] = 'shortcut'
  609. layers.append(Bottleneck(prev_chs, mid_chs, out_chs, k, s, se_ratio=se_ratio, **layer_kwargs))
  610. prev_chs = out_chs
  611. layer_idx += 1
  612. if s > 1:
  613. net_stride *= 2
  614. self.feature_info.append(dict(
  615. num_chs=prev_chs, reduction=net_stride, module=f'blocks.{stage_idx}'))
  616. stages.append(nn.Sequential(*layers))
  617. stage_idx += 1
  618. out_chs = make_divisible(exp_size * width, 4)
  619. stages.append(nn.Sequential(ConvBnAct(prev_chs, out_chs, 1, **dd)))
  620. self.pool_dim = prev_chs = out_chs
  621. self.blocks = nn.Sequential(*stages)
  622. # building last several layers
  623. self.num_features = prev_chs
  624. self.head_hidden_size = out_chs = 1280
  625. self.global_pool = SelectAdaptivePool2d(pool_type=global_pool)
  626. self.conv_head = nn.Conv2d(prev_chs, out_chs, 1, 1, 0, bias=True, **dd)
  627. self.act2 = nn.ReLU(inplace=True)
  628. self.flatten = nn.Flatten(1) if global_pool else nn.Identity() # don't flatten if pooling disabled
  629. self.classifier = Linear(out_chs, num_classes, **dd) if num_classes > 0 else nn.Identity()
  630. # FIXME init
  631. @torch.jit.ignore
  632. def no_weight_decay(self) -> Set:
  633. return set()
  634. @torch.jit.ignore
  635. def group_matcher(self, coarse: bool = False) -> Dict[str, Any]:
  636. matcher = dict(
  637. stem=r'^conv_stem|bn1',
  638. blocks=[
  639. (r'^blocks\.(\d+)' if coarse else r'^blocks\.(\d+)\.(\d+)', None),
  640. (r'conv_head', (99999,))
  641. ]
  642. )
  643. return matcher
  644. @torch.jit.ignore
  645. def set_grad_checkpointing(self, enable: bool = True):
  646. self.grad_checkpointing = enable
  647. @torch.jit.ignore
  648. def get_classifier(self) -> nn.Module:
  649. return self.classifier
  650. def reset_classifier(self, num_classes: int, global_pool: str = 'avg'):
  651. self.num_classes = num_classes
  652. # cannot meaningfully change pooling of efficient head after creation
  653. self.global_pool = SelectAdaptivePool2d(pool_type=global_pool)
  654. self.flatten = nn.Flatten(1) if global_pool else nn.Identity() # don't flatten if pooling disabled
  655. self.classifier = Linear(
  656. self.head_hidden_size, num_classes,
  657. device=self.conv_head.weight.device, dtype=self.conv_head.weight.dtype
  658. ) if num_classes > 0 else nn.Identity()
  659. def forward_intermediates(
  660. self,
  661. x: torch.Tensor,
  662. indices: Optional[Union[int, List[int]]] = None,
  663. norm: bool = False,
  664. stop_early: bool = False,
  665. output_fmt: str = 'NCHW',
  666. intermediates_only: bool = False,
  667. ) -> Union[List[torch.Tensor], Tuple[torch.Tensor, List[torch.Tensor]]]:
  668. """ Forward features that returns intermediates.
  669. Args:
  670. x: Input image tensor
  671. indices: Take last n blocks if int, all if None, select matching indices if sequence
  672. norm: Apply norm layer to compatible intermediates
  673. stop_early: Stop iterating over blocks when last desired intermediate hit
  674. output_fmt: Shape of intermediate feature outputs
  675. intermediates_only: Only return intermediate features
  676. Returns:
  677. """
  678. assert output_fmt in ('NCHW',), 'Output shape must be NCHW.'
  679. intermediates = []
  680. stage_ends = [-1] + [int(info['module'].split('.')[-1]) for info in self.feature_info[1:]]
  681. take_indices, max_index = feature_take_indices(len(stage_ends), indices)
  682. take_indices = [stage_ends[i]+1 for i in take_indices]
  683. max_index = stage_ends[max_index]
  684. # forward pass
  685. feat_idx = 0
  686. x = self.conv_stem(x)
  687. if feat_idx in take_indices:
  688. intermediates.append(x)
  689. x = self.bn1(x)
  690. x = self.act1(x)
  691. if torch.jit.is_scripting() or not stop_early: # can't slice blocks in torchscript
  692. stages = self.blocks
  693. else:
  694. stages = self.blocks[:max_index + 1]
  695. for feat_idx, stage in enumerate(stages, start=1):
  696. if self.grad_checkpointing and not torch.jit.is_scripting():
  697. x = checkpoint_seq(stage, x)
  698. else:
  699. x = stage(x)
  700. if feat_idx in take_indices:
  701. intermediates.append(x)
  702. if intermediates_only:
  703. return intermediates
  704. return x, intermediates
  705. def prune_intermediate_layers(
  706. self,
  707. indices: Union[int, List[int]] = 1,
  708. prune_norm: bool = False,
  709. prune_head: bool = True,
  710. ):
  711. """ Prune layers not required for specified intermediates.
  712. """
  713. stage_ends = [-1] + [int(info['module'].split('.')[-1]) for info in self.feature_info[1:]]
  714. take_indices, max_index = feature_take_indices(len(stage_ends), indices)
  715. max_index = stage_ends[max_index]
  716. self.blocks = self.blocks[:max_index + 1] # truncate blocks w/ stem as idx 0
  717. if prune_head:
  718. self.reset_classifier(0, '')
  719. return take_indices
  720. def forward_features(self, x: torch.Tensor) -> torch.Tensor:
  721. x = self.conv_stem(x)
  722. x = self.bn1(x)
  723. x = self.act1(x)
  724. if self.grad_checkpointing and not torch.jit.is_scripting():
  725. x = checkpoint_seq(self.blocks, x, flatten=True)
  726. else:
  727. x = self.blocks(x)
  728. return x
  729. def forward_head(self, x: torch.Tensor, pre_logits: bool = False) -> torch.Tensor:
  730. x = self.global_pool(x)
  731. x = self.conv_head(x)
  732. x = self.act2(x)
  733. x = self.flatten(x)
  734. if self.drop_rate > 0.:
  735. x = F.dropout(x, p=self.drop_rate, training=self.training)
  736. return x if pre_logits else self.classifier(x)
  737. def forward(self, x: torch.Tensor) -> torch.Tensor:
  738. x = self.forward_features(x)
  739. x = self.forward_head(x)
  740. return x
  741. def convert_to_deploy(self):
  742. reparameterize_model(self, inplace=False)
  743. def checkpoint_filter_fn(state_dict: Dict[str, torch.Tensor], model: nn.Module) -> Dict[str, torch.Tensor]:
  744. if 'state_dict' in state_dict:
  745. state_dict = state_dict['state_dict']
  746. out_dict = {}
  747. for k, v in state_dict.items():
  748. if 'bn.' in k and '.ghost' in k:
  749. k = k.replace('bn.', 'bn1.')
  750. if 'bn.' in k and '.dw_rpr_' in k:
  751. k = k.replace('bn.', 'bn1.')
  752. if 'total' in k:
  753. continue
  754. out_dict[k] = v
  755. return out_dict
  756. def _create_ghostnet(variant: str, width: float = 1.0, pretrained: bool = False, **kwargs: Any) -> GhostNet:
  757. """
  758. Constructs a GhostNet model
  759. """
  760. cfgs = [
  761. # k, t, c, SE, s
  762. # stage1
  763. [[3, 16, 16, 0, 1]],
  764. # stage2
  765. [[3, 48, 24, 0, 2]],
  766. [[3, 72, 24, 0, 1]],
  767. # stage3
  768. [[5, 72, 40, 0.25, 2]],
  769. [[5, 120, 40, 0.25, 1]],
  770. # stage4
  771. [[3, 240, 80, 0, 2]],
  772. [[3, 200, 80, 0, 1],
  773. [3, 184, 80, 0, 1],
  774. [3, 184, 80, 0, 1],
  775. [3, 480, 112, 0.25, 1],
  776. [3, 672, 112, 0.25, 1]
  777. ],
  778. # stage5
  779. [[5, 672, 160, 0.25, 2]],
  780. [[5, 960, 160, 0, 1],
  781. [5, 960, 160, 0.25, 1],
  782. [5, 960, 160, 0, 1],
  783. [5, 960, 160, 0.25, 1]
  784. ]
  785. ]
  786. model_kwargs = dict(
  787. cfgs=cfgs,
  788. width=width,
  789. **kwargs,
  790. )
  791. return build_model_with_cfg(
  792. GhostNet,
  793. variant,
  794. pretrained,
  795. pretrained_filter_fn=checkpoint_filter_fn,
  796. feature_cfg=dict(flatten_sequential=True),
  797. **model_kwargs,
  798. )
  799. def _cfg(url='', **kwargs):
  800. return {
  801. 'url': url, 'num_classes': 1000, 'input_size': (3, 224, 224), 'pool_size': (7, 7),
  802. 'crop_pct': 0.875, 'interpolation': 'bicubic',
  803. 'mean': IMAGENET_DEFAULT_MEAN, 'std': IMAGENET_DEFAULT_STD,
  804. 'first_conv': 'conv_stem', 'classifier': 'classifier',
  805. 'license': 'apache-2.0',
  806. **kwargs
  807. }
  808. default_cfgs = generate_default_cfgs({
  809. 'ghostnet_050.untrained': _cfg(),
  810. 'ghostnet_100.in1k': _cfg(
  811. hf_hub_id='timm/',
  812. # url='https://github.com/huawei-noah/CV-backbones/releases/download/ghostnet_pth/ghostnet_1x.pth'
  813. ),
  814. 'ghostnet_130.untrained': _cfg(),
  815. 'ghostnetv2_100.in1k': _cfg(
  816. hf_hub_id='timm/',
  817. # url='https://github.com/huawei-noah/Efficient-AI-Backbones/releases/download/GhostNetV2/ck_ghostnetv2_10.pth.tar'
  818. ),
  819. 'ghostnetv2_130.in1k': _cfg(
  820. hf_hub_id='timm/',
  821. # url='https://github.com/huawei-noah/Efficient-AI-Backbones/releases/download/GhostNetV2/ck_ghostnetv2_13.pth.tar'
  822. ),
  823. 'ghostnetv2_160.in1k': _cfg(
  824. hf_hub_id='timm/',
  825. # url='https://github.com/huawei-noah/Efficient-AI-Backbones/releases/download/GhostNetV2/ck_ghostnetv2_16.pth.tar'
  826. ),
  827. 'ghostnetv3_050.untrained': _cfg(),
  828. 'ghostnetv3_100.in1k': _cfg(
  829. hf_hub_id='timm/',
  830. #url='https://github.com/huawei-noah/Efficient-AI-Backbones/releases/download/GhostNetV3/ghostnetv3-1.0.pth.tar'
  831. ),
  832. 'ghostnetv3_130.untrained': _cfg(),
  833. 'ghostnetv3_160.untrained': _cfg(),
  834. })
  835. @register_model
  836. def ghostnet_050(pretrained=False, **kwargs) -> GhostNet:
  837. """ GhostNet-0.5x """
  838. model = _create_ghostnet('ghostnet_050', width=0.5, pretrained=pretrained, **kwargs)
  839. return model
  840. @register_model
  841. def ghostnet_100(pretrained=False, **kwargs) -> GhostNet:
  842. """ GhostNet-1.0x """
  843. model = _create_ghostnet('ghostnet_100', width=1.0, pretrained=pretrained, **kwargs)
  844. return model
  845. @register_model
  846. def ghostnet_130(pretrained=False, **kwargs) -> GhostNet:
  847. """ GhostNet-1.3x """
  848. model = _create_ghostnet('ghostnet_130', width=1.3, pretrained=pretrained, **kwargs)
  849. return model
  850. @register_model
  851. def ghostnetv2_100(pretrained=False, **kwargs) -> GhostNet:
  852. """ GhostNetV2-1.0x """
  853. model = _create_ghostnet('ghostnetv2_100', width=1.0, pretrained=pretrained, version='v2', **kwargs)
  854. return model
  855. @register_model
  856. def ghostnetv2_130(pretrained=False, **kwargs) -> GhostNet:
  857. """ GhostNetV2-1.3x """
  858. model = _create_ghostnet('ghostnetv2_130', width=1.3, pretrained=pretrained, version='v2', **kwargs)
  859. return model
  860. @register_model
  861. def ghostnetv2_160(pretrained=False, **kwargs) -> GhostNet:
  862. """ GhostNetV2-1.6x """
  863. model = _create_ghostnet('ghostnetv2_160', width=1.6, pretrained=pretrained, version='v2', **kwargs)
  864. return model
  865. @register_model
  866. def ghostnetv3_050(pretrained: bool = False, **kwargs: Any) -> GhostNet:
  867. """ GhostNetV3-0.5x """
  868. model = _create_ghostnet('ghostnetv3_050', width=0.5, pretrained=pretrained, version='v3', **kwargs)
  869. return model
  870. @register_model
  871. def ghostnetv3_100(pretrained: bool = False, **kwargs: Any) -> GhostNet:
  872. """ GhostNetV3-1.0x """
  873. model = _create_ghostnet('ghostnetv3_100', width=1.0, pretrained=pretrained, version='v3', **kwargs)
  874. return model
  875. @register_model
  876. def ghostnetv3_130(pretrained: bool = False, **kwargs: Any) -> GhostNet:
  877. """ GhostNetV3-1.3x """
  878. model = _create_ghostnet('ghostnetv3_130', width=1.3, pretrained=pretrained, version='v3', **kwargs)
  879. return model
  880. @register_model
  881. def ghostnetv3_160(pretrained: bool = False, **kwargs: Any) -> GhostNet:
  882. """ GhostNetV3-1.6x """
  883. model = _create_ghostnet('ghostnetv3_160', width=1.6, pretrained=pretrained, version='v3', **kwargs)
  884. return model