pg_fpn.py 9.8 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345
  1. # copyright (c) 2021 PaddlePaddle Authors. All Rights Reserve.
  2. #
  3. # Licensed under the Apache License, Version 2.0 (the "License");
  4. # you may not use this file except in compliance with the License.
  5. # You may obtain a copy of the License at
  6. #
  7. # http://www.apache.org/licenses/LICENSE-2.0
  8. #
  9. # Unless required by applicable law or agreed to in writing, software
  10. # distributed under the License is distributed on an "AS IS" BASIS,
  11. # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  12. # See the License for the specific language governing permissions and
  13. # limitations under the License.
  14. from __future__ import absolute_import
  15. from __future__ import division
  16. from __future__ import print_function
  17. import paddle
  18. from paddle import nn
  19. import paddle.nn.functional as F
  20. from paddle import ParamAttr
  21. class ConvBNLayer(nn.Layer):
  22. def __init__(
  23. self,
  24. in_channels,
  25. out_channels,
  26. kernel_size,
  27. stride=1,
  28. groups=1,
  29. is_vd_mode=False,
  30. act=None,
  31. name=None,
  32. ):
  33. super(ConvBNLayer, self).__init__()
  34. self.is_vd_mode = is_vd_mode
  35. self._pool2d_avg = nn.AvgPool2D(
  36. kernel_size=2, stride=2, padding=0, ceil_mode=True
  37. )
  38. self._conv = nn.Conv2D(
  39. in_channels=in_channels,
  40. out_channels=out_channels,
  41. kernel_size=kernel_size,
  42. stride=stride,
  43. padding=(kernel_size - 1) // 2,
  44. groups=groups,
  45. weight_attr=ParamAttr(name=name + "_weights"),
  46. bias_attr=False,
  47. )
  48. if name == "conv1":
  49. bn_name = "bn_" + name
  50. else:
  51. bn_name = "bn" + name[3:]
  52. self._batch_norm = nn.BatchNorm(
  53. out_channels,
  54. act=act,
  55. param_attr=ParamAttr(name=bn_name + "_scale"),
  56. bias_attr=ParamAttr(bn_name + "_offset"),
  57. moving_mean_name=bn_name + "_mean",
  58. moving_variance_name=bn_name + "_variance",
  59. use_global_stats=False,
  60. )
  61. def forward(self, inputs):
  62. y = self._conv(inputs)
  63. y = self._batch_norm(y)
  64. return y
  65. class DeConvBNLayer(nn.Layer):
  66. def __init__(
  67. self,
  68. in_channels,
  69. out_channels,
  70. kernel_size=4,
  71. stride=2,
  72. padding=1,
  73. groups=1,
  74. if_act=True,
  75. act=None,
  76. name=None,
  77. ):
  78. super(DeConvBNLayer, self).__init__()
  79. self.if_act = if_act
  80. self.act = act
  81. self.deconv = nn.Conv2DTranspose(
  82. in_channels=in_channels,
  83. out_channels=out_channels,
  84. kernel_size=kernel_size,
  85. stride=stride,
  86. padding=padding,
  87. groups=groups,
  88. weight_attr=ParamAttr(name=name + "_weights"),
  89. bias_attr=False,
  90. )
  91. self.bn = nn.BatchNorm(
  92. num_channels=out_channels,
  93. act=act,
  94. param_attr=ParamAttr(name="bn_" + name + "_scale"),
  95. bias_attr=ParamAttr(name="bn_" + name + "_offset"),
  96. moving_mean_name="bn_" + name + "_mean",
  97. moving_variance_name="bn_" + name + "_variance",
  98. use_global_stats=False,
  99. )
  100. def forward(self, x):
  101. x = self.deconv(x)
  102. x = self.bn(x)
  103. return x
  104. class PGFPN(nn.Layer):
  105. def __init__(self, in_channels, **kwargs):
  106. super(PGFPN, self).__init__()
  107. num_inputs = [2048, 2048, 1024, 512, 256]
  108. num_outputs = [256, 256, 192, 192, 128]
  109. self.out_channels = 128
  110. self.conv_bn_layer_1 = ConvBNLayer(
  111. in_channels=3,
  112. out_channels=32,
  113. kernel_size=3,
  114. stride=1,
  115. act=None,
  116. name="FPN_d1",
  117. )
  118. self.conv_bn_layer_2 = ConvBNLayer(
  119. in_channels=64,
  120. out_channels=64,
  121. kernel_size=3,
  122. stride=1,
  123. act=None,
  124. name="FPN_d2",
  125. )
  126. self.conv_bn_layer_3 = ConvBNLayer(
  127. in_channels=256,
  128. out_channels=128,
  129. kernel_size=3,
  130. stride=1,
  131. act=None,
  132. name="FPN_d3",
  133. )
  134. self.conv_bn_layer_4 = ConvBNLayer(
  135. in_channels=32,
  136. out_channels=64,
  137. kernel_size=3,
  138. stride=2,
  139. act=None,
  140. name="FPN_d4",
  141. )
  142. self.conv_bn_layer_5 = ConvBNLayer(
  143. in_channels=64,
  144. out_channels=64,
  145. kernel_size=3,
  146. stride=1,
  147. act="relu",
  148. name="FPN_d5",
  149. )
  150. self.conv_bn_layer_6 = ConvBNLayer(
  151. in_channels=64,
  152. out_channels=128,
  153. kernel_size=3,
  154. stride=2,
  155. act=None,
  156. name="FPN_d6",
  157. )
  158. self.conv_bn_layer_7 = ConvBNLayer(
  159. in_channels=128,
  160. out_channels=128,
  161. kernel_size=3,
  162. stride=1,
  163. act="relu",
  164. name="FPN_d7",
  165. )
  166. self.conv_bn_layer_8 = ConvBNLayer(
  167. in_channels=128,
  168. out_channels=128,
  169. kernel_size=1,
  170. stride=1,
  171. act=None,
  172. name="FPN_d8",
  173. )
  174. self.conv_h0 = ConvBNLayer(
  175. in_channels=num_inputs[0],
  176. out_channels=num_outputs[0],
  177. kernel_size=1,
  178. stride=1,
  179. act=None,
  180. name="conv_h{}".format(0),
  181. )
  182. self.conv_h1 = ConvBNLayer(
  183. in_channels=num_inputs[1],
  184. out_channels=num_outputs[1],
  185. kernel_size=1,
  186. stride=1,
  187. act=None,
  188. name="conv_h{}".format(1),
  189. )
  190. self.conv_h2 = ConvBNLayer(
  191. in_channels=num_inputs[2],
  192. out_channels=num_outputs[2],
  193. kernel_size=1,
  194. stride=1,
  195. act=None,
  196. name="conv_h{}".format(2),
  197. )
  198. self.conv_h3 = ConvBNLayer(
  199. in_channels=num_inputs[3],
  200. out_channels=num_outputs[3],
  201. kernel_size=1,
  202. stride=1,
  203. act=None,
  204. name="conv_h{}".format(3),
  205. )
  206. self.conv_h4 = ConvBNLayer(
  207. in_channels=num_inputs[4],
  208. out_channels=num_outputs[4],
  209. kernel_size=1,
  210. stride=1,
  211. act=None,
  212. name="conv_h{}".format(4),
  213. )
  214. self.dconv0 = DeConvBNLayer(
  215. in_channels=num_outputs[0],
  216. out_channels=num_outputs[0 + 1],
  217. name="dconv_{}".format(0),
  218. )
  219. self.dconv1 = DeConvBNLayer(
  220. in_channels=num_outputs[1],
  221. out_channels=num_outputs[1 + 1],
  222. act=None,
  223. name="dconv_{}".format(1),
  224. )
  225. self.dconv2 = DeConvBNLayer(
  226. in_channels=num_outputs[2],
  227. out_channels=num_outputs[2 + 1],
  228. act=None,
  229. name="dconv_{}".format(2),
  230. )
  231. self.dconv3 = DeConvBNLayer(
  232. in_channels=num_outputs[3],
  233. out_channels=num_outputs[3 + 1],
  234. act=None,
  235. name="dconv_{}".format(3),
  236. )
  237. self.conv_g1 = ConvBNLayer(
  238. in_channels=num_outputs[1],
  239. out_channels=num_outputs[1],
  240. kernel_size=3,
  241. stride=1,
  242. act="relu",
  243. name="conv_g{}".format(1),
  244. )
  245. self.conv_g2 = ConvBNLayer(
  246. in_channels=num_outputs[2],
  247. out_channels=num_outputs[2],
  248. kernel_size=3,
  249. stride=1,
  250. act="relu",
  251. name="conv_g{}".format(2),
  252. )
  253. self.conv_g3 = ConvBNLayer(
  254. in_channels=num_outputs[3],
  255. out_channels=num_outputs[3],
  256. kernel_size=3,
  257. stride=1,
  258. act="relu",
  259. name="conv_g{}".format(3),
  260. )
  261. self.conv_g4 = ConvBNLayer(
  262. in_channels=num_outputs[4],
  263. out_channels=num_outputs[4],
  264. kernel_size=3,
  265. stride=1,
  266. act="relu",
  267. name="conv_g{}".format(4),
  268. )
  269. self.convf = ConvBNLayer(
  270. in_channels=num_outputs[4],
  271. out_channels=num_outputs[4],
  272. kernel_size=1,
  273. stride=1,
  274. act=None,
  275. name="conv_f{}".format(4),
  276. )
  277. def forward(self, x):
  278. c0, c1, c2, c3, c4, c5, c6 = x
  279. # FPN_Down_Fusion
  280. f = [c0, c1, c2]
  281. g = [None, None, None]
  282. h = [None, None, None]
  283. h[0] = self.conv_bn_layer_1(f[0])
  284. h[1] = self.conv_bn_layer_2(f[1])
  285. h[2] = self.conv_bn_layer_3(f[2])
  286. g[0] = self.conv_bn_layer_4(h[0])
  287. g[1] = paddle.add(g[0], h[1])
  288. g[1] = F.relu(g[1])
  289. g[1] = self.conv_bn_layer_5(g[1])
  290. g[1] = self.conv_bn_layer_6(g[1])
  291. g[2] = paddle.add(g[1], h[2])
  292. g[2] = F.relu(g[2])
  293. g[2] = self.conv_bn_layer_7(g[2])
  294. f_down = self.conv_bn_layer_8(g[2])
  295. # FPN UP Fusion
  296. f1 = [c6, c5, c4, c3, c2]
  297. g = [None, None, None, None, None]
  298. h = [None, None, None, None, None]
  299. h[0] = self.conv_h0(f1[0])
  300. h[1] = self.conv_h1(f1[1])
  301. h[2] = self.conv_h2(f1[2])
  302. h[3] = self.conv_h3(f1[3])
  303. h[4] = self.conv_h4(f1[4])
  304. g[0] = self.dconv0(h[0])
  305. g[1] = paddle.add(g[0], h[1])
  306. g[1] = F.relu(g[1])
  307. g[1] = self.conv_g1(g[1])
  308. g[1] = self.dconv1(g[1])
  309. g[2] = paddle.add(g[1], h[2])
  310. g[2] = F.relu(g[2])
  311. g[2] = self.conv_g2(g[2])
  312. g[2] = self.dconv2(g[2])
  313. g[3] = paddle.add(g[2], h[3])
  314. g[3] = F.relu(g[3])
  315. g[3] = self.conv_g3(g[3])
  316. g[3] = self.dconv3(g[3])
  317. g[4] = paddle.add(x=g[3], y=h[4])
  318. g[4] = F.relu(g[4])
  319. g[4] = self.conv_g4(g[4])
  320. f_up = self.convf(g[4])
  321. f_common = paddle.add(f_down, f_up)
  322. f_common = F.relu(f_common)
  323. return f_common