rec_pphgnetv2.py 59 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525526527528529530531532533534535536537538539540541542543544545546547548549550551552553554555556557558559560561562563564565566567568569570571572573574575576577578579580581582583584585586587588589590591592593594595596597598599600601602603604605606607608609610611612613614615616617618619620621622623624625626627628629630631632633634635636637638639640641642643644645646647648649650651652653654655656657658659660661662663664665666667668669670671672673674675676677678679680681682683684685686687688689690691692693694695696697698699700701702703704705706707708709710711712713714715716717718719720721722723724725726727728729730731732733734735736737738739740741742743744745746747748749750751752753754755756757758759760761762763764765766767768769770771772773774775776777778779780781782783784785786787788789790791792793794795796797798799800801802803804805806807808809810811812813814815816817818819820821822823824825826827828829830831832833834835836837838839840841842843844845846847848849850851852853854855856857858859860861862863864865866867868869870871872873874875876877878879880881882883884885886887888889890891892893894895896897898899900901902903904905906907908909910911912913914915916917918919920921922923924925926927928929930931932933934935936937938939940941942943944945946947948949950951952953954955956957958959960961962963964965966967968969970971972973974975976977978979980981982983984985986987988989990991992993994995996997998999100010011002100310041005100610071008100910101011101210131014101510161017101810191020102110221023102410251026102710281029103010311032103310341035103610371038103910401041104210431044104510461047104810491050105110521053105410551056105710581059106010611062106310641065106610671068106910701071107210731074107510761077107810791080108110821083108410851086108710881089109010911092109310941095109610971098109911001101110211031104110511061107110811091110111111121113111411151116111711181119112011211122112311241125112611271128112911301131113211331134113511361137113811391140114111421143114411451146114711481149115011511152115311541155115611571158115911601161116211631164116511661167116811691170117111721173117411751176117711781179118011811182118311841185118611871188118911901191119211931194119511961197119811991200120112021203120412051206120712081209121012111212121312141215121612171218121912201221122212231224122512261227122812291230123112321233123412351236123712381239124012411242124312441245124612471248124912501251125212531254125512561257125812591260126112621263126412651266126712681269127012711272127312741275127612771278127912801281128212831284128512861287128812891290129112921293129412951296129712981299130013011302130313041305130613071308130913101311131213131314131513161317131813191320132113221323132413251326132713281329133013311332133313341335133613371338133913401341134213431344134513461347134813491350135113521353135413551356135713581359136013611362136313641365136613671368136913701371137213731374137513761377137813791380138113821383138413851386138713881389139013911392139313941395139613971398139914001401140214031404140514061407140814091410141114121413141414151416141714181419142014211422142314241425142614271428142914301431143214331434143514361437143814391440144114421443144414451446144714481449145014511452145314541455145614571458145914601461146214631464146514661467146814691470147114721473147414751476147714781479148014811482148314841485148614871488148914901491149214931494149514961497149814991500150115021503150415051506150715081509151015111512151315141515151615171518151915201521152215231524152515261527152815291530153115321533153415351536153715381539154015411542154315441545154615471548154915501551155215531554155515561557155815591560156115621563156415651566156715681569157015711572157315741575157615771578157915801581158215831584158515861587158815891590159115921593159415951596159715981599160016011602160316041605160616071608160916101611161216131614161516161617161816191620162116221623162416251626162716281629163016311632163316341635163616371638163916401641164216431644164516461647164816491650165116521653165416551656165716581659166016611662166316641665166616671668166916701671167216731674167516761677167816791680168116821683168416851686168716881689169016911692169316941695169616971698169917001701170217031704170517061707170817091710171117121713
  1. # copyright (c) 2024 PaddlePaddle Authors. All Rights Reserve.
  2. #
  3. # Licensed under the Apache License, Version 2.0 (the "License");
  4. # you may not use this file except in compliance with the License.
  5. # You may obtain a copy of the License at
  6. #
  7. # http://www.apache.org/licenses/LICENSE-2.0
  8. #
  9. # Unless required by applicable law or agreed to in writing, software
  10. # distributed under the License is distributed on an "AS IS" BASIS,
  11. # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  12. # See the License for the specific language governing permissions and
  13. # limitations under the License.
  14. """
  15. This code is refer from:
  16. https://github.com/PaddlePaddle/PaddleClas/blob/2f36cab604e439b59d1a854df34ece3b10d888e3/ppcls/arch/backbone/legendary_models/pp_hgnet_v2.py
  17. """
  18. from __future__ import absolute_import, division, print_function
  19. import math
  20. import numpy as np
  21. import paddle
  22. import paddle.nn as nn
  23. import paddle.nn.functional as F
  24. from paddle import ParamAttr
  25. from paddle.nn import Conv2D, BatchNorm, Linear, BatchNorm2D, MaxPool2D, AvgPool2D
  26. from paddle.nn.initializer import Uniform
  27. from paddle.regularizer import L2Decay
  28. from typing import Tuple, List, Dict, Union, Callable, Any
  29. from ppocr.modeling.backbones.rec_donut_swin import DonutSwinModelOutput
  30. class IdentityBasedConv1x1(nn.Conv2D):
  31. def __init__(self, channels, groups=1):
  32. super(IdentityBasedConv1x1, self).__init__(
  33. in_channels=channels,
  34. out_channels=channels,
  35. kernel_size=1,
  36. stride=1,
  37. padding=0,
  38. groups=groups,
  39. bias_attr=False,
  40. )
  41. assert channels % groups == 0
  42. input_dim = channels // groups
  43. id_value = np.zeros((channels, input_dim, 1, 1))
  44. for i in range(channels):
  45. id_value[i, i % input_dim, 0, 0] = 1
  46. self.id_tensor = paddle.to_tensor(id_value)
  47. self.weight.set_value(paddle.zeros_like(self.weight))
  48. def forward(self, input):
  49. kernel = self.weight + self.id_tensor
  50. result = F.conv2d(
  51. input,
  52. kernel,
  53. None,
  54. stride=1,
  55. padding=0,
  56. dilation=self._dilation,
  57. groups=self._groups,
  58. )
  59. return result
  60. def get_actual_kernel(self):
  61. return self.weight + self.id_tensor
  62. class BNAndPad(nn.Layer):
  63. def __init__(
  64. self,
  65. pad_pixels,
  66. num_features,
  67. epsilon=1e-5,
  68. momentum=0.1,
  69. last_conv_bias=None,
  70. bn=nn.BatchNorm2D,
  71. ):
  72. super().__init__()
  73. self.bn = bn(num_features, momentum=momentum, epsilon=epsilon)
  74. self.pad_pixels = pad_pixels
  75. self.last_conv_bias = last_conv_bias
  76. def forward(self, input):
  77. output = self.bn(input)
  78. if self.pad_pixels > 0:
  79. bias = -self.bn._mean
  80. if self.last_conv_bias is not None:
  81. bias += self.last_conv_bias
  82. pad_values = self.bn.bias + self.bn.weight * (
  83. bias / paddle.sqrt(self.bn._variance + self.bn._epsilon)
  84. )
  85. """ pad """
  86. # TODO: n,h,w,c format is not supported yet
  87. n, c, h, w = output.shape
  88. values = pad_values.reshape([1, -1, 1, 1])
  89. w_values = values.expand([n, -1, self.pad_pixels, w])
  90. x = paddle.concat([w_values, output, w_values], axis=2)
  91. h = h + self.pad_pixels * 2
  92. h_values = values.expand([n, -1, h, self.pad_pixels])
  93. x = paddle.concat([h_values, x, h_values], axis=3)
  94. output = x
  95. return output
  96. @property
  97. def weight(self):
  98. return self.bn.weight
  99. @property
  100. def bias(self):
  101. return self.bn.bias
  102. @property
  103. def _mean(self):
  104. return self.bn._mean
  105. @property
  106. def _variance(self):
  107. return self.bn._variance
  108. @property
  109. def _epsilon(self):
  110. return self.bn._epsilon
  111. def conv_bn(
  112. in_channels,
  113. out_channels,
  114. kernel_size,
  115. stride=1,
  116. padding=0,
  117. dilation=1,
  118. groups=1,
  119. padding_mode="zeros",
  120. ):
  121. conv_layer = nn.Conv2D(
  122. in_channels=in_channels,
  123. out_channels=out_channels,
  124. kernel_size=kernel_size,
  125. stride=stride,
  126. padding=padding,
  127. dilation=dilation,
  128. groups=groups,
  129. bias_attr=False,
  130. padding_mode=padding_mode,
  131. )
  132. bn_layer = nn.BatchNorm2D(num_features=out_channels)
  133. se = nn.Sequential()
  134. se.add_sublayer("conv", conv_layer)
  135. se.add_sublayer("bn", bn_layer)
  136. return se
  137. def transI_fusebn(kernel, bn):
  138. gamma = bn.weight
  139. std = (bn._variance + bn._epsilon).sqrt()
  140. return (
  141. kernel * ((gamma / std).reshape([-1, 1, 1, 1])),
  142. bn.bias - bn._mean * gamma / std,
  143. )
  144. def transII_addbranch(kernels, biases):
  145. return sum(kernels), sum(biases)
  146. def transIII_1x1_kxk(k1, b1, k2, b2, groups):
  147. if groups == 1:
  148. k = F.conv2d(k2, k1.transpose([1, 0, 2, 3]))
  149. b_hat = (k2 * b1.reshape([1, -1, 1, 1])).sum((1, 2, 3))
  150. else:
  151. k_slices = []
  152. b_slices = []
  153. k1_T = k1.transpose([1, 0, 2, 3])
  154. k1_group_width = k1.shape[0] // groups
  155. k2_group_width = k2.shape[0] // groups
  156. for g in range(groups):
  157. k1_T_slice = k1_T[:, g * k1_group_width : (g + 1) * k1_group_width, :, :]
  158. k2_slice = k2[g * k2_group_width : (g + 1) * k2_group_width, :, :, :]
  159. k_slices.append(F.conv2d(k2_slice, k1_T_slice))
  160. b_slices.append(
  161. (
  162. k2_slice
  163. * b1[g * k1_group_width : (g + 1) * k1_group_width].reshape(
  164. [1, -1, 1, 1]
  165. )
  166. ).sum((1, 2, 3))
  167. )
  168. k, b_hat = transIV_depthconcat(k_slices, b_slices)
  169. return k, b_hat + b2
  170. def transIV_depthconcat(kernels, biases):
  171. return paddle.cat(kernels, axis=0), paddle.cat(biases)
  172. def transV_avg(channels, kernel_size, groups):
  173. input_dim = channels // groups
  174. k = paddle.zeros((channels, input_dim, kernel_size, kernel_size))
  175. k[np.arange(channels), np.tile(np.arange(input_dim), groups), :, :] = (
  176. 1.0 / kernel_size**2
  177. )
  178. return k
  179. def transVI_multiscale(kernel, target_kernel_size):
  180. H_pixels_to_pad = (target_kernel_size - kernel.shape[2]) // 2
  181. W_pixels_to_pad = (target_kernel_size - kernel.shape[3]) // 2
  182. return F.pad(
  183. kernel, [H_pixels_to_pad, H_pixels_to_pad, W_pixels_to_pad, W_pixels_to_pad]
  184. )
  185. class DiverseBranchBlock(nn.Layer):
  186. def __init__(
  187. self,
  188. num_channels,
  189. num_filters,
  190. filter_size,
  191. stride=1,
  192. groups=1,
  193. act=None,
  194. is_repped=False,
  195. single_init=False,
  196. **kwargs,
  197. ):
  198. super().__init__()
  199. padding = (filter_size - 1) // 2
  200. dilation = 1
  201. in_channels = num_channels
  202. out_channels = num_filters
  203. kernel_size = filter_size
  204. internal_channels_1x1_3x3 = None
  205. nonlinear = act
  206. self.is_repped = is_repped
  207. if nonlinear is None:
  208. self.nonlinear = nn.Identity()
  209. else:
  210. self.nonlinear = nn.ReLU()
  211. self.kernel_size = kernel_size
  212. self.out_channels = out_channels
  213. self.groups = groups
  214. assert padding == kernel_size // 2
  215. if is_repped:
  216. self.dbb_reparam = nn.Conv2D(
  217. in_channels=in_channels,
  218. out_channels=out_channels,
  219. kernel_size=kernel_size,
  220. stride=stride,
  221. padding=padding,
  222. dilation=dilation,
  223. groups=groups,
  224. bias_attr=True,
  225. )
  226. else:
  227. self.dbb_origin = conv_bn(
  228. in_channels=in_channels,
  229. out_channels=out_channels,
  230. kernel_size=kernel_size,
  231. stride=stride,
  232. padding=padding,
  233. dilation=dilation,
  234. groups=groups,
  235. )
  236. self.dbb_avg = nn.Sequential()
  237. if groups < out_channels:
  238. self.dbb_avg.add_sublayer(
  239. "conv",
  240. nn.Conv2D(
  241. in_channels=in_channels,
  242. out_channels=out_channels,
  243. kernel_size=1,
  244. stride=1,
  245. padding=0,
  246. groups=groups,
  247. bias_attr=False,
  248. ),
  249. )
  250. self.dbb_avg.add_sublayer(
  251. "bn", BNAndPad(pad_pixels=padding, num_features=out_channels)
  252. )
  253. self.dbb_avg.add_sublayer(
  254. "avg",
  255. nn.AvgPool2D(kernel_size=kernel_size, stride=stride, padding=0),
  256. )
  257. self.dbb_1x1 = conv_bn(
  258. in_channels=in_channels,
  259. out_channels=out_channels,
  260. kernel_size=1,
  261. stride=stride,
  262. padding=0,
  263. groups=groups,
  264. )
  265. else:
  266. self.dbb_avg.add_sublayer(
  267. "avg",
  268. nn.AvgPool2D(
  269. kernel_size=kernel_size, stride=stride, padding=padding
  270. ),
  271. )
  272. self.dbb_avg.add_sublayer("avgbn", nn.BatchNorm2D(out_channels))
  273. if internal_channels_1x1_3x3 is None:
  274. internal_channels_1x1_3x3 = (
  275. in_channels if groups < out_channels else 2 * in_channels
  276. ) # For mobilenet, it is better to have 2X internal channels
  277. self.dbb_1x1_kxk = nn.Sequential()
  278. if internal_channels_1x1_3x3 == in_channels:
  279. self.dbb_1x1_kxk.add_sublayer(
  280. "idconv1", IdentityBasedConv1x1(channels=in_channels, groups=groups)
  281. )
  282. else:
  283. self.dbb_1x1_kxk.add_sublayer(
  284. "conv1",
  285. nn.Conv2D(
  286. in_channels=in_channels,
  287. out_channels=internal_channels_1x1_3x3,
  288. kernel_size=1,
  289. stride=1,
  290. padding=0,
  291. groups=groups,
  292. bias_attr=False,
  293. ),
  294. )
  295. self.dbb_1x1_kxk.add_sublayer(
  296. "bn1",
  297. BNAndPad(pad_pixels=padding, num_features=internal_channels_1x1_3x3),
  298. )
  299. self.dbb_1x1_kxk.add_sublayer(
  300. "conv2",
  301. nn.Conv2D(
  302. in_channels=internal_channels_1x1_3x3,
  303. out_channels=out_channels,
  304. kernel_size=kernel_size,
  305. stride=stride,
  306. padding=0,
  307. groups=groups,
  308. bias_attr=False,
  309. ),
  310. )
  311. self.dbb_1x1_kxk.add_sublayer("bn2", nn.BatchNorm2D(out_channels))
  312. # The experiments reported in the paper used the default initialization of bn.weight (all as 1). But changing the initialization may be useful in some cases.
  313. if single_init:
  314. # Initialize the bn.weight of dbb_origin as 1 and others as 0. This is not the default setting.
  315. self.single_init()
  316. def forward(self, inputs):
  317. if self.is_repped:
  318. return self.nonlinear(self.dbb_reparam(inputs))
  319. out = self.dbb_origin(inputs)
  320. if hasattr(self, "dbb_1x1"):
  321. out += self.dbb_1x1(inputs)
  322. out += self.dbb_avg(inputs)
  323. out += self.dbb_1x1_kxk(inputs)
  324. return self.nonlinear(out)
  325. def init_gamma(self, gamma_value):
  326. if hasattr(self, "dbb_origin"):
  327. paddle.nn.init.constant_(self.dbb_origin.bn.weight, gamma_value)
  328. if hasattr(self, "dbb_1x1"):
  329. paddle.nn.init.constant_(self.dbb_1x1.bn.weight, gamma_value)
  330. if hasattr(self, "dbb_avg"):
  331. paddle.nn.init.constant_(self.dbb_avg.avgbn.weight, gamma_value)
  332. if hasattr(self, "dbb_1x1_kxk"):
  333. paddle.nn.init.constant_(self.dbb_1x1_kxk.bn2.weight, gamma_value)
  334. def single_init(self):
  335. self.init_gamma(0.0)
  336. if hasattr(self, "dbb_origin"):
  337. paddle.nn.init.constant_(self.dbb_origin.bn.weight, 1.0)
  338. def get_equivalent_kernel_bias(self):
  339. k_origin, b_origin = transI_fusebn(
  340. self.dbb_origin.conv.weight, self.dbb_origin.bn
  341. )
  342. if hasattr(self, "dbb_1x1"):
  343. k_1x1, b_1x1 = transI_fusebn(self.dbb_1x1.conv.weight, self.dbb_1x1.bn)
  344. k_1x1 = transVI_multiscale(k_1x1, self.kernel_size)
  345. else:
  346. k_1x1, b_1x1 = 0, 0
  347. if hasattr(self.dbb_1x1_kxk, "idconv1"):
  348. k_1x1_kxk_first = self.dbb_1x1_kxk.idconv1.get_actual_kernel()
  349. else:
  350. k_1x1_kxk_first = self.dbb_1x1_kxk.conv1.weight
  351. k_1x1_kxk_first, b_1x1_kxk_first = transI_fusebn(
  352. k_1x1_kxk_first, self.dbb_1x1_kxk.bn1
  353. )
  354. k_1x1_kxk_second, b_1x1_kxk_second = transI_fusebn(
  355. self.dbb_1x1_kxk.conv2.weight, self.dbb_1x1_kxk.bn2
  356. )
  357. k_1x1_kxk_merged, b_1x1_kxk_merged = transIII_1x1_kxk(
  358. k_1x1_kxk_first,
  359. b_1x1_kxk_first,
  360. k_1x1_kxk_second,
  361. b_1x1_kxk_second,
  362. groups=self.groups,
  363. )
  364. k_avg = transV_avg(self.out_channels, self.kernel_size, self.groups)
  365. k_1x1_avg_second, b_1x1_avg_second = transI_fusebn(k_avg, self.dbb_avg.avgbn)
  366. if hasattr(self.dbb_avg, "conv"):
  367. k_1x1_avg_first, b_1x1_avg_first = transI_fusebn(
  368. self.dbb_avg.conv.weight, self.dbb_avg.bn
  369. )
  370. k_1x1_avg_merged, b_1x1_avg_merged = transIII_1x1_kxk(
  371. k_1x1_avg_first,
  372. b_1x1_avg_first,
  373. k_1x1_avg_second,
  374. b_1x1_avg_second,
  375. groups=self.groups,
  376. )
  377. else:
  378. k_1x1_avg_merged, b_1x1_avg_merged = k_1x1_avg_second, b_1x1_avg_second
  379. return transII_addbranch(
  380. (k_origin, k_1x1, k_1x1_kxk_merged, k_1x1_avg_merged),
  381. (b_origin, b_1x1, b_1x1_kxk_merged, b_1x1_avg_merged),
  382. )
  383. def re_parameterize(self):
  384. if self.is_repped:
  385. return
  386. kernel, bias = self.get_equivalent_kernel_bias()
  387. self.dbb_reparam = nn.Conv2D(
  388. in_channels=self.dbb_origin.conv._in_channels,
  389. out_channels=self.dbb_origin.conv._out_channels,
  390. kernel_size=self.dbb_origin.conv._kernel_size,
  391. stride=self.dbb_origin.conv._stride,
  392. padding=self.dbb_origin.conv._padding,
  393. dilation=self.dbb_origin.conv._dilation,
  394. groups=self.dbb_origin.conv._groups,
  395. bias_attr=True,
  396. )
  397. self.dbb_reparam.weight.set_value(kernel)
  398. self.dbb_reparam.bias.set_value(bias)
  399. self.__delattr__("dbb_origin")
  400. self.__delattr__("dbb_avg")
  401. if hasattr(self, "dbb_1x1"):
  402. self.__delattr__("dbb_1x1")
  403. self.__delattr__("dbb_1x1_kxk")
  404. self.is_repped = True
  405. class Identity(nn.Layer):
  406. def __init__(self):
  407. super(Identity, self).__init__()
  408. def forward(self, inputs):
  409. return inputs
  410. class TheseusLayer(nn.Layer):
  411. def __init__(self, *args, **kwargs):
  412. super().__init__()
  413. self.res_dict = {}
  414. self.res_name = self.full_name()
  415. self.pruner = None
  416. self.quanter = None
  417. self.init_net(*args, **kwargs)
  418. def _return_dict_hook(self, layer, input, output):
  419. res_dict = {"logits": output}
  420. # 'list' is needed to avoid error raised by popping self.res_dict
  421. for res_key in list(self.res_dict):
  422. # clear the res_dict because the forward process may change according to input
  423. res_dict[res_key] = self.res_dict.pop(res_key)
  424. return res_dict
  425. def init_net(
  426. self,
  427. stages_pattern=None,
  428. return_patterns=None,
  429. return_stages=None,
  430. freeze_befor=None,
  431. stop_after=None,
  432. *args,
  433. **kwargs,
  434. ):
  435. # init the output of net
  436. if return_patterns or return_stages:
  437. if return_patterns and return_stages:
  438. msg = f"The 'return_patterns' would be ignored when 'return_stages' is set."
  439. return_stages = None
  440. if return_stages is True:
  441. return_patterns = stages_pattern
  442. # return_stages is int or bool
  443. if type(return_stages) is int:
  444. return_stages = [return_stages]
  445. if isinstance(return_stages, list):
  446. if max(return_stages) > len(stages_pattern) or min(return_stages) < 0:
  447. msg = f"The 'return_stages' set error. Illegal value(s) have been ignored. The stages' pattern list is {stages_pattern}."
  448. return_stages = [
  449. val
  450. for val in return_stages
  451. if val >= 0 and val < len(stages_pattern)
  452. ]
  453. return_patterns = [stages_pattern[i] for i in return_stages]
  454. if return_patterns:
  455. # call update_res function after the __init__ of the object has completed execution, that is, the constructing of layer or model has been completed.
  456. def update_res_hook(layer, input):
  457. self.update_res(return_patterns)
  458. self.register_forward_pre_hook(update_res_hook)
  459. # freeze subnet
  460. if freeze_befor is not None:
  461. self.freeze_befor(freeze_befor)
  462. # set subnet to Identity
  463. if stop_after is not None:
  464. self.stop_after(stop_after)
  465. def init_res(self, stages_pattern, return_patterns=None, return_stages=None):
  466. if return_patterns and return_stages:
  467. return_stages = None
  468. if return_stages is True:
  469. return_patterns = stages_pattern
  470. # return_stages is int or bool
  471. if type(return_stages) is int:
  472. return_stages = [return_stages]
  473. if isinstance(return_stages, list):
  474. if max(return_stages) > len(stages_pattern) or min(return_stages) < 0:
  475. return_stages = [
  476. val
  477. for val in return_stages
  478. if val >= 0 and val < len(stages_pattern)
  479. ]
  480. return_patterns = [stages_pattern[i] for i in return_stages]
  481. if return_patterns:
  482. self.update_res(return_patterns)
  483. def replace_sub(self, *args, **kwargs) -> None:
  484. msg = "The function 'replace_sub()' is deprecated, please use 'upgrade_sublayer()' instead."
  485. raise DeprecationWarning(msg)
  486. def upgrade_sublayer(
  487. self,
  488. layer_name_pattern: Union[str, List[str]],
  489. handle_func: Callable[[nn.Layer, str], nn.Layer],
  490. ) -> Dict[str, nn.Layer]:
  491. """use 'handle_func' to modify the sub-layer(s) specified by 'layer_name_pattern'.
  492. Args:
  493. layer_name_pattern (Union[str, List[str]]): The name of layer to be modified by 'handle_func'.
  494. handle_func (Callable[[nn.Layer, str], nn.Layer]): The function to modify target layer specified by 'layer_name_pattern'. The formal params are the layer(nn.Layer) and pattern(str) that is (a member of) layer_name_pattern (when layer_name_pattern is List type). And the return is the layer processed.
  495. Returns:
  496. Dict[str, nn.Layer]: The key is the pattern and corresponding value is the result returned by 'handle_func()'.
  497. Examples:
  498. from paddle import nn
  499. import paddleclas
  500. def rep_func(layer: nn.Layer, pattern: str):
  501. new_layer = nn.Conv2D(
  502. in_channels=layer._in_channels,
  503. out_channels=layer._out_channels,
  504. kernel_size=5,
  505. padding=2
  506. )
  507. return new_layer
  508. net = paddleclas.MobileNetV1()
  509. res = net.upgrade_sublayer(layer_name_pattern=["blocks[11].depthwise_conv.conv", "blocks[12].depthwise_conv.conv"], handle_func=rep_func)
  510. print(res)
  511. # {'blocks[11].depthwise_conv.conv': the corresponding new_layer, 'blocks[12].depthwise_conv.conv': the corresponding new_layer}
  512. """
  513. if not isinstance(layer_name_pattern, list):
  514. layer_name_pattern = [layer_name_pattern]
  515. hit_layer_pattern_list = []
  516. for pattern in layer_name_pattern:
  517. # parse pattern to find target layer and its parent
  518. layer_list = parse_pattern_str(pattern=pattern, parent_layer=self)
  519. if not layer_list:
  520. continue
  521. sub_layer_parent = layer_list[-2]["layer"] if len(layer_list) > 1 else self
  522. sub_layer = layer_list[-1]["layer"]
  523. sub_layer_name = layer_list[-1]["name"]
  524. sub_layer_index_list = layer_list[-1]["index_list"]
  525. new_sub_layer = handle_func(sub_layer, pattern)
  526. if sub_layer_index_list:
  527. if len(sub_layer_index_list) > 1:
  528. sub_layer_parent = getattr(sub_layer_parent, sub_layer_name)[
  529. sub_layer_index_list[0]
  530. ]
  531. for sub_layer_index in sub_layer_index_list[1:-1]:
  532. sub_layer_parent = sub_layer_parent[sub_layer_index]
  533. sub_layer_parent[sub_layer_index_list[-1]] = new_sub_layer
  534. else:
  535. getattr(sub_layer_parent, sub_layer_name)[
  536. sub_layer_index_list[0]
  537. ] = new_sub_layer
  538. else:
  539. setattr(sub_layer_parent, sub_layer_name, new_sub_layer)
  540. hit_layer_pattern_list.append(pattern)
  541. return hit_layer_pattern_list
  542. def stop_after(self, stop_layer_name: str) -> bool:
  543. """stop forward and backward after 'stop_layer_name'.
  544. Args:
  545. stop_layer_name (str): The name of layer that stop forward and backward after this layer.
  546. Returns:
  547. bool: 'True' if successful, 'False' otherwise.
  548. """
  549. layer_list = parse_pattern_str(stop_layer_name, self)
  550. if not layer_list:
  551. return False
  552. parent_layer = self
  553. for layer_dict in layer_list:
  554. name, index_list = layer_dict["name"], layer_dict["index_list"]
  555. if not set_identity(parent_layer, name, index_list):
  556. msg = f"Failed to set the layers that after stop_layer_name('{stop_layer_name}') to IdentityLayer. The error layer's name is '{name}'."
  557. return False
  558. parent_layer = layer_dict["layer"]
  559. return True
  560. def freeze_befor(self, layer_name: str) -> bool:
  561. """freeze the layer named layer_name and its previous layer.
  562. Args:
  563. layer_name (str): The name of layer that would be freezed.
  564. Returns:
  565. bool: 'True' if successful, 'False' otherwise.
  566. """
  567. def stop_grad(layer, pattern):
  568. class StopGradLayer(nn.Layer):
  569. def __init__(self):
  570. super().__init__()
  571. self.layer = layer
  572. def forward(self, x):
  573. x = self.layer(x)
  574. x.stop_gradient = True
  575. return x
  576. new_layer = StopGradLayer()
  577. return new_layer
  578. res = self.upgrade_sublayer(layer_name, stop_grad)
  579. if len(res) == 0:
  580. msg = "Failed to stop the gradient before the layer named '{layer_name}'"
  581. return False
  582. return True
  583. def update_res(self, return_patterns: Union[str, List[str]]) -> Dict[str, nn.Layer]:
  584. """update the result(s) to be returned.
  585. Args:
  586. return_patterns (Union[str, List[str]]): The name of layer to return output.
  587. Returns:
  588. Dict[str, nn.Layer]: The pattern(str) and corresponding layer(nn.Layer) that have been set successfully.
  589. """
  590. # clear res_dict that could have been set
  591. self.res_dict = {}
  592. class Handler(object):
  593. def __init__(self, res_dict):
  594. # res_dict is a reference
  595. self.res_dict = res_dict
  596. def __call__(self, layer, pattern):
  597. layer.res_dict = self.res_dict
  598. layer.res_name = pattern
  599. if hasattr(layer, "hook_remove_helper"):
  600. layer.hook_remove_helper.remove()
  601. layer.hook_remove_helper = layer.register_forward_post_hook(
  602. save_sub_res_hook
  603. )
  604. return layer
  605. handle_func = Handler(self.res_dict)
  606. hit_layer_pattern_list = self.upgrade_sublayer(
  607. return_patterns, handle_func=handle_func
  608. )
  609. if hasattr(self, "hook_remove_helper"):
  610. self.hook_remove_helper.remove()
  611. self.hook_remove_helper = self.register_forward_post_hook(
  612. self._return_dict_hook
  613. )
  614. return hit_layer_pattern_list
  615. def save_sub_res_hook(layer, input, output):
  616. layer.res_dict[layer.res_name] = output
  617. def set_identity(
  618. parent_layer: nn.Layer, layer_name: str, layer_index_list: str = None
  619. ) -> bool:
  620. """set the layer specified by layer_name and layer_index_list to Identity.
  621. Args:
  622. parent_layer (nn.Layer): The parent layer of target layer specified by layer_name and layer_index_list.
  623. layer_name (str): The name of target layer to be set to Identity.
  624. layer_index_list (str, optional): The index of target layer to be set to Identity in parent_layer. Defaults to None.
  625. Returns:
  626. bool: True if successfully, False otherwise.
  627. """
  628. stop_after = False
  629. for sub_layer_name in parent_layer._sub_layers:
  630. if stop_after:
  631. parent_layer._sub_layers[sub_layer_name] = Identity()
  632. continue
  633. if sub_layer_name == layer_name:
  634. stop_after = True
  635. if layer_index_list and stop_after:
  636. layer_container = parent_layer._sub_layers[layer_name]
  637. for num, layer_index in enumerate(layer_index_list):
  638. stop_after = False
  639. for i in range(num):
  640. layer_container = layer_container[layer_index_list[i]]
  641. for sub_layer_index in layer_container._sub_layers:
  642. if stop_after:
  643. parent_layer._sub_layers[layer_name][sub_layer_index] = Identity()
  644. continue
  645. if layer_index == sub_layer_index:
  646. stop_after = True
  647. return stop_after
  648. def parse_pattern_str(
  649. pattern: str, parent_layer: nn.Layer
  650. ) -> Union[None, List[Dict[str, Union[nn.Layer, str, None]]]]:
  651. """parse the string type pattern.
  652. Args:
  653. pattern (str): The pattern to describe layer.
  654. parent_layer (nn.Layer): The root layer relative to the pattern.
  655. Returns:
  656. Union[None, List[Dict[str, Union[nn.Layer, str, None]]]]: None if failed. If successfully, the members are layers parsed in order:
  657. [
  658. {"layer": first layer, "name": first layer's name parsed, "index": first layer's index parsed if exist},
  659. {"layer": second layer, "name": second layer's name parsed, "index": second layer's index parsed if exist},
  660. ...
  661. ]
  662. """
  663. pattern_list = pattern.split(".")
  664. if not pattern_list:
  665. msg = f"The pattern('{pattern}') is illegal. Please check and retry."
  666. return None
  667. layer_list = []
  668. while len(pattern_list) > 0:
  669. if "[" in pattern_list[0]:
  670. target_layer_name = pattern_list[0].split("[")[0]
  671. target_layer_index_list = list(
  672. index.split("]")[0] for index in pattern_list[0].split("[")[1:]
  673. )
  674. else:
  675. target_layer_name = pattern_list[0]
  676. target_layer_index_list = None
  677. target_layer = getattr(parent_layer, target_layer_name, None)
  678. if target_layer is None:
  679. msg = f"Not found layer named('{target_layer_name}') specified in pattern('{pattern}')."
  680. return None
  681. if target_layer_index_list:
  682. for target_layer_index in target_layer_index_list:
  683. if int(target_layer_index) < 0 or int(target_layer_index) >= len(
  684. target_layer
  685. ):
  686. msg = f"Not found layer by index('{target_layer_index}') specified in pattern('{pattern}'). The index should < {len(target_layer)} and > 0."
  687. return None
  688. target_layer = target_layer[target_layer_index]
  689. layer_list.append(
  690. {
  691. "layer": target_layer,
  692. "name": target_layer_name,
  693. "index_list": target_layer_index_list,
  694. }
  695. )
  696. pattern_list = pattern_list[1:]
  697. parent_layer = target_layer
  698. return layer_list
  699. class AdaptiveAvgPool2D(nn.AdaptiveAvgPool2D):
  700. def __init__(self, *args, **kwargs):
  701. super().__init__(*args, **kwargs)
  702. if paddle.device.get_device().startswith("npu"):
  703. self.device = "npu"
  704. else:
  705. self.device = None
  706. if isinstance(self._output_size, int) and self._output_size == 1:
  707. self._gap = True
  708. elif (
  709. isinstance(self._output_size, tuple)
  710. and self._output_size[0] == 1
  711. and self._output_size[1] == 1
  712. ):
  713. self._gap = True
  714. else:
  715. self._gap = False
  716. def forward(self, x):
  717. if self.device == "npu" and self._gap:
  718. # Global Average Pooling
  719. N, C, _, _ = x.shape
  720. x_mean = paddle.mean(x, axis=[2, 3])
  721. x_mean = paddle.reshape(x_mean, [N, C, 1, 1])
  722. return x_mean
  723. else:
  724. return F.adaptive_avg_pool2d(
  725. x,
  726. output_size=self._output_size,
  727. data_format=self._data_format,
  728. name=self._name,
  729. )
  730. # copyright (c) 2023 PaddlePaddle Authors. All Rights Reserve.
  731. #
  732. # Licensed under the Apache License, Version 2.0 (the "License");
  733. # you may not use this file except in compliance with the License.
  734. # You may obtain a copy of the License at
  735. #
  736. # http://www.apache.org/licenses/LICENSE-2.0
  737. #
  738. # Unless required by applicable law or agreed to in writing, software
  739. # distributed under the License is distributed on an "AS IS" BASIS,
  740. # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  741. # See the License for the specific language governing permissions and
  742. # limitations under the License.
  743. import paddle
  744. import paddle.nn as nn
  745. import paddle.nn.functional as F
  746. from paddle.nn.initializer import KaimingNormal, Constant
  747. from paddle.nn import Conv2D, BatchNorm2D, ReLU, AdaptiveAvgPool2D, MaxPool2D
  748. from paddle.regularizer import L2Decay
  749. from paddle import ParamAttr
  750. MODEL_URLS = {
  751. "PPHGNetV2_B0": "https://paddle-imagenet-models-name.bj.bcebos.com/dygraph/legendary_models/PPHGNetV2_B0_ssld_pretrained.pdparams",
  752. "PPHGNetV2_B1": "https://paddle-imagenet-models-name.bj.bcebos.com/dygraph/legendary_models/PPHGNetV2_B1_ssld_pretrained.pdparams",
  753. "PPHGNetV2_B2": "https://paddle-imagenet-models-name.bj.bcebos.com/dygraph/legendary_models/PPHGNetV2_B2_ssld_pretrained.pdparams",
  754. "PPHGNetV2_B3": "https://paddle-imagenet-models-name.bj.bcebos.com/dygraph/legendary_models/PPHGNetV2_B3_ssld_pretrained.pdparams",
  755. "PPHGNetV2_B4": "https://paddle-imagenet-models-name.bj.bcebos.com/dygraph/legendary_models/PPHGNetV2_B4_ssld_pretrained.pdparams",
  756. "PPHGNetV2_B5": "https://paddle-imagenet-models-name.bj.bcebos.com/dygraph/legendary_models/PPHGNetV2_B5_ssld_pretrained.pdparams",
  757. "PPHGNetV2_B6": "https://paddle-imagenet-models-name.bj.bcebos.com/dygraph/legendary_models/PPHGNetV2_B6_ssld_pretrained.pdparams",
  758. }
  759. __all__ = list(MODEL_URLS.keys())
  760. kaiming_normal_ = KaimingNormal()
  761. zeros_ = Constant(value=0.0)
  762. ones_ = Constant(value=1.0)
  763. class LearnableAffineBlock(TheseusLayer):
  764. """
  765. Create a learnable affine block module. This module can significantly improve accuracy on smaller models.
  766. Args:
  767. scale_value (float): The initial value of the scale parameter, default is 1.0.
  768. bias_value (float): The initial value of the bias parameter, default is 0.0.
  769. lr_mult (float): The learning rate multiplier, default is 1.0.
  770. lab_lr (float): The learning rate, default is 0.01.
  771. """
  772. def __init__(self, scale_value=1.0, bias_value=0.0, lr_mult=1.0, lab_lr=0.01):
  773. super().__init__()
  774. self.scale = self.create_parameter(
  775. shape=[
  776. 1,
  777. ],
  778. default_initializer=Constant(value=scale_value),
  779. attr=ParamAttr(learning_rate=lr_mult * lab_lr),
  780. )
  781. self.add_parameter("scale", self.scale)
  782. self.bias = self.create_parameter(
  783. shape=[
  784. 1,
  785. ],
  786. default_initializer=Constant(value=bias_value),
  787. attr=ParamAttr(learning_rate=lr_mult * lab_lr),
  788. )
  789. self.add_parameter("bias", self.bias)
  790. def forward(self, x):
  791. return self.scale * x + self.bias
  792. class ConvBNAct(TheseusLayer):
  793. """
  794. ConvBNAct is a combination of convolution and batchnorm layers.
  795. Args:
  796. in_channels (int): Number of input channels.
  797. out_channels (int): Number of output channels.
  798. kernel_size (int): Size of the convolution kernel. Defaults to 3.
  799. stride (int): Stride of the convolution. Defaults to 1.
  800. padding (int/str): Padding or padding type for the convolution. Defaults to 1.
  801. groups (int): Number of groups for the convolution. Defaults to 1.
  802. use_act: (bool): Whether to use activation function. Defaults to True.
  803. use_lab (bool): Whether to use the LAB operation. Defaults to False.
  804. lr_mult (float): Learning rate multiplier for the layer. Defaults to 1.0.
  805. """
  806. def __init__(
  807. self,
  808. in_channels,
  809. out_channels,
  810. kernel_size=3,
  811. stride=1,
  812. padding=1,
  813. groups=1,
  814. use_act=True,
  815. use_lab=False,
  816. lr_mult=1.0,
  817. ):
  818. super().__init__()
  819. self.use_act = use_act
  820. self.use_lab = use_lab
  821. self.conv = Conv2D(
  822. in_channels,
  823. out_channels,
  824. kernel_size,
  825. stride,
  826. padding=padding if isinstance(padding, str) else (kernel_size - 1) // 2,
  827. groups=groups,
  828. weight_attr=ParamAttr(learning_rate=lr_mult),
  829. bias_attr=False,
  830. )
  831. self.bn = BatchNorm2D(
  832. out_channels,
  833. weight_attr=ParamAttr(regularizer=L2Decay(0.0), learning_rate=lr_mult),
  834. bias_attr=ParamAttr(regularizer=L2Decay(0.0), learning_rate=lr_mult),
  835. )
  836. if self.use_act:
  837. self.act = ReLU()
  838. if self.use_lab:
  839. self.lab = LearnableAffineBlock(lr_mult=lr_mult)
  840. def forward(self, x):
  841. x = self.conv(x)
  842. x = self.bn(x)
  843. if self.use_act:
  844. x = self.act(x)
  845. if self.use_lab:
  846. x = self.lab(x)
  847. return x
  848. class LightConvBNAct(TheseusLayer):
  849. """
  850. LightConvBNAct is a combination of pw and dw layers.
  851. Args:
  852. in_channels (int): Number of input channels.
  853. out_channels (int): Number of output channels.
  854. kernel_size (int): Size of the depth-wise convolution kernel.
  855. use_lab (bool): Whether to use the LAB operation. Defaults to False.
  856. lr_mult (float): Learning rate multiplier for the layer. Defaults to 1.0.
  857. """
  858. def __init__(
  859. self,
  860. in_channels,
  861. out_channels,
  862. kernel_size,
  863. use_lab=False,
  864. lr_mult=1.0,
  865. **kwargs,
  866. ):
  867. super().__init__()
  868. self.conv1 = ConvBNAct(
  869. in_channels=in_channels,
  870. out_channels=out_channels,
  871. kernel_size=1,
  872. use_act=False,
  873. use_lab=use_lab,
  874. lr_mult=lr_mult,
  875. )
  876. self.conv2 = ConvBNAct(
  877. in_channels=out_channels,
  878. out_channels=out_channels,
  879. kernel_size=kernel_size,
  880. groups=out_channels,
  881. use_act=True,
  882. use_lab=use_lab,
  883. lr_mult=lr_mult,
  884. )
  885. def forward(self, x):
  886. x = self.conv1(x)
  887. x = self.conv2(x)
  888. return x
  889. class StemBlock(TheseusLayer):
  890. """
  891. StemBlock for PP-HGNetV2.
  892. Args:
  893. in_channels (int): Number of input channels.
  894. mid_channels (int): Number of middle channels.
  895. out_channels (int): Number of output channels.
  896. use_lab (bool): Whether to use the LAB operation. Defaults to False.
  897. lr_mult (float): Learning rate multiplier for the layer. Defaults to 1.0.
  898. """
  899. def __init__(
  900. self,
  901. in_channels,
  902. mid_channels,
  903. out_channels,
  904. use_lab=False,
  905. lr_mult=1.0,
  906. text_rec=False,
  907. ):
  908. super().__init__()
  909. self.stem1 = ConvBNAct(
  910. in_channels=in_channels,
  911. out_channels=mid_channels,
  912. kernel_size=3,
  913. stride=2,
  914. use_lab=use_lab,
  915. lr_mult=lr_mult,
  916. )
  917. self.stem2a = ConvBNAct(
  918. in_channels=mid_channels,
  919. out_channels=mid_channels // 2,
  920. kernel_size=2,
  921. stride=1,
  922. padding="SAME",
  923. use_lab=use_lab,
  924. lr_mult=lr_mult,
  925. )
  926. self.stem2b = ConvBNAct(
  927. in_channels=mid_channels // 2,
  928. out_channels=mid_channels,
  929. kernel_size=2,
  930. stride=1,
  931. padding="SAME",
  932. use_lab=use_lab,
  933. lr_mult=lr_mult,
  934. )
  935. self.stem3 = ConvBNAct(
  936. in_channels=mid_channels * 2,
  937. out_channels=mid_channels,
  938. kernel_size=3,
  939. stride=1 if text_rec else 2,
  940. use_lab=use_lab,
  941. lr_mult=lr_mult,
  942. )
  943. self.stem4 = ConvBNAct(
  944. in_channels=mid_channels,
  945. out_channels=out_channels,
  946. kernel_size=1,
  947. stride=1,
  948. use_lab=use_lab,
  949. lr_mult=lr_mult,
  950. )
  951. self.pool = nn.MaxPool2D(
  952. kernel_size=2, stride=1, ceil_mode=True, padding="SAME"
  953. )
  954. def forward(self, x):
  955. x = self.stem1(x)
  956. x2 = self.stem2a(x)
  957. x2 = self.stem2b(x2)
  958. x1 = self.pool(x)
  959. x = paddle.concat([x1, x2], 1)
  960. x = self.stem3(x)
  961. x = self.stem4(x)
  962. return x
  963. class HGV2_Block(TheseusLayer):
  964. """
  965. HGV2_Block, the basic unit that constitutes the HGV2_Stage.
  966. Args:
  967. in_channels (int): Number of input channels.
  968. mid_channels (int): Number of middle channels.
  969. out_channels (int): Number of output channels.
  970. kernel_size (int): Size of the convolution kernel. Defaults to 3.
  971. layer_num (int): Number of layers in the HGV2 block. Defaults to 6.
  972. stride (int): Stride of the convolution. Defaults to 1.
  973. padding (int/str): Padding or padding type for the convolution. Defaults to 1.
  974. groups (int): Number of groups for the convolution. Defaults to 1.
  975. use_act (bool): Whether to use activation function. Defaults to True.
  976. use_lab (bool): Whether to use the LAB operation. Defaults to False.
  977. lr_mult (float): Learning rate multiplier for the layer. Defaults to 1.0.
  978. """
  979. def __init__(
  980. self,
  981. in_channels,
  982. mid_channels,
  983. out_channels,
  984. kernel_size=3,
  985. layer_num=6,
  986. identity=False,
  987. light_block=True,
  988. use_lab=False,
  989. lr_mult=1.0,
  990. ):
  991. super().__init__()
  992. self.identity = identity
  993. self.layers = nn.LayerList()
  994. block_type = "LightConvBNAct" if light_block else "ConvBNAct"
  995. for i in range(layer_num):
  996. self.layers.append(
  997. eval(block_type)(
  998. in_channels=in_channels if i == 0 else mid_channels,
  999. out_channels=mid_channels,
  1000. stride=1,
  1001. kernel_size=kernel_size,
  1002. use_lab=use_lab,
  1003. lr_mult=lr_mult,
  1004. )
  1005. )
  1006. # feature aggregation
  1007. total_channels = in_channels + layer_num * mid_channels
  1008. self.aggregation_squeeze_conv = ConvBNAct(
  1009. in_channels=total_channels,
  1010. out_channels=out_channels // 2,
  1011. kernel_size=1,
  1012. stride=1,
  1013. use_lab=use_lab,
  1014. lr_mult=lr_mult,
  1015. )
  1016. self.aggregation_excitation_conv = ConvBNAct(
  1017. in_channels=out_channels // 2,
  1018. out_channels=out_channels,
  1019. kernel_size=1,
  1020. stride=1,
  1021. use_lab=use_lab,
  1022. lr_mult=lr_mult,
  1023. )
  1024. def forward(self, x):
  1025. identity = x
  1026. output = []
  1027. output.append(x)
  1028. for layer in self.layers:
  1029. x = layer(x)
  1030. output.append(x)
  1031. x = paddle.concat(output, axis=1)
  1032. x = self.aggregation_squeeze_conv(x)
  1033. x = self.aggregation_excitation_conv(x)
  1034. if self.identity:
  1035. x += identity
  1036. return x
  1037. class HGV2_Stage(TheseusLayer):
  1038. """
  1039. HGV2_Stage, the basic unit that constitutes the PPHGNetV2.
  1040. Args:
  1041. in_channels (int): Number of input channels.
  1042. mid_channels (int): Number of middle channels.
  1043. out_channels (int): Number of output channels.
  1044. block_num (int): Number of blocks in the HGV2 stage.
  1045. layer_num (int): Number of layers in the HGV2 block. Defaults to 6.
  1046. is_downsample (bool): Whether to use downsampling operation. Defaults to False.
  1047. light_block (bool): Whether to use light block. Defaults to True.
  1048. kernel_size (int): Size of the convolution kernel. Defaults to 3.
  1049. use_lab (bool, optional): Whether to use the LAB operation. Defaults to False.
  1050. lr_mult (float, optional): Learning rate multiplier for the layer. Defaults to 1.0.
  1051. """
  1052. def __init__(
  1053. self,
  1054. in_channels,
  1055. mid_channels,
  1056. out_channels,
  1057. block_num,
  1058. layer_num=6,
  1059. is_downsample=True,
  1060. light_block=True,
  1061. kernel_size=3,
  1062. use_lab=False,
  1063. stride=2,
  1064. lr_mult=1.0,
  1065. ):
  1066. super().__init__()
  1067. self.is_downsample = is_downsample
  1068. if self.is_downsample:
  1069. self.downsample = ConvBNAct(
  1070. in_channels=in_channels,
  1071. out_channels=in_channels,
  1072. kernel_size=3,
  1073. stride=stride,
  1074. groups=in_channels,
  1075. use_act=False,
  1076. use_lab=use_lab,
  1077. lr_mult=lr_mult,
  1078. )
  1079. blocks_list = []
  1080. for i in range(block_num):
  1081. blocks_list.append(
  1082. HGV2_Block(
  1083. in_channels=in_channels if i == 0 else out_channels,
  1084. mid_channels=mid_channels,
  1085. out_channels=out_channels,
  1086. kernel_size=kernel_size,
  1087. layer_num=layer_num,
  1088. identity=False if i == 0 else True,
  1089. light_block=light_block,
  1090. use_lab=use_lab,
  1091. lr_mult=lr_mult,
  1092. )
  1093. )
  1094. self.blocks = nn.Sequential(*blocks_list)
  1095. def forward(self, x):
  1096. if self.is_downsample:
  1097. x = self.downsample(x)
  1098. x = self.blocks(x)
  1099. return x
  1100. class PPHGNetV2(TheseusLayer):
  1101. """
  1102. PPHGNetV2
  1103. Args:
  1104. stage_config (dict): Config for PPHGNetV2 stages. such as the number of channels, stride, etc.
  1105. stem_channels: (list): Number of channels of the stem of the PPHGNetV2.
  1106. use_lab (bool): Whether to use the LAB operation. Defaults to False.
  1107. use_last_conv (bool): Whether to use the last conv layer as the output channel. Defaults to True.
  1108. class_expand (int): Number of channels for the last 1x1 convolutional layer.
  1109. drop_prob (float): Dropout probability for the last 1x1 convolutional layer. Defaults to 0.0.
  1110. class_num (int): The number of classes for the classification layer. Defaults to 1000.
  1111. lr_mult_list (list): Learning rate multiplier for the stages. Defaults to [1.0, 1.0, 1.0, 1.0, 1.0].
  1112. Returns:
  1113. model: nn.Layer. Specific PPHGNetV2 model depends on args.
  1114. """
  1115. def __init__(
  1116. self,
  1117. stage_config,
  1118. stem_channels=[3, 32, 64],
  1119. use_lab=False,
  1120. use_last_conv=True,
  1121. class_expand=2048,
  1122. dropout_prob=0.0,
  1123. class_num=1000,
  1124. lr_mult_list=[1.0, 1.0, 1.0, 1.0, 1.0],
  1125. det=False,
  1126. text_rec=False,
  1127. out_indices=None,
  1128. **kwargs,
  1129. ):
  1130. super().__init__()
  1131. self.det = det
  1132. self.text_rec = text_rec
  1133. self.use_lab = use_lab
  1134. self.use_last_conv = use_last_conv
  1135. self.class_expand = class_expand
  1136. self.class_num = class_num
  1137. self.out_indices = out_indices if out_indices is not None else [0, 1, 2, 3]
  1138. self.out_channels = []
  1139. # stem
  1140. self.stem = StemBlock(
  1141. in_channels=stem_channels[0],
  1142. mid_channels=stem_channels[1],
  1143. out_channels=stem_channels[2],
  1144. use_lab=use_lab,
  1145. lr_mult=lr_mult_list[0],
  1146. text_rec=text_rec,
  1147. )
  1148. # stages
  1149. self.stages = nn.LayerList()
  1150. for i, k in enumerate(stage_config):
  1151. (
  1152. in_channels,
  1153. mid_channels,
  1154. out_channels,
  1155. block_num,
  1156. is_downsample,
  1157. light_block,
  1158. kernel_size,
  1159. layer_num,
  1160. stride,
  1161. ) = stage_config[k]
  1162. self.stages.append(
  1163. HGV2_Stage(
  1164. in_channels,
  1165. mid_channels,
  1166. out_channels,
  1167. block_num,
  1168. layer_num,
  1169. is_downsample,
  1170. light_block,
  1171. kernel_size,
  1172. use_lab,
  1173. stride,
  1174. lr_mult=lr_mult_list[i + 1],
  1175. )
  1176. )
  1177. if i in self.out_indices:
  1178. self.out_channels.append(out_channels)
  1179. if not self.det:
  1180. self.out_channels = stage_config["stage4"][2]
  1181. self.avg_pool = AdaptiveAvgPool2D(1)
  1182. if self.use_last_conv:
  1183. self.last_conv = Conv2D(
  1184. in_channels=out_channels,
  1185. out_channels=self.class_expand,
  1186. kernel_size=1,
  1187. stride=1,
  1188. padding=0,
  1189. bias_attr=False,
  1190. )
  1191. self.act = ReLU()
  1192. if self.use_lab:
  1193. self.lab = LearnableAffineBlock()
  1194. self.dropout = nn.Dropout(p=dropout_prob, mode="downscale_in_infer")
  1195. self.flatten = nn.Flatten(start_axis=1, stop_axis=-1)
  1196. if not self.det:
  1197. self.fc = nn.Linear(
  1198. self.class_expand if self.use_last_conv else out_channels,
  1199. self.class_num,
  1200. )
  1201. self._init_weights()
  1202. def _init_weights(self):
  1203. for m in self.sublayers():
  1204. if isinstance(m, nn.Conv2D):
  1205. kaiming_normal_(m.weight)
  1206. elif isinstance(m, (nn.BatchNorm2D)):
  1207. ones_(m.weight)
  1208. zeros_(m.bias)
  1209. elif isinstance(m, nn.Linear):
  1210. zeros_(m.bias)
  1211. def forward(self, x):
  1212. x = self.stem(x)
  1213. out = []
  1214. for i, stage in enumerate(self.stages):
  1215. x = stage(x)
  1216. if self.det and i in self.out_indices:
  1217. out.append(x)
  1218. if self.det:
  1219. return out
  1220. if self.text_rec:
  1221. if self.training:
  1222. x = F.adaptive_avg_pool2d(x, [1, 40])
  1223. else:
  1224. x = F.avg_pool2d(x, [3, 2])
  1225. return x
  1226. def PPHGNetV2_B0(pretrained=False, use_ssld=False, **kwargs):
  1227. """
  1228. PPHGNetV2_B0
  1229. Args:
  1230. pretrained (bool/str): If `True` load pretrained parameters, `False` otherwise.
  1231. If str, means the path of the pretrained model.
  1232. use_ssld (bool) Whether using ssld pretrained model when pretrained is True.
  1233. Returns:
  1234. model: nn.Layer. Specific `PPHGNetV2_B0` model depends on args.
  1235. """
  1236. stage_config = {
  1237. # in_channels, mid_channels, out_channels, num_blocks, is_downsample, light_block, kernel_size, layer_num
  1238. "stage1": [16, 16, 64, 1, False, False, 3, 3],
  1239. "stage2": [64, 32, 256, 1, True, False, 3, 3],
  1240. "stage3": [256, 64, 512, 2, True, True, 5, 3],
  1241. "stage4": [512, 128, 1024, 1, True, True, 5, 3],
  1242. }
  1243. model = PPHGNetV2(
  1244. stem_channels=[3, 16, 16], stage_config=stage_config, use_lab=True, **kwargs
  1245. )
  1246. return model
  1247. def PPHGNetV2_B1(pretrained=False, use_ssld=False, **kwargs):
  1248. """
  1249. PPHGNetV2_B1
  1250. Args:
  1251. pretrained (bool/str): If `True` load pretrained parameters, `False` otherwise.
  1252. If str, means the path of the pretrained model.
  1253. use_ssld (bool) Whether using ssld pretrained model when pretrained is True.
  1254. Returns:
  1255. model: nn.Layer. Specific `PPHGNetV2_B1` model depends on args.
  1256. """
  1257. stage_config = {
  1258. # in_channels, mid_channels, out_channels, num_blocks, is_downsample, light_block, kernel_size, layer_num
  1259. "stage1": [32, 32, 64, 1, False, False, 3, 3],
  1260. "stage2": [64, 48, 256, 1, True, False, 3, 3],
  1261. "stage3": [256, 96, 512, 2, True, True, 5, 3],
  1262. "stage4": [512, 192, 1024, 1, True, True, 5, 3],
  1263. }
  1264. model = PPHGNetV2(
  1265. stem_channels=[3, 24, 32], stage_config=stage_config, use_lab=True, **kwargs
  1266. )
  1267. return model
  1268. def PPHGNetV2_B2(pretrained=False, use_ssld=False, **kwargs):
  1269. """
  1270. PPHGNetV2_B2
  1271. Args:
  1272. pretrained (bool/str): If `True` load pretrained parameters, `False` otherwise.
  1273. If str, means the path of the pretrained model.
  1274. use_ssld (bool) Whether using ssld pretrained model when pretrained is True.
  1275. Returns:
  1276. model: nn.Layer. Specific `PPHGNetV2_B2` model depends on args.
  1277. """
  1278. stage_config = {
  1279. # in_channels, mid_channels, out_channels, num_blocks, is_downsample, light_block, kernel_size, layer_num
  1280. "stage1": [32, 32, 96, 1, False, False, 3, 4],
  1281. "stage2": [96, 64, 384, 1, True, False, 3, 4],
  1282. "stage3": [384, 128, 768, 3, True, True, 5, 4],
  1283. "stage4": [768, 256, 1536, 1, True, True, 5, 4],
  1284. }
  1285. model = PPHGNetV2(
  1286. stem_channels=[3, 24, 32], stage_config=stage_config, use_lab=True, **kwargs
  1287. )
  1288. return model
  1289. def PPHGNetV2_B3(pretrained=False, use_ssld=False, **kwargs):
  1290. """
  1291. PPHGNetV2_B3
  1292. Args:
  1293. pretrained (bool/str): If `True` load pretrained parameters, `False` otherwise.
  1294. If str, means the path of the pretrained model.
  1295. use_ssld (bool) Whether using ssld pretrained model when pretrained is True.
  1296. Returns:
  1297. model: nn.Layer. Specific `PPHGNetV2_B3` model depends on args.
  1298. """
  1299. stage_config = {
  1300. # in_channels, mid_channels, out_channels, num_blocks, is_downsample, light_block, kernel_size, layer_num
  1301. "stage1": [32, 32, 128, 1, False, False, 3, 5],
  1302. "stage2": [128, 64, 512, 1, True, False, 3, 5],
  1303. "stage3": [512, 128, 1024, 3, True, True, 5, 5],
  1304. "stage4": [1024, 256, 2048, 1, True, True, 5, 5],
  1305. }
  1306. model = PPHGNetV2(
  1307. stem_channels=[3, 24, 32], stage_config=stage_config, use_lab=True, **kwargs
  1308. )
  1309. return model
  1310. def PPHGNetV2_B4(pretrained=False, use_ssld=False, det=False, text_rec=False, **kwargs):
  1311. """
  1312. PPHGNetV2_B4
  1313. Args:
  1314. pretrained (bool/str): If `True` load pretrained parameters, `False` otherwise.
  1315. If str, means the path of the pretrained model.
  1316. use_ssld (bool) Whether using ssld pretrained model when pretrained is True.
  1317. Returns:
  1318. model: nn.Layer. Specific `PPHGNetV2_B4` model depends on args.
  1319. """
  1320. stage_config_rec = {
  1321. # in_channels, mid_channels, out_channels, num_blocks, is_downsample, light_block, kernel_size, layer_num, stride
  1322. "stage1": [48, 48, 128, 1, True, False, 3, 6, [2, 1]],
  1323. "stage2": [128, 96, 512, 1, True, False, 3, 6, [1, 2]],
  1324. "stage3": [512, 192, 1024, 3, True, True, 5, 6, [2, 1]],
  1325. "stage4": [1024, 384, 2048, 1, True, True, 5, 6, [2, 1]],
  1326. }
  1327. stage_config_det = {
  1328. # in_channels, mid_channels, out_channels, num_blocks, is_downsample, light_block, kernel_size, layer_num
  1329. "stage1": [48, 48, 128, 1, False, False, 3, 6, 2],
  1330. "stage2": [128, 96, 512, 1, True, False, 3, 6, 2],
  1331. "stage3": [512, 192, 1024, 3, True, True, 5, 6, 2],
  1332. "stage4": [1024, 384, 2048, 1, True, True, 5, 6, 2],
  1333. }
  1334. model = PPHGNetV2(
  1335. stem_channels=[3, 32, 48],
  1336. stage_config=stage_config_det if det else stage_config_rec,
  1337. use_lab=False,
  1338. det=det,
  1339. text_rec=text_rec,
  1340. **kwargs,
  1341. )
  1342. return model
  1343. def PPHGNetV2_B5(pretrained=False, use_ssld=False, **kwargs):
  1344. """
  1345. PPHGNetV2_B5
  1346. Args:
  1347. pretrained (bool/str): If `True` load pretrained parameters, `False` otherwise.
  1348. If str, means the path of the pretrained model.
  1349. use_ssld (bool) Whether using ssld pretrained model when pretrained is True.
  1350. Returns:
  1351. model: nn.Layer. Specific `PPHGNetV2_B5` model depends on args.
  1352. """
  1353. stage_config = {
  1354. # in_channels, mid_channels, out_channels, num_blocks, is_downsample, light_block, kernel_size, layer_num
  1355. "stage1": [64, 64, 128, 1, False, False, 3, 6],
  1356. "stage2": [128, 128, 512, 2, True, False, 3, 6],
  1357. "stage3": [512, 256, 1024, 5, True, True, 5, 6],
  1358. "stage4": [1024, 512, 2048, 2, True, True, 5, 6],
  1359. }
  1360. model = PPHGNetV2(
  1361. stem_channels=[3, 32, 64], stage_config=stage_config, use_lab=False, **kwargs
  1362. )
  1363. return model
  1364. def PPHGNetV2_B6(pretrained=False, use_ssld=False, **kwargs):
  1365. """
  1366. PPHGNetV2_B6
  1367. Args:
  1368. pretrained (bool/str): If `True` load pretrained parameters, `False` otherwise.
  1369. If str, means the path of the pretrained model.
  1370. use_ssld (bool) Whether using ssld pretrained model when pretrained is True.
  1371. Returns:
  1372. model: nn.Layer. Specific `PPHGNetV2_B6` model depends on args.
  1373. """
  1374. stage_config = {
  1375. # in_channels, mid_channels, out_channels, num_blocks, is_downsample, light_block, kernel_size, layer_num
  1376. "stage1": [96, 96, 192, 2, False, False, 3, 6],
  1377. "stage2": [192, 192, 512, 3, True, False, 3, 6],
  1378. "stage3": [512, 384, 1024, 6, True, True, 5, 6],
  1379. "stage4": [1024, 768, 2048, 3, True, True, 5, 6],
  1380. }
  1381. model = PPHGNetV2(
  1382. stem_channels=[3, 48, 96], stage_config=stage_config, use_lab=False, **kwargs
  1383. )
  1384. return model
  1385. class PPHGNetV2_B4_Formula(nn.Layer):
  1386. """
  1387. PPHGNetV2_B4_Formula
  1388. Args:
  1389. in_channels (int): Number of input channels. Default is 3 (for RGB images).
  1390. class_num (int): Number of classes for classification. Default is 1000.
  1391. Returns:
  1392. model: nn.Layer. Specific `PPHGNetV2_B4` model with defined architecture.
  1393. """
  1394. def __init__(self, in_channels=3, class_num=1000):
  1395. super().__init__()
  1396. self.in_channels = in_channels
  1397. self.out_channels = 2048
  1398. stage_config = {
  1399. # in_channels, mid_channels, out_channels, num_blocks, is_downsample, light_block, kernel_size, layer_num
  1400. "stage1": [48, 48, 128, 1, False, False, 3, 6, 2],
  1401. "stage2": [128, 96, 512, 1, True, False, 3, 6, 2],
  1402. "stage3": [512, 192, 1024, 3, True, True, 5, 6, 2],
  1403. "stage4": [1024, 384, 2048, 1, True, True, 5, 6, 2],
  1404. }
  1405. self.pphgnet_b4 = PPHGNetV2(
  1406. stem_channels=[3, 32, 48],
  1407. stage_config=stage_config,
  1408. class_num=class_num,
  1409. use_lab=False,
  1410. )
  1411. def forward(self, input_data):
  1412. if self.training:
  1413. pixel_values, label, attention_mask = input_data
  1414. else:
  1415. if isinstance(input_data, list):
  1416. pixel_values = input_data[0]
  1417. else:
  1418. pixel_values = input_data
  1419. num_channels = pixel_values.shape[1]
  1420. if num_channels == 1:
  1421. pixel_values = paddle.repeat_interleave(pixel_values, repeats=3, axis=1)
  1422. pphgnet_b4_output = self.pphgnet_b4(pixel_values)
  1423. b, c, h, w = pphgnet_b4_output.shape
  1424. pphgnet_b4_output = pphgnet_b4_output.reshape([b, c, h * w]).transpose(
  1425. [0, 2, 1]
  1426. )
  1427. pphgnet_b4_output = DonutSwinModelOutput(
  1428. last_hidden_state=pphgnet_b4_output,
  1429. pooler_output=None,
  1430. hidden_states=None,
  1431. attentions=False,
  1432. reshaped_hidden_states=None,
  1433. )
  1434. if self.training:
  1435. return pphgnet_b4_output, label, attention_mask
  1436. else:
  1437. return pphgnet_b4_output
  1438. class PPHGNetV2_B6_Formula(nn.Layer):
  1439. """
  1440. PPHGNetV2_B6_Formula
  1441. Args:
  1442. in_channels (int): Number of input channels. Default is 3 (for RGB images).
  1443. class_num (int): Number of classes for classification. Default is 1000.
  1444. Returns:
  1445. model: nn.Layer. Specific `PPHGNetV2_B6` model with defined architecture.
  1446. """
  1447. def __init__(self, in_channels=3, class_num=1000):
  1448. super().__init__()
  1449. self.in_channels = in_channels
  1450. self.out_channels = 2048
  1451. stage_config = {
  1452. # in_channels, mid_channels, out_channels, num_blocks, is_downsample, light_block, kernel_size, layer_num
  1453. "stage1": [96, 96, 192, 2, False, False, 3, 6, 2],
  1454. "stage2": [192, 192, 512, 3, True, False, 3, 6, 2],
  1455. "stage3": [512, 384, 1024, 6, True, True, 5, 6, 2],
  1456. "stage4": [1024, 768, 2048, 3, True, True, 5, 6, 2],
  1457. }
  1458. self.pphgnet_b6 = PPHGNetV2(
  1459. stem_channels=[3, 48, 96],
  1460. class_num=class_num,
  1461. stage_config=stage_config,
  1462. use_lab=False,
  1463. )
  1464. def forward(self, input_data):
  1465. if self.training:
  1466. pixel_values, label, attention_mask = input_data
  1467. else:
  1468. if isinstance(input_data, list):
  1469. pixel_values = input_data[0]
  1470. else:
  1471. pixel_values = input_data
  1472. num_channels = pixel_values.shape[1]
  1473. if num_channels == 1:
  1474. pixel_values = paddle.repeat_interleave(pixel_values, repeats=3, axis=1)
  1475. pphgnet_b6_output = self.pphgnet_b6(pixel_values)
  1476. b, c, h, w = pphgnet_b6_output.shape
  1477. pphgnet_b6_output = pphgnet_b6_output.reshape([b, c, h * w]).transpose(
  1478. [0, 2, 1]
  1479. )
  1480. pphgnet_b6_output = DonutSwinModelOutput(
  1481. last_hidden_state=pphgnet_b6_output,
  1482. pooler_output=None,
  1483. hidden_states=None,
  1484. attentions=False,
  1485. reshaped_hidden_states=None,
  1486. )
  1487. if self.training:
  1488. return pphgnet_b6_output, label, attention_mask
  1489. else:
  1490. return pphgnet_b6_output