rec_efficientb3_pren.py 10 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305
  1. # copyright (c) 2022 PaddlePaddle Authors. All Rights Reserve.
  2. #
  3. # Licensed under the Apache License, Version 2.0 (the "License");
  4. # you may not use this file except in compliance with the License.
  5. # You may obtain a copy of the License at
  6. #
  7. # http://www.apache.org/licenses/LICENSE-2.0
  8. #
  9. # Unless required by applicable law or agreed to in writing, software
  10. # distributed under the License is distributed on an "AS IS" BASIS,
  11. # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  12. # See the License for the specific language governing permissions and
  13. # limitations under the License.
  14. """
  15. Code is refer from:
  16. https://github.com/RuijieJ/pren/blob/main/Nets/EfficientNet.py
  17. """
  18. from __future__ import absolute_import
  19. from __future__ import division
  20. from __future__ import print_function
  21. import math
  22. import re
  23. import collections
  24. import paddle
  25. import paddle.nn as nn
  26. import paddle.nn.functional as F
  27. __all__ = ["EfficientNetb3_PREN"]
  28. GlobalParams = collections.namedtuple(
  29. "GlobalParams",
  30. [
  31. "batch_norm_momentum",
  32. "batch_norm_epsilon",
  33. "dropout_rate",
  34. "num_classes",
  35. "width_coefficient",
  36. "depth_coefficient",
  37. "depth_divisor",
  38. "min_depth",
  39. "drop_connect_rate",
  40. "image_size",
  41. ],
  42. )
  43. BlockArgs = collections.namedtuple(
  44. "BlockArgs",
  45. [
  46. "kernel_size",
  47. "num_repeat",
  48. "input_filters",
  49. "output_filters",
  50. "expand_ratio",
  51. "id_skip",
  52. "stride",
  53. "se_ratio",
  54. ],
  55. )
  56. class BlockDecoder:
  57. @staticmethod
  58. def _decode_block_string(block_string):
  59. assert isinstance(block_string, str)
  60. ops = block_string.split("_")
  61. options = {}
  62. for op in ops:
  63. splits = re.split(r"(\d.*)", op)
  64. if len(splits) >= 2:
  65. key, value = splits[:2]
  66. options[key] = value
  67. assert ("s" in options and len(options["s"]) == 1) or (
  68. len(options["s"]) == 2 and options["s"][0] == options["s"][1]
  69. )
  70. return BlockArgs(
  71. kernel_size=int(options["k"]),
  72. num_repeat=int(options["r"]),
  73. input_filters=int(options["i"]),
  74. output_filters=int(options["o"]),
  75. expand_ratio=int(options["e"]),
  76. id_skip=("noskip" not in block_string),
  77. se_ratio=float(options["se"]) if "se" in options else None,
  78. stride=[int(options["s"][0])],
  79. )
  80. @staticmethod
  81. def decode(string_list):
  82. assert isinstance(string_list, list)
  83. blocks_args = []
  84. for block_string in string_list:
  85. blocks_args.append(BlockDecoder._decode_block_string(block_string))
  86. return blocks_args
  87. def efficientnet(
  88. width_coefficient=None,
  89. depth_coefficient=None,
  90. dropout_rate=0.2,
  91. drop_connect_rate=0.2,
  92. image_size=None,
  93. num_classes=1000,
  94. ):
  95. blocks_args = [
  96. "r1_k3_s11_e1_i32_o16_se0.25",
  97. "r2_k3_s22_e6_i16_o24_se0.25",
  98. "r2_k5_s22_e6_i24_o40_se0.25",
  99. "r3_k3_s22_e6_i40_o80_se0.25",
  100. "r3_k5_s11_e6_i80_o112_se0.25",
  101. "r4_k5_s22_e6_i112_o192_se0.25",
  102. "r1_k3_s11_e6_i192_o320_se0.25",
  103. ]
  104. blocks_args = BlockDecoder.decode(blocks_args)
  105. global_params = GlobalParams(
  106. batch_norm_momentum=0.99,
  107. batch_norm_epsilon=1e-3,
  108. dropout_rate=dropout_rate,
  109. drop_connect_rate=drop_connect_rate,
  110. num_classes=num_classes,
  111. width_coefficient=width_coefficient,
  112. depth_coefficient=depth_coefficient,
  113. depth_divisor=8,
  114. min_depth=None,
  115. image_size=image_size,
  116. )
  117. return blocks_args, global_params
  118. class EffUtils:
  119. @staticmethod
  120. def round_filters(filters, global_params):
  121. """Calculate and round number of filters based on depth multiplier."""
  122. multiplier = global_params.width_coefficient
  123. if not multiplier:
  124. return filters
  125. divisor = global_params.depth_divisor
  126. min_depth = global_params.min_depth
  127. filters *= multiplier
  128. min_depth = min_depth or divisor
  129. new_filters = max(min_depth, int(filters + divisor / 2) // divisor * divisor)
  130. if new_filters < 0.9 * filters:
  131. new_filters += divisor
  132. return int(new_filters)
  133. @staticmethod
  134. def round_repeats(repeats, global_params):
  135. """Round number of filters based on depth multiplier."""
  136. multiplier = global_params.depth_coefficient
  137. if not multiplier:
  138. return repeats
  139. return int(math.ceil(multiplier * repeats))
  140. class MbConvBlock(nn.Layer):
  141. def __init__(self, block_args):
  142. super(MbConvBlock, self).__init__()
  143. self._block_args = block_args
  144. self.has_se = (self._block_args.se_ratio is not None) and (
  145. 0 < self._block_args.se_ratio <= 1
  146. )
  147. self.id_skip = block_args.id_skip
  148. # expansion phase
  149. self.inp = self._block_args.input_filters
  150. oup = self._block_args.input_filters * self._block_args.expand_ratio
  151. if self._block_args.expand_ratio != 1:
  152. self._expand_conv = nn.Conv2D(self.inp, oup, 1, bias_attr=False)
  153. self._bn0 = nn.BatchNorm(oup)
  154. # depthwise conv phase
  155. k = self._block_args.kernel_size
  156. s = self._block_args.stride
  157. if isinstance(s, list):
  158. s = s[0]
  159. self._depthwise_conv = nn.Conv2D(
  160. oup,
  161. oup,
  162. groups=oup,
  163. kernel_size=k,
  164. stride=s,
  165. padding="same",
  166. bias_attr=False,
  167. )
  168. self._bn1 = nn.BatchNorm(oup)
  169. # squeeze and excitation layer, if desired
  170. if self.has_se:
  171. num_squeezed_channels = max(
  172. 1, int(self._block_args.input_filters * self._block_args.se_ratio)
  173. )
  174. self._se_reduce = nn.Conv2D(oup, num_squeezed_channels, 1)
  175. self._se_expand = nn.Conv2D(num_squeezed_channels, oup, 1)
  176. # output phase and some util class
  177. self.final_oup = self._block_args.output_filters
  178. self._project_conv = nn.Conv2D(oup, self.final_oup, 1, bias_attr=False)
  179. self._bn2 = nn.BatchNorm(self.final_oup)
  180. self._swish = nn.Swish()
  181. def _drop_connect(self, inputs, p, training):
  182. if not training:
  183. return inputs
  184. batch_size = inputs.shape[0]
  185. keep_prob = 1 - p
  186. random_tensor = keep_prob
  187. random_tensor += paddle.rand([batch_size, 1, 1, 1], dtype=inputs.dtype)
  188. random_tensor = paddle.to_tensor(random_tensor, place=inputs.place)
  189. binary_tensor = paddle.floor(random_tensor)
  190. output = inputs / keep_prob * binary_tensor
  191. return output
  192. def forward(self, inputs, drop_connect_rate=None):
  193. # expansion and depthwise conv
  194. x = inputs
  195. if self._block_args.expand_ratio != 1:
  196. x = self._swish(self._bn0(self._expand_conv(inputs)))
  197. x = self._swish(self._bn1(self._depthwise_conv(x)))
  198. # squeeze and excitation
  199. if self.has_se:
  200. x_squeezed = F.adaptive_avg_pool2d(x, 1)
  201. x_squeezed = self._se_expand(self._swish(self._se_reduce(x_squeezed)))
  202. x = F.sigmoid(x_squeezed) * x
  203. x = self._bn2(self._project_conv(x))
  204. # skip connection and drop connect
  205. if self.id_skip and self._block_args.stride == 1 and self.inp == self.final_oup:
  206. if drop_connect_rate:
  207. x = self._drop_connect(x, p=drop_connect_rate, training=self.training)
  208. x = x + inputs
  209. return x
  210. class EfficientNetb3_PREN(nn.Layer):
  211. def __init__(self, in_channels):
  212. super(EfficientNetb3_PREN, self).__init__()
  213. """
  214. the fllowing are efficientnetb3's superparams,
  215. they means efficientnetb3 network's width, depth, resolution and
  216. dropout respectively, to fit for text recognition task, the resolution
  217. here is changed from 300 to 64.
  218. """
  219. w, d, s, p = 1.2, 1.4, 64, 0.3
  220. self._blocks_args, self._global_params = efficientnet(
  221. width_coefficient=w, depth_coefficient=d, dropout_rate=p, image_size=s
  222. )
  223. self.out_channels = []
  224. # stem
  225. out_channels = EffUtils.round_filters(32, self._global_params)
  226. self._conv_stem = nn.Conv2D(
  227. in_channels, out_channels, 3, 2, padding="same", bias_attr=False
  228. )
  229. self._bn0 = nn.BatchNorm(out_channels)
  230. # build blocks
  231. self._blocks = []
  232. # to extract three feature maps for fpn based on efficientnetb3 backbone
  233. self._concerned_block_idxes = [7, 17, 25]
  234. _concerned_idx = 0
  235. for i, block_args in enumerate(self._blocks_args):
  236. block_args = block_args._replace(
  237. input_filters=EffUtils.round_filters(
  238. block_args.input_filters, self._global_params
  239. ),
  240. output_filters=EffUtils.round_filters(
  241. block_args.output_filters, self._global_params
  242. ),
  243. num_repeat=EffUtils.round_repeats(
  244. block_args.num_repeat, self._global_params
  245. ),
  246. )
  247. self._blocks.append(self.add_sublayer(f"{i}-0", MbConvBlock(block_args)))
  248. _concerned_idx += 1
  249. if _concerned_idx in self._concerned_block_idxes:
  250. self.out_channels.append(block_args.output_filters)
  251. if block_args.num_repeat > 1:
  252. block_args = block_args._replace(
  253. input_filters=block_args.output_filters, stride=1
  254. )
  255. for j in range(block_args.num_repeat - 1):
  256. self._blocks.append(
  257. self.add_sublayer(f"{i}-{j+1}", MbConvBlock(block_args))
  258. )
  259. _concerned_idx += 1
  260. if _concerned_idx in self._concerned_block_idxes:
  261. self.out_channels.append(block_args.output_filters)
  262. self._swish = nn.Swish()
  263. def forward(self, inputs):
  264. outs = []
  265. x = self._swish(self._bn0(self._conv_stem(inputs)))
  266. for idx, block in enumerate(self._blocks):
  267. drop_connect_rate = self._global_params.drop_connect_rate
  268. if drop_connect_rate:
  269. drop_connect_rate *= float(idx) / len(self._blocks)
  270. x = block(x, drop_connect_rate=drop_connect_rate)
  271. if idx in self._concerned_block_idxes:
  272. outs.append(x)
  273. return outs