ct_fpn.py 5.8 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188
  1. # copyright (c) 2019 PaddlePaddle Authors. All Rights Reserve.
  2. #
  3. # Licensed under the Apache License, Version 2.0 (the "License");
  4. # you may not use this file except in compliance with the License.
  5. # You may obtain a copy of the License at
  6. #
  7. # http://www.apache.org/licenses/LICENSE-2.0
  8. #
  9. # Unless required by applicable law or agreed to in writing, software
  10. # distributed under the License is distributed on an "AS IS" BASIS,
  11. # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  12. # See the License for the specific language governing permissions and
  13. # limitations under the License.
  14. from __future__ import absolute_import
  15. from __future__ import division
  16. from __future__ import print_function
  17. import paddle
  18. from paddle import nn
  19. import paddle.nn.functional as F
  20. from paddle import ParamAttr
  21. import os
  22. import sys
  23. import math
  24. from paddle.nn.initializer import TruncatedNormal, Constant, Normal
  25. ones_ = Constant(value=1.0)
  26. zeros_ = Constant(value=0.0)
  27. __dir__ = os.path.dirname(os.path.abspath(__file__))
  28. sys.path.append(__dir__)
  29. sys.path.insert(0, os.path.abspath(os.path.join(__dir__, "../../..")))
  30. class Conv_BN_ReLU(nn.Layer):
  31. def __init__(self, in_planes, out_planes, kernel_size=1, stride=1, padding=0):
  32. super(Conv_BN_ReLU, self).__init__()
  33. self.conv = nn.Conv2D(
  34. in_planes,
  35. out_planes,
  36. kernel_size=kernel_size,
  37. stride=stride,
  38. padding=padding,
  39. bias_attr=False,
  40. )
  41. self.bn = nn.BatchNorm2D(out_planes)
  42. self.relu = nn.ReLU()
  43. for m in self.sublayers():
  44. if isinstance(m, nn.Conv2D):
  45. n = m._kernel_size[0] * m._kernel_size[1] * m._out_channels
  46. normal_ = Normal(mean=0.0, std=math.sqrt(2.0 / n))
  47. normal_(m.weight)
  48. elif isinstance(m, nn.BatchNorm2D):
  49. zeros_(m.bias)
  50. ones_(m.weight)
  51. def forward(self, x):
  52. return self.relu(self.bn(self.conv(x)))
  53. class FPEM(nn.Layer):
  54. def __init__(self, in_channels, out_channels):
  55. super(FPEM, self).__init__()
  56. planes = out_channels
  57. self.dwconv3_1 = nn.Conv2D(
  58. planes,
  59. planes,
  60. kernel_size=3,
  61. stride=1,
  62. padding=1,
  63. groups=planes,
  64. bias_attr=False,
  65. )
  66. self.smooth_layer3_1 = Conv_BN_ReLU(planes, planes)
  67. self.dwconv2_1 = nn.Conv2D(
  68. planes,
  69. planes,
  70. kernel_size=3,
  71. stride=1,
  72. padding=1,
  73. groups=planes,
  74. bias_attr=False,
  75. )
  76. self.smooth_layer2_1 = Conv_BN_ReLU(planes, planes)
  77. self.dwconv1_1 = nn.Conv2D(
  78. planes,
  79. planes,
  80. kernel_size=3,
  81. stride=1,
  82. padding=1,
  83. groups=planes,
  84. bias_attr=False,
  85. )
  86. self.smooth_layer1_1 = Conv_BN_ReLU(planes, planes)
  87. self.dwconv2_2 = nn.Conv2D(
  88. planes,
  89. planes,
  90. kernel_size=3,
  91. stride=2,
  92. padding=1,
  93. groups=planes,
  94. bias_attr=False,
  95. )
  96. self.smooth_layer2_2 = Conv_BN_ReLU(planes, planes)
  97. self.dwconv3_2 = nn.Conv2D(
  98. planes,
  99. planes,
  100. kernel_size=3,
  101. stride=2,
  102. padding=1,
  103. groups=planes,
  104. bias_attr=False,
  105. )
  106. self.smooth_layer3_2 = Conv_BN_ReLU(planes, planes)
  107. self.dwconv4_2 = nn.Conv2D(
  108. planes,
  109. planes,
  110. kernel_size=3,
  111. stride=2,
  112. padding=1,
  113. groups=planes,
  114. bias_attr=False,
  115. )
  116. self.smooth_layer4_2 = Conv_BN_ReLU(planes, planes)
  117. def _upsample_add(self, x, y):
  118. return F.upsample(x, scale_factor=2, mode="bilinear") + y
  119. def forward(self, f1, f2, f3, f4):
  120. # up-down
  121. f3 = self.smooth_layer3_1(self.dwconv3_1(self._upsample_add(f4, f3)))
  122. f2 = self.smooth_layer2_1(self.dwconv2_1(self._upsample_add(f3, f2)))
  123. f1 = self.smooth_layer1_1(self.dwconv1_1(self._upsample_add(f2, f1)))
  124. # down-up
  125. f2 = self.smooth_layer2_2(self.dwconv2_2(self._upsample_add(f2, f1)))
  126. f3 = self.smooth_layer3_2(self.dwconv3_2(self._upsample_add(f3, f2)))
  127. f4 = self.smooth_layer4_2(self.dwconv4_2(self._upsample_add(f4, f3)))
  128. return f1, f2, f3, f4
  129. class CTFPN(nn.Layer):
  130. def __init__(self, in_channels, out_channel=128):
  131. super(CTFPN, self).__init__()
  132. self.out_channels = out_channel * 4
  133. self.reduce_layer1 = Conv_BN_ReLU(in_channels[0], 128)
  134. self.reduce_layer2 = Conv_BN_ReLU(in_channels[1], 128)
  135. self.reduce_layer3 = Conv_BN_ReLU(in_channels[2], 128)
  136. self.reduce_layer4 = Conv_BN_ReLU(in_channels[3], 128)
  137. self.fpem1 = FPEM(in_channels=(64, 128, 256, 512), out_channels=128)
  138. self.fpem2 = FPEM(in_channels=(64, 128, 256, 512), out_channels=128)
  139. def _upsample(self, x, scale=1):
  140. return F.upsample(x, scale_factor=scale, mode="bilinear")
  141. def forward(self, f):
  142. # # reduce channel
  143. f1 = self.reduce_layer1(f[0]) # N,64,160,160 --> N, 128, 160, 160
  144. f2 = self.reduce_layer2(f[1]) # N, 128, 80, 80 --> N, 128, 80, 80
  145. f3 = self.reduce_layer3(f[2]) # N, 256, 40, 40 --> N, 128, 40, 40
  146. f4 = self.reduce_layer4(f[3]) # N, 512, 20, 20 --> N, 128, 20, 20
  147. # FPEM
  148. f1_1, f2_1, f3_1, f4_1 = self.fpem1(f1, f2, f3, f4)
  149. f1_2, f2_2, f3_2, f4_2 = self.fpem2(f1_1, f2_1, f3_1, f4_1)
  150. # FFM
  151. f1 = f1_1 + f1_2
  152. f2 = f2_1 + f2_2
  153. f3 = f3_1 + f3_2
  154. f4 = f4_1 + f4_2
  155. f2 = self._upsample(f2, scale=2)
  156. f3 = self._upsample(f3, scale=4)
  157. f4 = self._upsample(f4, scale=8)
  158. ff = paddle.concat((f1, f2, f3, f4), 1) # N,512, 160,160
  159. return ff