pren_fpn.py 5.6 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177
  1. # copyright (c) 2022 PaddlePaddle Authors. All Rights Reserve.
  2. #
  3. # Licensed under the Apache License, Version 2.0 (the "License");
  4. # you may not use this file except in compliance with the License.
  5. # You may obtain a copy of the License at
  6. #
  7. # http://www.apache.org/licenses/LICENSE-2.0
  8. #
  9. # Unless required by applicable law or agreed to in writing, software
  10. # distributed under the License is distributed on an "AS IS" BASIS,
  11. # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  12. # See the License for the specific language governing permissions and
  13. # limitations under the License.
  14. """
  15. Code is refer from:
  16. https://github.com/RuijieJ/pren/blob/main/Nets/Aggregation.py
  17. """
  18. from __future__ import absolute_import
  19. from __future__ import division
  20. from __future__ import print_function
  21. import paddle
  22. from paddle import nn
  23. import paddle.nn.functional as F
  24. class PoolAggregate(nn.Layer):
  25. def __init__(self, n_r, d_in, d_middle=None, d_out=None):
  26. super(PoolAggregate, self).__init__()
  27. if not d_middle:
  28. d_middle = d_in
  29. if not d_out:
  30. d_out = d_in
  31. self.d_in = d_in
  32. self.d_middle = d_middle
  33. self.d_out = d_out
  34. self.act = nn.Swish()
  35. self.n_r = n_r
  36. self.aggs = self._build_aggs()
  37. def _build_aggs(self):
  38. aggs = []
  39. for i in range(self.n_r):
  40. aggs.append(
  41. self.add_sublayer(
  42. "{}".format(i),
  43. nn.Sequential(
  44. (
  45. "conv1",
  46. nn.Conv2D(
  47. self.d_in, self.d_middle, 3, 2, 1, bias_attr=False
  48. ),
  49. ),
  50. ("bn1", nn.BatchNorm(self.d_middle)),
  51. ("act", self.act),
  52. (
  53. "conv2",
  54. nn.Conv2D(
  55. self.d_middle, self.d_out, 3, 2, 1, bias_attr=False
  56. ),
  57. ),
  58. ("bn2", nn.BatchNorm(self.d_out)),
  59. ),
  60. )
  61. )
  62. return aggs
  63. def forward(self, x):
  64. b = x.shape[0]
  65. outs = []
  66. for agg in self.aggs:
  67. y = agg(x)
  68. p = F.adaptive_avg_pool2d(y, 1)
  69. outs.append(p.reshape((b, 1, self.d_out)))
  70. out = paddle.concat(outs, 1)
  71. return out
  72. class WeightAggregate(nn.Layer):
  73. def __init__(self, n_r, d_in, d_middle=None, d_out=None):
  74. super(WeightAggregate, self).__init__()
  75. if not d_middle:
  76. d_middle = d_in
  77. if not d_out:
  78. d_out = d_in
  79. self.n_r = n_r
  80. self.d_out = d_out
  81. self.act = nn.Swish()
  82. self.conv_n = nn.Sequential(
  83. ("conv1", nn.Conv2D(d_in, d_in, 3, 1, 1, bias_attr=False)),
  84. ("bn1", nn.BatchNorm(d_in)),
  85. ("act1", self.act),
  86. ("conv2", nn.Conv2D(d_in, n_r, 1, bias_attr=False)),
  87. ("bn2", nn.BatchNorm(n_r)),
  88. ("act2", nn.Sigmoid()),
  89. )
  90. self.conv_d = nn.Sequential(
  91. ("conv1", nn.Conv2D(d_in, d_middle, 3, 1, 1, bias_attr=False)),
  92. ("bn1", nn.BatchNorm(d_middle)),
  93. ("act1", self.act),
  94. ("conv2", nn.Conv2D(d_middle, d_out, 1, bias_attr=False)),
  95. ("bn2", nn.BatchNorm(d_out)),
  96. )
  97. def forward(self, x):
  98. b, _, h, w = x.shape
  99. hmaps = self.conv_n(x)
  100. fmaps = self.conv_d(x)
  101. r = paddle.bmm(
  102. hmaps.reshape((b, self.n_r, h * w)),
  103. fmaps.reshape((b, self.d_out, h * w)).transpose((0, 2, 1)),
  104. )
  105. return r
  106. class GCN(nn.Layer):
  107. def __init__(self, d_in, n_in, d_out=None, n_out=None, dropout=0.1):
  108. super(GCN, self).__init__()
  109. if not d_out:
  110. d_out = d_in
  111. if not n_out:
  112. n_out = d_in
  113. self.conv_n = nn.Conv1D(n_in, n_out, 1)
  114. self.linear = nn.Linear(d_in, d_out)
  115. self.dropout = nn.Dropout(dropout)
  116. self.act = nn.Swish()
  117. def forward(self, x):
  118. x = self.conv_n(x)
  119. x = self.dropout(self.linear(x))
  120. return self.act(x)
  121. class PRENFPN(nn.Layer):
  122. def __init__(self, in_channels, n_r, d_model, max_len, dropout):
  123. super(PRENFPN, self).__init__()
  124. assert len(in_channels) == 3, "in_channels' length must be 3."
  125. c1, c2, c3 = in_channels # the depths are from big to small
  126. # build fpn
  127. assert d_model % 3 == 0, "{} can't be divided by 3.".format(d_model)
  128. self.agg_p1 = PoolAggregate(n_r, c1, d_out=d_model // 3)
  129. self.agg_p2 = PoolAggregate(n_r, c2, d_out=d_model // 3)
  130. self.agg_p3 = PoolAggregate(n_r, c3, d_out=d_model // 3)
  131. self.agg_w1 = WeightAggregate(n_r, c1, 4 * c1, d_model // 3)
  132. self.agg_w2 = WeightAggregate(n_r, c2, 4 * c2, d_model // 3)
  133. self.agg_w3 = WeightAggregate(n_r, c3, 4 * c3, d_model // 3)
  134. self.gcn_pool = GCN(d_model, n_r, d_model, max_len, dropout)
  135. self.gcn_weight = GCN(d_model, n_r, d_model, max_len, dropout)
  136. self.out_channels = d_model
  137. def forward(self, inputs):
  138. f3, f5, f7 = inputs
  139. rp1 = self.agg_p1(f3)
  140. rp2 = self.agg_p2(f5)
  141. rp3 = self.agg_p3(f7)
  142. rp = paddle.concat([rp1, rp2, rp3], 2) # [b,nr,d]
  143. rw1 = self.agg_w1(f3)
  144. rw2 = self.agg_w2(f5)
  145. rw3 = self.agg_w3(f7)
  146. rw = paddle.concat([rw1, rw2, rw3], 2) # [b,nr,d]
  147. y1 = self.gcn_pool(rp)
  148. y2 = self.gcn_weight(rw)
  149. y = 0.5 * (y1 + y2)
  150. return y # [b,max_len,d]