rec_lcnetv3.py 17 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525526527528529530531532533534535536537538539540541542543544545546547548549550551552553554555556557558
  1. # copyright (c) 2021 PaddlePaddle Authors. All Rights Reserve.
  2. #
  3. # Licensed under the Apache License, Version 2.0 (the "License");
  4. # you may not use this file except in compliance with the License.
  5. # You may obtain a copy of the License at
  6. #
  7. # http://www.apache.org/licenses/LICENSE-2.0
  8. #
  9. # Unless required by applicable law or agreed to in writing, software
  10. # distributed under the License is distributed on an "AS IS" BASIS,
  11. # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  12. # See the License for the specific language governing permissions and
  13. # limitations under the License.
  14. from __future__ import absolute_import
  15. from __future__ import division
  16. from __future__ import print_function
  17. import paddle
  18. import paddle.nn as nn
  19. import paddle.nn.functional as F
  20. from paddle import ParamAttr
  21. from paddle.nn.initializer import Constant, KaimingNormal
  22. from paddle.nn import (
  23. AdaptiveAvgPool2D,
  24. BatchNorm2D,
  25. Conv2D,
  26. Dropout,
  27. Hardsigmoid,
  28. Hardswish,
  29. Identity,
  30. Linear,
  31. ReLU,
  32. )
  33. from paddle.regularizer import L2Decay
  34. from ppocr.modeling.backbones.rec_hgnet import MeanPool2D
  35. NET_CONFIG_det = {
  36. "blocks2":
  37. # k, in_c, out_c, s, use_se
  38. [[3, 16, 32, 1, False]],
  39. "blocks3": [[3, 32, 64, 2, False], [3, 64, 64, 1, False]],
  40. "blocks4": [[3, 64, 128, 2, False], [3, 128, 128, 1, False]],
  41. "blocks5": [
  42. [3, 128, 256, 2, False],
  43. [5, 256, 256, 1, False],
  44. [5, 256, 256, 1, False],
  45. [5, 256, 256, 1, False],
  46. [5, 256, 256, 1, False],
  47. ],
  48. "blocks6": [
  49. [5, 256, 512, 2, True],
  50. [5, 512, 512, 1, True],
  51. [5, 512, 512, 1, False],
  52. [5, 512, 512, 1, False],
  53. ],
  54. }
  55. NET_CONFIG_rec = {
  56. "blocks2":
  57. # k, in_c, out_c, s, use_se
  58. [[3, 16, 32, 1, False]],
  59. "blocks3": [[3, 32, 64, 1, False], [3, 64, 64, 1, False]],
  60. "blocks4": [[3, 64, 128, (2, 1), False], [3, 128, 128, 1, False]],
  61. "blocks5": [
  62. [3, 128, 256, (1, 2), False],
  63. [5, 256, 256, 1, False],
  64. [5, 256, 256, 1, False],
  65. [5, 256, 256, 1, False],
  66. [5, 256, 256, 1, False],
  67. ],
  68. "blocks6": [
  69. [5, 256, 512, (2, 1), True],
  70. [5, 512, 512, 1, True],
  71. [5, 512, 512, (2, 1), False],
  72. [5, 512, 512, 1, False],
  73. ],
  74. }
  75. def make_divisible(v, divisor=16, min_value=None):
  76. if min_value is None:
  77. min_value = divisor
  78. new_v = max(min_value, int(v + divisor / 2) // divisor * divisor)
  79. if new_v < 0.9 * v:
  80. new_v += divisor
  81. return new_v
  82. class LearnableAffineBlock(nn.Layer):
  83. def __init__(self, scale_value=1.0, bias_value=0.0, lr_mult=1.0, lab_lr=0.1):
  84. super().__init__()
  85. self.scale = self.create_parameter(
  86. shape=[
  87. 1,
  88. ],
  89. default_initializer=Constant(value=scale_value),
  90. attr=ParamAttr(learning_rate=lr_mult * lab_lr),
  91. )
  92. self.add_parameter("scale", self.scale)
  93. self.bias = self.create_parameter(
  94. shape=[
  95. 1,
  96. ],
  97. default_initializer=Constant(value=bias_value),
  98. attr=ParamAttr(learning_rate=lr_mult * lab_lr),
  99. )
  100. self.add_parameter("bias", self.bias)
  101. def forward(self, x):
  102. return self.scale * x + self.bias
  103. class ConvBNLayer(nn.Layer):
  104. def __init__(
  105. self, in_channels, out_channels, kernel_size, stride, groups=1, lr_mult=1.0
  106. ):
  107. super().__init__()
  108. self.conv = Conv2D(
  109. in_channels=in_channels,
  110. out_channels=out_channels,
  111. kernel_size=kernel_size,
  112. stride=stride,
  113. padding=(kernel_size - 1) // 2,
  114. groups=groups,
  115. weight_attr=ParamAttr(initializer=KaimingNormal(), learning_rate=lr_mult),
  116. bias_attr=False,
  117. )
  118. self.bn = BatchNorm2D(
  119. out_channels,
  120. weight_attr=ParamAttr(regularizer=L2Decay(0.0), learning_rate=lr_mult),
  121. bias_attr=ParamAttr(regularizer=L2Decay(0.0), learning_rate=lr_mult),
  122. )
  123. def forward(self, x):
  124. x = self.conv(x)
  125. x = self.bn(x)
  126. return x
  127. class Act(nn.Layer):
  128. def __init__(self, act="hswish", lr_mult=1.0, lab_lr=0.1):
  129. super().__init__()
  130. if act == "hswish":
  131. self.act = Hardswish()
  132. else:
  133. assert act == "relu"
  134. self.act = ReLU()
  135. self.lab = LearnableAffineBlock(lr_mult=lr_mult, lab_lr=lab_lr)
  136. def forward(self, x):
  137. return self.lab(self.act(x))
  138. class LearnableRepLayer(nn.Layer):
  139. def __init__(
  140. self,
  141. in_channels,
  142. out_channels,
  143. kernel_size,
  144. stride=1,
  145. groups=1,
  146. num_conv_branches=1,
  147. lr_mult=1.0,
  148. lab_lr=0.1,
  149. ):
  150. super().__init__()
  151. self.is_repped = False
  152. self.groups = groups
  153. self.stride = stride
  154. self.kernel_size = kernel_size
  155. self.in_channels = in_channels
  156. self.out_channels = out_channels
  157. self.num_conv_branches = num_conv_branches
  158. self.padding = (kernel_size - 1) // 2
  159. self.identity = (
  160. BatchNorm2D(
  161. num_features=in_channels,
  162. weight_attr=ParamAttr(learning_rate=lr_mult),
  163. bias_attr=ParamAttr(learning_rate=lr_mult),
  164. )
  165. if out_channels == in_channels and stride == 1
  166. else None
  167. )
  168. self.conv_kxk = nn.LayerList(
  169. [
  170. ConvBNLayer(
  171. in_channels,
  172. out_channels,
  173. kernel_size,
  174. stride,
  175. groups=groups,
  176. lr_mult=lr_mult,
  177. )
  178. for _ in range(self.num_conv_branches)
  179. ]
  180. )
  181. self.conv_1x1 = (
  182. ConvBNLayer(
  183. in_channels, out_channels, 1, stride, groups=groups, lr_mult=lr_mult
  184. )
  185. if kernel_size > 1
  186. else None
  187. )
  188. self.lab = LearnableAffineBlock(lr_mult=lr_mult, lab_lr=lab_lr)
  189. self.act = Act(lr_mult=lr_mult, lab_lr=lab_lr)
  190. def forward(self, x):
  191. # for export
  192. if self.is_repped:
  193. out = self.lab(self.reparam_conv(x))
  194. if self.stride != 2:
  195. out = self.act(out)
  196. return out
  197. out = 0
  198. if self.identity is not None:
  199. out += self.identity(x)
  200. if self.conv_1x1 is not None:
  201. out += self.conv_1x1(x)
  202. for conv in self.conv_kxk:
  203. out += conv(x)
  204. out = self.lab(out)
  205. if self.stride != 2:
  206. out = self.act(out)
  207. return out
  208. def rep(self):
  209. if self.is_repped:
  210. return
  211. kernel, bias = self._get_kernel_bias()
  212. self.reparam_conv = Conv2D(
  213. in_channels=self.in_channels,
  214. out_channels=self.out_channels,
  215. kernel_size=self.kernel_size,
  216. stride=self.stride,
  217. padding=self.padding,
  218. groups=self.groups,
  219. )
  220. self.reparam_conv.weight.set_value(kernel)
  221. self.reparam_conv.bias.set_value(bias)
  222. self.is_repped = True
  223. def _pad_kernel_1x1_to_kxk(self, kernel1x1, pad):
  224. if not isinstance(kernel1x1, paddle.Tensor):
  225. return 0
  226. else:
  227. return nn.functional.pad(kernel1x1, [pad, pad, pad, pad])
  228. def _get_kernel_bias(self):
  229. kernel_conv_1x1, bias_conv_1x1 = self._fuse_bn_tensor(self.conv_1x1)
  230. kernel_conv_1x1 = self._pad_kernel_1x1_to_kxk(
  231. kernel_conv_1x1, self.kernel_size // 2
  232. )
  233. kernel_identity, bias_identity = self._fuse_bn_tensor(self.identity)
  234. kernel_conv_kxk = 0
  235. bias_conv_kxk = 0
  236. for conv in self.conv_kxk:
  237. kernel, bias = self._fuse_bn_tensor(conv)
  238. kernel_conv_kxk += kernel
  239. bias_conv_kxk += bias
  240. kernel_reparam = kernel_conv_kxk + kernel_conv_1x1 + kernel_identity
  241. bias_reparam = bias_conv_kxk + bias_conv_1x1 + bias_identity
  242. return kernel_reparam, bias_reparam
  243. def _fuse_bn_tensor(self, branch):
  244. if not branch:
  245. return 0, 0
  246. elif isinstance(branch, ConvBNLayer):
  247. kernel = branch.conv.weight
  248. running_mean = branch.bn._mean
  249. running_var = branch.bn._variance
  250. gamma = branch.bn.weight
  251. beta = branch.bn.bias
  252. eps = branch.bn._epsilon
  253. else:
  254. assert isinstance(branch, BatchNorm2D)
  255. if not hasattr(self, "id_tensor"):
  256. input_dim = self.in_channels // self.groups
  257. kernel_value = paddle.zeros(
  258. (self.in_channels, input_dim, self.kernel_size, self.kernel_size),
  259. dtype=branch.weight.dtype,
  260. )
  261. for i in range(self.in_channels):
  262. kernel_value[
  263. i, i % input_dim, self.kernel_size // 2, self.kernel_size // 2
  264. ] = 1
  265. self.id_tensor = kernel_value
  266. kernel = self.id_tensor
  267. running_mean = branch._mean
  268. running_var = branch._variance
  269. gamma = branch.weight
  270. beta = branch.bias
  271. eps = branch._epsilon
  272. std = (running_var + eps).sqrt()
  273. t = (gamma / std).reshape((-1, 1, 1, 1))
  274. return kernel * t, beta - running_mean * gamma / std
  275. class SELayer(nn.Layer):
  276. def __init__(self, channel, reduction=4, lr_mult=1.0):
  277. super().__init__()
  278. if "npu" in paddle.device.get_device():
  279. self.avg_pool = MeanPool2D(1, 1)
  280. else:
  281. self.avg_pool = AdaptiveAvgPool2D(1)
  282. self.conv1 = Conv2D(
  283. in_channels=channel,
  284. out_channels=channel // reduction,
  285. kernel_size=1,
  286. stride=1,
  287. padding=0,
  288. weight_attr=ParamAttr(learning_rate=lr_mult),
  289. bias_attr=ParamAttr(learning_rate=lr_mult),
  290. )
  291. self.relu = ReLU()
  292. self.conv2 = Conv2D(
  293. in_channels=channel // reduction,
  294. out_channels=channel,
  295. kernel_size=1,
  296. stride=1,
  297. padding=0,
  298. weight_attr=ParamAttr(learning_rate=lr_mult),
  299. bias_attr=ParamAttr(learning_rate=lr_mult),
  300. )
  301. self.hardsigmoid = Hardsigmoid()
  302. def forward(self, x):
  303. identity = x
  304. x = self.avg_pool(x)
  305. x = self.conv1(x)
  306. x = self.relu(x)
  307. x = self.conv2(x)
  308. x = self.hardsigmoid(x)
  309. x = paddle.multiply(x=identity, y=x)
  310. return x
  311. class LCNetV3Block(nn.Layer):
  312. def __init__(
  313. self,
  314. in_channels,
  315. out_channels,
  316. stride,
  317. dw_size,
  318. use_se=False,
  319. conv_kxk_num=4,
  320. lr_mult=1.0,
  321. lab_lr=0.1,
  322. ):
  323. super().__init__()
  324. self.use_se = use_se
  325. self.dw_conv = LearnableRepLayer(
  326. in_channels=in_channels,
  327. out_channels=in_channels,
  328. kernel_size=dw_size,
  329. stride=stride,
  330. groups=in_channels,
  331. num_conv_branches=conv_kxk_num,
  332. lr_mult=lr_mult,
  333. lab_lr=lab_lr,
  334. )
  335. if use_se:
  336. self.se = SELayer(in_channels, lr_mult=lr_mult)
  337. self.pw_conv = LearnableRepLayer(
  338. in_channels=in_channels,
  339. out_channels=out_channels,
  340. kernel_size=1,
  341. stride=1,
  342. num_conv_branches=conv_kxk_num,
  343. lr_mult=lr_mult,
  344. lab_lr=lab_lr,
  345. )
  346. def forward(self, x):
  347. x = self.dw_conv(x)
  348. if self.use_se:
  349. x = self.se(x)
  350. x = self.pw_conv(x)
  351. return x
  352. class PPLCNetV3(nn.Layer):
  353. def __init__(
  354. self,
  355. scale=1.0,
  356. conv_kxk_num=4,
  357. lr_mult_list=[1.0, 1.0, 1.0, 1.0, 1.0, 1.0],
  358. lab_lr=0.1,
  359. det=False,
  360. **kwargs,
  361. ):
  362. super().__init__()
  363. self.scale = scale
  364. self.lr_mult_list = lr_mult_list
  365. self.det = det
  366. self.net_config = NET_CONFIG_det if self.det else NET_CONFIG_rec
  367. assert isinstance(
  368. self.lr_mult_list, (list, tuple)
  369. ), "lr_mult_list should be in (list, tuple) but got {}".format(
  370. type(self.lr_mult_list)
  371. )
  372. assert (
  373. len(self.lr_mult_list) == 6
  374. ), "lr_mult_list length should be 6 but got {}".format(len(self.lr_mult_list))
  375. self.conv1 = ConvBNLayer(
  376. in_channels=3,
  377. out_channels=make_divisible(16 * scale),
  378. kernel_size=3,
  379. stride=2,
  380. lr_mult=self.lr_mult_list[0],
  381. )
  382. self.blocks2 = nn.Sequential(
  383. *[
  384. LCNetV3Block(
  385. in_channels=make_divisible(in_c * scale),
  386. out_channels=make_divisible(out_c * scale),
  387. dw_size=k,
  388. stride=s,
  389. use_se=se,
  390. conv_kxk_num=conv_kxk_num,
  391. lr_mult=self.lr_mult_list[1],
  392. lab_lr=lab_lr,
  393. )
  394. for i, (k, in_c, out_c, s, se) in enumerate(self.net_config["blocks2"])
  395. ]
  396. )
  397. self.blocks3 = nn.Sequential(
  398. *[
  399. LCNetV3Block(
  400. in_channels=make_divisible(in_c * scale),
  401. out_channels=make_divisible(out_c * scale),
  402. dw_size=k,
  403. stride=s,
  404. use_se=se,
  405. conv_kxk_num=conv_kxk_num,
  406. lr_mult=self.lr_mult_list[2],
  407. lab_lr=lab_lr,
  408. )
  409. for i, (k, in_c, out_c, s, se) in enumerate(self.net_config["blocks3"])
  410. ]
  411. )
  412. self.blocks4 = nn.Sequential(
  413. *[
  414. LCNetV3Block(
  415. in_channels=make_divisible(in_c * scale),
  416. out_channels=make_divisible(out_c * scale),
  417. dw_size=k,
  418. stride=s,
  419. use_se=se,
  420. conv_kxk_num=conv_kxk_num,
  421. lr_mult=self.lr_mult_list[3],
  422. lab_lr=lab_lr,
  423. )
  424. for i, (k, in_c, out_c, s, se) in enumerate(self.net_config["blocks4"])
  425. ]
  426. )
  427. self.blocks5 = nn.Sequential(
  428. *[
  429. LCNetV3Block(
  430. in_channels=make_divisible(in_c * scale),
  431. out_channels=make_divisible(out_c * scale),
  432. dw_size=k,
  433. stride=s,
  434. use_se=se,
  435. conv_kxk_num=conv_kxk_num,
  436. lr_mult=self.lr_mult_list[4],
  437. lab_lr=lab_lr,
  438. )
  439. for i, (k, in_c, out_c, s, se) in enumerate(self.net_config["blocks5"])
  440. ]
  441. )
  442. self.blocks6 = nn.Sequential(
  443. *[
  444. LCNetV3Block(
  445. in_channels=make_divisible(in_c * scale),
  446. out_channels=make_divisible(out_c * scale),
  447. dw_size=k,
  448. stride=s,
  449. use_se=se,
  450. conv_kxk_num=conv_kxk_num,
  451. lr_mult=self.lr_mult_list[5],
  452. lab_lr=lab_lr,
  453. )
  454. for i, (k, in_c, out_c, s, se) in enumerate(self.net_config["blocks6"])
  455. ]
  456. )
  457. self.out_channels = make_divisible(512 * scale)
  458. if self.det:
  459. mv_c = [16, 24, 56, 480]
  460. self.out_channels = [
  461. make_divisible(self.net_config["blocks3"][-1][2] * scale),
  462. make_divisible(self.net_config["blocks4"][-1][2] * scale),
  463. make_divisible(self.net_config["blocks5"][-1][2] * scale),
  464. make_divisible(self.net_config["blocks6"][-1][2] * scale),
  465. ]
  466. self.layer_list = nn.LayerList(
  467. [
  468. nn.Conv2D(self.out_channels[0], int(mv_c[0] * scale), 1, 1, 0),
  469. nn.Conv2D(self.out_channels[1], int(mv_c[1] * scale), 1, 1, 0),
  470. nn.Conv2D(self.out_channels[2], int(mv_c[2] * scale), 1, 1, 0),
  471. nn.Conv2D(self.out_channels[3], int(mv_c[3] * scale), 1, 1, 0),
  472. ]
  473. )
  474. self.out_channels = [
  475. int(mv_c[0] * scale),
  476. int(mv_c[1] * scale),
  477. int(mv_c[2] * scale),
  478. int(mv_c[3] * scale),
  479. ]
  480. def forward(self, x):
  481. out_list = []
  482. x = self.conv1(x)
  483. x = self.blocks2(x)
  484. x = self.blocks3(x)
  485. out_list.append(x)
  486. x = self.blocks4(x)
  487. out_list.append(x)
  488. x = self.blocks5(x)
  489. out_list.append(x)
  490. x = self.blocks6(x)
  491. out_list.append(x)
  492. if self.det:
  493. out_list[0] = self.layer_list[0](out_list[0])
  494. out_list[1] = self.layer_list[1](out_list[1])
  495. out_list[2] = self.layer_list[2](out_list[2])
  496. out_list[3] = self.layer_list[3](out_list[3])
  497. return out_list
  498. if self.training:
  499. x = F.adaptive_avg_pool2d(x, [1, 40])
  500. else:
  501. x = F.avg_pool2d(x, [3, 2])
  502. return x