east_fpn.py 5.7 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203
  1. # copyright (c) 2019 PaddlePaddle Authors. All Rights Reserve.
  2. #
  3. # Licensed under the Apache License, Version 2.0 (the "License");
  4. # you may not use this file except in compliance with the License.
  5. # You may obtain a copy of the License at
  6. #
  7. # http://www.apache.org/licenses/LICENSE-2.0
  8. #
  9. # Unless required by applicable law or agreed to in writing, software
  10. # distributed under the License is distributed on an "AS IS" BASIS,
  11. # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  12. # See the License for the specific language governing permissions and
  13. # limitations under the License.
  14. from __future__ import absolute_import
  15. from __future__ import division
  16. from __future__ import print_function
  17. import paddle
  18. from paddle import nn
  19. import paddle.nn.functional as F
  20. from paddle import ParamAttr
  21. class ConvBNLayer(nn.Layer):
  22. def __init__(
  23. self,
  24. in_channels,
  25. out_channels,
  26. kernel_size,
  27. stride,
  28. padding,
  29. groups=1,
  30. if_act=True,
  31. act=None,
  32. name=None,
  33. ):
  34. super(ConvBNLayer, self).__init__()
  35. self.if_act = if_act
  36. self.act = act
  37. self.conv = nn.Conv2D(
  38. in_channels=in_channels,
  39. out_channels=out_channels,
  40. kernel_size=kernel_size,
  41. stride=stride,
  42. padding=padding,
  43. groups=groups,
  44. weight_attr=ParamAttr(name=name + "_weights"),
  45. bias_attr=False,
  46. )
  47. self.bn = nn.BatchNorm(
  48. num_channels=out_channels,
  49. act=act,
  50. param_attr=ParamAttr(name="bn_" + name + "_scale"),
  51. bias_attr=ParamAttr(name="bn_" + name + "_offset"),
  52. moving_mean_name="bn_" + name + "_mean",
  53. moving_variance_name="bn_" + name + "_variance",
  54. )
  55. def forward(self, x):
  56. x = self.conv(x)
  57. x = self.bn(x)
  58. return x
  59. class DeConvBNLayer(nn.Layer):
  60. def __init__(
  61. self,
  62. in_channels,
  63. out_channels,
  64. kernel_size,
  65. stride,
  66. padding,
  67. groups=1,
  68. if_act=True,
  69. act=None,
  70. name=None,
  71. ):
  72. super(DeConvBNLayer, self).__init__()
  73. self.if_act = if_act
  74. self.act = act
  75. self.deconv = nn.Conv2DTranspose(
  76. in_channels=in_channels,
  77. out_channels=out_channels,
  78. kernel_size=kernel_size,
  79. stride=stride,
  80. padding=padding,
  81. groups=groups,
  82. weight_attr=ParamAttr(name=name + "_weights"),
  83. bias_attr=False,
  84. )
  85. self.bn = nn.BatchNorm(
  86. num_channels=out_channels,
  87. act=act,
  88. param_attr=ParamAttr(name="bn_" + name + "_scale"),
  89. bias_attr=ParamAttr(name="bn_" + name + "_offset"),
  90. moving_mean_name="bn_" + name + "_mean",
  91. moving_variance_name="bn_" + name + "_variance",
  92. )
  93. def forward(self, x):
  94. x = self.deconv(x)
  95. x = self.bn(x)
  96. return x
  97. class EASTFPN(nn.Layer):
  98. def __init__(self, in_channels, model_name, **kwargs):
  99. super(EASTFPN, self).__init__()
  100. self.model_name = model_name
  101. if self.model_name == "large":
  102. self.out_channels = 128
  103. else:
  104. self.out_channels = 64
  105. self.in_channels = in_channels[::-1]
  106. self.h1_conv = ConvBNLayer(
  107. in_channels=self.out_channels + self.in_channels[1],
  108. out_channels=self.out_channels,
  109. kernel_size=3,
  110. stride=1,
  111. padding=1,
  112. if_act=True,
  113. act="relu",
  114. name="unet_h_1",
  115. )
  116. self.h2_conv = ConvBNLayer(
  117. in_channels=self.out_channels + self.in_channels[2],
  118. out_channels=self.out_channels,
  119. kernel_size=3,
  120. stride=1,
  121. padding=1,
  122. if_act=True,
  123. act="relu",
  124. name="unet_h_2",
  125. )
  126. self.h3_conv = ConvBNLayer(
  127. in_channels=self.out_channels + self.in_channels[3],
  128. out_channels=self.out_channels,
  129. kernel_size=3,
  130. stride=1,
  131. padding=1,
  132. if_act=True,
  133. act="relu",
  134. name="unet_h_3",
  135. )
  136. self.g0_deconv = DeConvBNLayer(
  137. in_channels=self.in_channels[0],
  138. out_channels=self.out_channels,
  139. kernel_size=4,
  140. stride=2,
  141. padding=1,
  142. if_act=True,
  143. act="relu",
  144. name="unet_g_0",
  145. )
  146. self.g1_deconv = DeConvBNLayer(
  147. in_channels=self.out_channels,
  148. out_channels=self.out_channels,
  149. kernel_size=4,
  150. stride=2,
  151. padding=1,
  152. if_act=True,
  153. act="relu",
  154. name="unet_g_1",
  155. )
  156. self.g2_deconv = DeConvBNLayer(
  157. in_channels=self.out_channels,
  158. out_channels=self.out_channels,
  159. kernel_size=4,
  160. stride=2,
  161. padding=1,
  162. if_act=True,
  163. act="relu",
  164. name="unet_g_2",
  165. )
  166. self.g3_conv = ConvBNLayer(
  167. in_channels=self.out_channels,
  168. out_channels=self.out_channels,
  169. kernel_size=3,
  170. stride=1,
  171. padding=1,
  172. if_act=True,
  173. act="relu",
  174. name="unet_g_3",
  175. )
  176. def forward(self, x):
  177. f = x[::-1]
  178. h = f[0]
  179. g = self.g0_deconv(h)
  180. h = paddle.concat([g, f[1]], axis=1)
  181. h = self.h1_conv(h)
  182. g = self.g1_deconv(h)
  183. h = paddle.concat([g, f[2]], axis=1)
  184. h = self.h2_conv(h)
  185. g = self.g2_deconv(h)
  186. h = paddle.concat([g, f[3]], axis=1)
  187. h = self.h3_conv(h)
  188. g = self.g3_conv(h)
  189. return g