sast_fpn.py 11 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368
  1. # copyright (c) 2019 PaddlePaddle Authors. All Rights Reserve.
  2. #
  3. # Licensed under the Apache License, Version 2.0 (the "License");
  4. # you may not use this file except in compliance with the License.
  5. # You may obtain a copy of the License at
  6. #
  7. # http://www.apache.org/licenses/LICENSE-2.0
  8. #
  9. # Unless required by applicable law or agreed to in writing, software
  10. # distributed under the License is distributed on an "AS IS" BASIS,
  11. # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  12. # See the License for the specific language governing permissions and
  13. # limitations under the License.
  14. from __future__ import absolute_import
  15. from __future__ import division
  16. from __future__ import print_function
  17. import paddle
  18. from paddle import nn
  19. import paddle.nn.functional as F
  20. from paddle import ParamAttr
  21. class ConvBNLayer(nn.Layer):
  22. def __init__(
  23. self,
  24. in_channels,
  25. out_channels,
  26. kernel_size,
  27. stride,
  28. groups=1,
  29. if_act=True,
  30. act=None,
  31. name=None,
  32. ):
  33. super(ConvBNLayer, self).__init__()
  34. self.if_act = if_act
  35. self.act = act
  36. self.conv = nn.Conv2D(
  37. in_channels=in_channels,
  38. out_channels=out_channels,
  39. kernel_size=kernel_size,
  40. stride=stride,
  41. padding=(kernel_size - 1) // 2,
  42. groups=groups,
  43. weight_attr=ParamAttr(name=name + "_weights"),
  44. bias_attr=False,
  45. )
  46. self.bn = nn.BatchNorm(
  47. num_channels=out_channels,
  48. act=act,
  49. param_attr=ParamAttr(name="bn_" + name + "_scale"),
  50. bias_attr=ParamAttr(name="bn_" + name + "_offset"),
  51. moving_mean_name="bn_" + name + "_mean",
  52. moving_variance_name="bn_" + name + "_variance",
  53. )
  54. def forward(self, x):
  55. x = self.conv(x)
  56. x = self.bn(x)
  57. return x
  58. class DeConvBNLayer(nn.Layer):
  59. def __init__(
  60. self,
  61. in_channels,
  62. out_channels,
  63. kernel_size,
  64. stride,
  65. groups=1,
  66. if_act=True,
  67. act=None,
  68. name=None,
  69. ):
  70. super(DeConvBNLayer, self).__init__()
  71. self.if_act = if_act
  72. self.act = act
  73. self.deconv = nn.Conv2DTranspose(
  74. in_channels=in_channels,
  75. out_channels=out_channels,
  76. kernel_size=kernel_size,
  77. stride=stride,
  78. padding=(kernel_size - 1) // 2,
  79. groups=groups,
  80. weight_attr=ParamAttr(name=name + "_weights"),
  81. bias_attr=False,
  82. )
  83. self.bn = nn.BatchNorm(
  84. num_channels=out_channels,
  85. act=act,
  86. param_attr=ParamAttr(name="bn_" + name + "_scale"),
  87. bias_attr=ParamAttr(name="bn_" + name + "_offset"),
  88. moving_mean_name="bn_" + name + "_mean",
  89. moving_variance_name="bn_" + name + "_variance",
  90. )
  91. def forward(self, x):
  92. x = self.deconv(x)
  93. x = self.bn(x)
  94. return x
  95. class FPN_Up_Fusion(nn.Layer):
  96. def __init__(self, in_channels):
  97. super(FPN_Up_Fusion, self).__init__()
  98. in_channels = in_channels[::-1]
  99. out_channels = [256, 256, 192, 192, 128]
  100. self.h0_conv = ConvBNLayer(
  101. in_channels[0], out_channels[0], 1, 1, act=None, name="fpn_up_h0"
  102. )
  103. self.h1_conv = ConvBNLayer(
  104. in_channels[1], out_channels[1], 1, 1, act=None, name="fpn_up_h1"
  105. )
  106. self.h2_conv = ConvBNLayer(
  107. in_channels[2], out_channels[2], 1, 1, act=None, name="fpn_up_h2"
  108. )
  109. self.h3_conv = ConvBNLayer(
  110. in_channels[3], out_channels[3], 1, 1, act=None, name="fpn_up_h3"
  111. )
  112. self.h4_conv = ConvBNLayer(
  113. in_channels[4], out_channels[4], 1, 1, act=None, name="fpn_up_h4"
  114. )
  115. self.g0_conv = DeConvBNLayer(
  116. out_channels[0], out_channels[1], 4, 2, act=None, name="fpn_up_g0"
  117. )
  118. self.g1_conv = nn.Sequential(
  119. ConvBNLayer(
  120. out_channels[1], out_channels[1], 3, 1, act="relu", name="fpn_up_g1_1"
  121. ),
  122. DeConvBNLayer(
  123. out_channels[1], out_channels[2], 4, 2, act=None, name="fpn_up_g1_2"
  124. ),
  125. )
  126. self.g2_conv = nn.Sequential(
  127. ConvBNLayer(
  128. out_channels[2], out_channels[2], 3, 1, act="relu", name="fpn_up_g2_1"
  129. ),
  130. DeConvBNLayer(
  131. out_channels[2], out_channels[3], 4, 2, act=None, name="fpn_up_g2_2"
  132. ),
  133. )
  134. self.g3_conv = nn.Sequential(
  135. ConvBNLayer(
  136. out_channels[3], out_channels[3], 3, 1, act="relu", name="fpn_up_g3_1"
  137. ),
  138. DeConvBNLayer(
  139. out_channels[3], out_channels[4], 4, 2, act=None, name="fpn_up_g3_2"
  140. ),
  141. )
  142. self.g4_conv = nn.Sequential(
  143. ConvBNLayer(
  144. out_channels[4],
  145. out_channels[4],
  146. 3,
  147. 1,
  148. act="relu",
  149. name="fpn_up_fusion_1",
  150. ),
  151. ConvBNLayer(
  152. out_channels[4], out_channels[4], 1, 1, act=None, name="fpn_up_fusion_2"
  153. ),
  154. )
  155. def _add_relu(self, x1, x2):
  156. x = paddle.add(x=x1, y=x2)
  157. x = F.relu(x)
  158. return x
  159. def forward(self, x):
  160. f = x[2:][::-1]
  161. h0 = self.h0_conv(f[0])
  162. h1 = self.h1_conv(f[1])
  163. h2 = self.h2_conv(f[2])
  164. h3 = self.h3_conv(f[3])
  165. h4 = self.h4_conv(f[4])
  166. g0 = self.g0_conv(h0)
  167. g1 = self._add_relu(g0, h1)
  168. g1 = self.g1_conv(g1)
  169. g2 = self.g2_conv(self._add_relu(g1, h2))
  170. g3 = self.g3_conv(self._add_relu(g2, h3))
  171. g4 = self.g4_conv(self._add_relu(g3, h4))
  172. return g4
  173. class FPN_Down_Fusion(nn.Layer):
  174. def __init__(self, in_channels):
  175. super(FPN_Down_Fusion, self).__init__()
  176. out_channels = [32, 64, 128]
  177. self.h0_conv = ConvBNLayer(
  178. in_channels[0], out_channels[0], 3, 1, act=None, name="fpn_down_h0"
  179. )
  180. self.h1_conv = ConvBNLayer(
  181. in_channels[1], out_channels[1], 3, 1, act=None, name="fpn_down_h1"
  182. )
  183. self.h2_conv = ConvBNLayer(
  184. in_channels[2], out_channels[2], 3, 1, act=None, name="fpn_down_h2"
  185. )
  186. self.g0_conv = ConvBNLayer(
  187. out_channels[0], out_channels[1], 3, 2, act=None, name="fpn_down_g0"
  188. )
  189. self.g1_conv = nn.Sequential(
  190. ConvBNLayer(
  191. out_channels[1], out_channels[1], 3, 1, act="relu", name="fpn_down_g1_1"
  192. ),
  193. ConvBNLayer(
  194. out_channels[1], out_channels[2], 3, 2, act=None, name="fpn_down_g1_2"
  195. ),
  196. )
  197. self.g2_conv = nn.Sequential(
  198. ConvBNLayer(
  199. out_channels[2],
  200. out_channels[2],
  201. 3,
  202. 1,
  203. act="relu",
  204. name="fpn_down_fusion_1",
  205. ),
  206. ConvBNLayer(
  207. out_channels[2],
  208. out_channels[2],
  209. 1,
  210. 1,
  211. act=None,
  212. name="fpn_down_fusion_2",
  213. ),
  214. )
  215. def forward(self, x):
  216. f = x[:3]
  217. h0 = self.h0_conv(f[0])
  218. h1 = self.h1_conv(f[1])
  219. h2 = self.h2_conv(f[2])
  220. g0 = self.g0_conv(h0)
  221. g1 = paddle.add(x=g0, y=h1)
  222. g1 = F.relu(g1)
  223. g1 = self.g1_conv(g1)
  224. g2 = paddle.add(x=g1, y=h2)
  225. g2 = F.relu(g2)
  226. g2 = self.g2_conv(g2)
  227. return g2
  228. class Cross_Attention(nn.Layer):
  229. def __init__(self, in_channels):
  230. super(Cross_Attention, self).__init__()
  231. self.theta_conv = ConvBNLayer(
  232. in_channels, in_channels, 1, 1, act="relu", name="f_theta"
  233. )
  234. self.phi_conv = ConvBNLayer(
  235. in_channels, in_channels, 1, 1, act="relu", name="f_phi"
  236. )
  237. self.g_conv = ConvBNLayer(
  238. in_channels, in_channels, 1, 1, act="relu", name="f_g"
  239. )
  240. self.fh_weight_conv = ConvBNLayer(
  241. in_channels, in_channels, 1, 1, act=None, name="fh_weight"
  242. )
  243. self.fh_sc_conv = ConvBNLayer(
  244. in_channels, in_channels, 1, 1, act=None, name="fh_sc"
  245. )
  246. self.fv_weight_conv = ConvBNLayer(
  247. in_channels, in_channels, 1, 1, act=None, name="fv_weight"
  248. )
  249. self.fv_sc_conv = ConvBNLayer(
  250. in_channels, in_channels, 1, 1, act=None, name="fv_sc"
  251. )
  252. self.f_attn_conv = ConvBNLayer(
  253. in_channels * 2, in_channels, 1, 1, act="relu", name="f_attn"
  254. )
  255. def _cal_fweight(self, f, shape):
  256. f_theta, f_phi, f_g = f
  257. # flatten
  258. f_theta = paddle.transpose(f_theta, [0, 2, 3, 1])
  259. f_theta = paddle.reshape(f_theta, [shape[0] * shape[1], shape[2], 128])
  260. f_phi = paddle.transpose(f_phi, [0, 2, 3, 1])
  261. f_phi = paddle.reshape(f_phi, [shape[0] * shape[1], shape[2], 128])
  262. f_g = paddle.transpose(f_g, [0, 2, 3, 1])
  263. f_g = paddle.reshape(f_g, [shape[0] * shape[1], shape[2], 128])
  264. # correlation
  265. f_attn = paddle.matmul(f_theta, paddle.transpose(f_phi, [0, 2, 1]))
  266. # scale
  267. f_attn = f_attn / (128**0.5)
  268. f_attn = F.softmax(f_attn)
  269. # weighted sum
  270. f_weight = paddle.matmul(f_attn, f_g)
  271. f_weight = paddle.reshape(f_weight, [shape[0], shape[1], shape[2], 128])
  272. return f_weight
  273. def forward(self, f_common):
  274. f_shape = f_common.shape
  275. # print('f_shape: ', f_shape)
  276. f_theta = self.theta_conv(f_common)
  277. f_phi = self.phi_conv(f_common)
  278. f_g = self.g_conv(f_common)
  279. ######## horizon ########
  280. fh_weight = self._cal_fweight(
  281. [f_theta, f_phi, f_g], [f_shape[0], f_shape[2], f_shape[3]]
  282. )
  283. fh_weight = paddle.transpose(fh_weight, [0, 3, 1, 2])
  284. fh_weight = self.fh_weight_conv(fh_weight)
  285. # short cut
  286. fh_sc = self.fh_sc_conv(f_common)
  287. f_h = F.relu(fh_weight + fh_sc)
  288. ######## vertical ########
  289. fv_theta = paddle.transpose(f_theta, [0, 1, 3, 2])
  290. fv_phi = paddle.transpose(f_phi, [0, 1, 3, 2])
  291. fv_g = paddle.transpose(f_g, [0, 1, 3, 2])
  292. fv_weight = self._cal_fweight(
  293. [fv_theta, fv_phi, fv_g], [f_shape[0], f_shape[3], f_shape[2]]
  294. )
  295. fv_weight = paddle.transpose(fv_weight, [0, 3, 2, 1])
  296. fv_weight = self.fv_weight_conv(fv_weight)
  297. # short cut
  298. fv_sc = self.fv_sc_conv(f_common)
  299. f_v = F.relu(fv_weight + fv_sc)
  300. ######## merge ########
  301. f_attn = paddle.concat([f_h, f_v], axis=1)
  302. f_attn = self.f_attn_conv(f_attn)
  303. return f_attn
  304. class SASTFPN(nn.Layer):
  305. def __init__(self, in_channels, with_cab=False, **kwargs):
  306. super(SASTFPN, self).__init__()
  307. self.in_channels = in_channels
  308. self.with_cab = with_cab
  309. self.FPN_Down_Fusion = FPN_Down_Fusion(self.in_channels)
  310. self.FPN_Up_Fusion = FPN_Up_Fusion(self.in_channels)
  311. self.out_channels = 128
  312. self.cross_attention = Cross_Attention(self.out_channels)
  313. def forward(self, x):
  314. # down fpn
  315. f_down = self.FPN_Down_Fusion(x)
  316. # up fpn
  317. f_up = self.FPN_Up_Fusion(x)
  318. # fusion
  319. f_common = paddle.add(x=f_down, y=f_up)
  320. f_common = F.relu(f_common)
  321. if self.with_cab:
  322. # print('enhence f_common with CAB.')
  323. f_common = self.cross_attention(f_common)
  324. return f_common