rec_resnet_rfl.py 11 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359
  1. # copyright (c) 2022 PaddlePaddle Authors. All Rights Reserve.
  2. #
  3. # Licensed under the Apache License, Version 2.0 (the "License");
  4. # you may not use this file except in compliance with the License.
  5. # You may obtain a copy of the License at
  6. #
  7. # http://www.apache.org/licenses/LICENSE-2.0
  8. #
  9. # Unless required by applicable law or agreed to in writing, software
  10. # distributed under the License is distributed on an "AS IS" BASIS,
  11. # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  12. # See the License for the specific language governing permissions and
  13. # limitations under the License.
  14. """
  15. This code is refer from:
  16. https://github.com/hikopensource/DAVAR-Lab-OCR/blob/main/davarocr/davar_rcg/models/backbones/ResNetRFL.py
  17. """
  18. from __future__ import absolute_import
  19. from __future__ import division
  20. from __future__ import print_function
  21. import paddle
  22. import paddle.nn as nn
  23. from paddle.nn.initializer import TruncatedNormal, Constant, Normal, KaimingNormal
  24. kaiming_init_ = KaimingNormal()
  25. zeros_ = Constant(value=0.0)
  26. ones_ = Constant(value=1.0)
  27. class BasicBlock(nn.Layer):
  28. """Res-net Basic Block"""
  29. expansion = 1
  30. def __init__(
  31. self, inplanes, planes, stride=1, downsample=None, norm_type="BN", **kwargs
  32. ):
  33. """
  34. Args:
  35. inplanes (int): input channel
  36. planes (int): channels of the middle feature
  37. stride (int): stride of the convolution
  38. downsample (int): type of the down_sample
  39. norm_type (str): type of the normalization
  40. **kwargs (None): backup parameter
  41. """
  42. super(BasicBlock, self).__init__()
  43. self.conv1 = self._conv3x3(inplanes, planes)
  44. self.bn1 = nn.BatchNorm(planes)
  45. self.conv2 = self._conv3x3(planes, planes)
  46. self.bn2 = nn.BatchNorm(planes)
  47. self.relu = nn.ReLU()
  48. self.downsample = downsample
  49. self.stride = stride
  50. def _conv3x3(self, in_planes, out_planes, stride=1):
  51. return nn.Conv2D(
  52. in_planes,
  53. out_planes,
  54. kernel_size=3,
  55. stride=stride,
  56. padding=1,
  57. bias_attr=False,
  58. )
  59. def forward(self, x):
  60. residual = x
  61. out = self.conv1(x)
  62. out = self.bn1(out)
  63. out = self.relu(out)
  64. out = self.conv2(out)
  65. out = self.bn2(out)
  66. if self.downsample is not None:
  67. residual = self.downsample(x)
  68. out += residual
  69. out = self.relu(out)
  70. return out
  71. class ResNetRFL(nn.Layer):
  72. def __init__(self, in_channels, out_channels=512, use_cnt=True, use_seq=True):
  73. """
  74. Args:
  75. in_channels (int): input channel
  76. out_channels (int): output channel
  77. """
  78. super(ResNetRFL, self).__init__()
  79. assert use_cnt or use_seq
  80. self.use_cnt, self.use_seq = use_cnt, use_seq
  81. self.backbone = RFLBase(in_channels)
  82. self.out_channels = out_channels
  83. self.out_channels_block = [
  84. int(self.out_channels / 4),
  85. int(self.out_channels / 2),
  86. self.out_channels,
  87. self.out_channels,
  88. ]
  89. block = BasicBlock
  90. layers = [1, 2, 5, 3]
  91. self.inplanes = int(self.out_channels // 2)
  92. self.relu = nn.ReLU()
  93. if self.use_seq:
  94. self.maxpool3 = nn.MaxPool2D(kernel_size=2, stride=(2, 1), padding=(0, 1))
  95. self.layer3 = self._make_layer(
  96. block, self.out_channels_block[2], layers[2], stride=1
  97. )
  98. self.conv3 = nn.Conv2D(
  99. self.out_channels_block[2],
  100. self.out_channels_block[2],
  101. kernel_size=3,
  102. stride=1,
  103. padding=1,
  104. bias_attr=False,
  105. )
  106. self.bn3 = nn.BatchNorm(self.out_channels_block[2])
  107. self.layer4 = self._make_layer(
  108. block, self.out_channels_block[3], layers[3], stride=1
  109. )
  110. self.conv4_1 = nn.Conv2D(
  111. self.out_channels_block[3],
  112. self.out_channels_block[3],
  113. kernel_size=2,
  114. stride=(2, 1),
  115. padding=(0, 1),
  116. bias_attr=False,
  117. )
  118. self.bn4_1 = nn.BatchNorm(self.out_channels_block[3])
  119. self.conv4_2 = nn.Conv2D(
  120. self.out_channels_block[3],
  121. self.out_channels_block[3],
  122. kernel_size=2,
  123. stride=1,
  124. padding=0,
  125. bias_attr=False,
  126. )
  127. self.bn4_2 = nn.BatchNorm(self.out_channels_block[3])
  128. if self.use_cnt:
  129. self.inplanes = int(self.out_channels // 2)
  130. self.v_maxpool3 = nn.MaxPool2D(kernel_size=2, stride=(2, 1), padding=(0, 1))
  131. self.v_layer3 = self._make_layer(
  132. block, self.out_channels_block[2], layers[2], stride=1
  133. )
  134. self.v_conv3 = nn.Conv2D(
  135. self.out_channels_block[2],
  136. self.out_channels_block[2],
  137. kernel_size=3,
  138. stride=1,
  139. padding=1,
  140. bias_attr=False,
  141. )
  142. self.v_bn3 = nn.BatchNorm(self.out_channels_block[2])
  143. self.v_layer4 = self._make_layer(
  144. block, self.out_channels_block[3], layers[3], stride=1
  145. )
  146. self.v_conv4_1 = nn.Conv2D(
  147. self.out_channels_block[3],
  148. self.out_channels_block[3],
  149. kernel_size=2,
  150. stride=(2, 1),
  151. padding=(0, 1),
  152. bias_attr=False,
  153. )
  154. self.v_bn4_1 = nn.BatchNorm(self.out_channels_block[3])
  155. self.v_conv4_2 = nn.Conv2D(
  156. self.out_channels_block[3],
  157. self.out_channels_block[3],
  158. kernel_size=2,
  159. stride=1,
  160. padding=0,
  161. bias_attr=False,
  162. )
  163. self.v_bn4_2 = nn.BatchNorm(self.out_channels_block[3])
  164. def _make_layer(self, block, planes, blocks, stride=1):
  165. downsample = None
  166. if stride != 1 or self.inplanes != planes * block.expansion:
  167. downsample = nn.Sequential(
  168. nn.Conv2D(
  169. self.inplanes,
  170. planes * block.expansion,
  171. kernel_size=1,
  172. stride=stride,
  173. bias_attr=False,
  174. ),
  175. nn.BatchNorm(planes * block.expansion),
  176. )
  177. layers = list()
  178. layers.append(block(self.inplanes, planes, stride, downsample))
  179. self.inplanes = planes * block.expansion
  180. for _ in range(1, blocks):
  181. layers.append(block(self.inplanes, planes))
  182. return nn.Sequential(*layers)
  183. def forward(self, inputs):
  184. x_1 = self.backbone(inputs)
  185. if self.use_cnt:
  186. v_x = self.v_maxpool3(x_1)
  187. v_x = self.v_layer3(v_x)
  188. v_x = self.v_conv3(v_x)
  189. v_x = self.v_bn3(v_x)
  190. visual_feature_2 = self.relu(v_x)
  191. v_x = self.v_layer4(visual_feature_2)
  192. v_x = self.v_conv4_1(v_x)
  193. v_x = self.v_bn4_1(v_x)
  194. v_x = self.relu(v_x)
  195. v_x = self.v_conv4_2(v_x)
  196. v_x = self.v_bn4_2(v_x)
  197. visual_feature_3 = self.relu(v_x)
  198. else:
  199. visual_feature_3 = None
  200. if self.use_seq:
  201. x = self.maxpool3(x_1)
  202. x = self.layer3(x)
  203. x = self.conv3(x)
  204. x = self.bn3(x)
  205. x_2 = self.relu(x)
  206. x = self.layer4(x_2)
  207. x = self.conv4_1(x)
  208. x = self.bn4_1(x)
  209. x = self.relu(x)
  210. x = self.conv4_2(x)
  211. x = self.bn4_2(x)
  212. x_3 = self.relu(x)
  213. else:
  214. x_3 = None
  215. return [visual_feature_3, x_3]
  216. class ResNetBase(nn.Layer):
  217. def __init__(self, in_channels, out_channels, block, layers):
  218. super(ResNetBase, self).__init__()
  219. self.out_channels_block = [
  220. int(out_channels / 4),
  221. int(out_channels / 2),
  222. out_channels,
  223. out_channels,
  224. ]
  225. self.inplanes = int(out_channels / 8)
  226. self.conv0_1 = nn.Conv2D(
  227. in_channels,
  228. int(out_channels / 16),
  229. kernel_size=3,
  230. stride=1,
  231. padding=1,
  232. bias_attr=False,
  233. )
  234. self.bn0_1 = nn.BatchNorm(int(out_channels / 16))
  235. self.conv0_2 = nn.Conv2D(
  236. int(out_channels / 16),
  237. self.inplanes,
  238. kernel_size=3,
  239. stride=1,
  240. padding=1,
  241. bias_attr=False,
  242. )
  243. self.bn0_2 = nn.BatchNorm(self.inplanes)
  244. self.relu = nn.ReLU()
  245. self.maxpool1 = nn.MaxPool2D(kernel_size=2, stride=2, padding=0)
  246. self.layer1 = self._make_layer(block, self.out_channels_block[0], layers[0])
  247. self.conv1 = nn.Conv2D(
  248. self.out_channels_block[0],
  249. self.out_channels_block[0],
  250. kernel_size=3,
  251. stride=1,
  252. padding=1,
  253. bias_attr=False,
  254. )
  255. self.bn1 = nn.BatchNorm(self.out_channels_block[0])
  256. self.maxpool2 = nn.MaxPool2D(kernel_size=2, stride=2, padding=0)
  257. self.layer2 = self._make_layer(
  258. block, self.out_channels_block[1], layers[1], stride=1
  259. )
  260. self.conv2 = nn.Conv2D(
  261. self.out_channels_block[1],
  262. self.out_channels_block[1],
  263. kernel_size=3,
  264. stride=1,
  265. padding=1,
  266. bias_attr=False,
  267. )
  268. self.bn2 = nn.BatchNorm(self.out_channels_block[1])
  269. def _make_layer(self, block, planes, blocks, stride=1):
  270. downsample = None
  271. if stride != 1 or self.inplanes != planes * block.expansion:
  272. downsample = nn.Sequential(
  273. nn.Conv2D(
  274. self.inplanes,
  275. planes * block.expansion,
  276. kernel_size=1,
  277. stride=stride,
  278. bias_attr=False,
  279. ),
  280. nn.BatchNorm(planes * block.expansion),
  281. )
  282. layers = list()
  283. layers.append(block(self.inplanes, planes, stride, downsample))
  284. self.inplanes = planes * block.expansion
  285. for _ in range(1, blocks):
  286. layers.append(block(self.inplanes, planes))
  287. return nn.Sequential(*layers)
  288. def forward(self, x):
  289. x = self.conv0_1(x)
  290. x = self.bn0_1(x)
  291. x = self.relu(x)
  292. x = self.conv0_2(x)
  293. x = self.bn0_2(x)
  294. x = self.relu(x)
  295. x = self.maxpool1(x)
  296. x = self.layer1(x)
  297. x = self.conv1(x)
  298. x = self.bn1(x)
  299. x = self.relu(x)
  300. x = self.maxpool2(x)
  301. x = self.layer2(x)
  302. x = self.conv2(x)
  303. x = self.bn2(x)
  304. x = self.relu(x)
  305. return x
  306. class RFLBase(nn.Layer):
  307. """Reciprocal feature learning share backbone network"""
  308. def __init__(self, in_channels, out_channels=512):
  309. super(RFLBase, self).__init__()
  310. self.ConvNet = ResNetBase(in_channels, out_channels, BasicBlock, [1, 2, 5, 3])
  311. def forward(self, inputs):
  312. return self.ConvNet(inputs)