table_fpn.py 4.2 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123
  1. # copyright (c) 2021 PaddlePaddle Authors. All Rights Reserve.
  2. #
  3. # Licensed under the Apache License, Version 2.0 (the "License");
  4. # you may not use this file except in compliance with the License.
  5. # You may obtain a copy of the License at
  6. #
  7. # http://www.apache.org/licenses/LICENSE-2.0
  8. #
  9. # Unless required by applicable law or agreed to in writing, software
  10. # distributed under the License is distributed on an "AS IS" BASIS,
  11. # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  12. # See the License for the specific language governing permissions and
  13. # limitations under the License.
  14. from __future__ import absolute_import
  15. from __future__ import division
  16. from __future__ import print_function
  17. import paddle
  18. from paddle import nn
  19. import paddle.nn.functional as F
  20. from paddle import ParamAttr
  21. class TableFPN(nn.Layer):
  22. def __init__(self, in_channels, out_channels, **kwargs):
  23. super(TableFPN, self).__init__()
  24. self.out_channels = 512
  25. weight_attr = paddle.nn.initializer.KaimingUniform()
  26. self.in2_conv = nn.Conv2D(
  27. in_channels=in_channels[0],
  28. out_channels=self.out_channels,
  29. kernel_size=1,
  30. weight_attr=ParamAttr(initializer=weight_attr),
  31. bias_attr=False,
  32. )
  33. self.in3_conv = nn.Conv2D(
  34. in_channels=in_channels[1],
  35. out_channels=self.out_channels,
  36. kernel_size=1,
  37. stride=1,
  38. weight_attr=ParamAttr(initializer=weight_attr),
  39. bias_attr=False,
  40. )
  41. self.in4_conv = nn.Conv2D(
  42. in_channels=in_channels[2],
  43. out_channels=self.out_channels,
  44. kernel_size=1,
  45. weight_attr=ParamAttr(initializer=weight_attr),
  46. bias_attr=False,
  47. )
  48. self.in5_conv = nn.Conv2D(
  49. in_channels=in_channels[3],
  50. out_channels=self.out_channels,
  51. kernel_size=1,
  52. weight_attr=ParamAttr(initializer=weight_attr),
  53. bias_attr=False,
  54. )
  55. self.p5_conv = nn.Conv2D(
  56. in_channels=self.out_channels,
  57. out_channels=self.out_channels // 4,
  58. kernel_size=3,
  59. padding=1,
  60. weight_attr=ParamAttr(initializer=weight_attr),
  61. bias_attr=False,
  62. )
  63. self.p4_conv = nn.Conv2D(
  64. in_channels=self.out_channels,
  65. out_channels=self.out_channels // 4,
  66. kernel_size=3,
  67. padding=1,
  68. weight_attr=ParamAttr(initializer=weight_attr),
  69. bias_attr=False,
  70. )
  71. self.p3_conv = nn.Conv2D(
  72. in_channels=self.out_channels,
  73. out_channels=self.out_channels // 4,
  74. kernel_size=3,
  75. padding=1,
  76. weight_attr=ParamAttr(initializer=weight_attr),
  77. bias_attr=False,
  78. )
  79. self.p2_conv = nn.Conv2D(
  80. in_channels=self.out_channels,
  81. out_channels=self.out_channels // 4,
  82. kernel_size=3,
  83. padding=1,
  84. weight_attr=ParamAttr(initializer=weight_attr),
  85. bias_attr=False,
  86. )
  87. self.fuse_conv = nn.Conv2D(
  88. in_channels=self.out_channels * 4,
  89. out_channels=512,
  90. kernel_size=3,
  91. padding=1,
  92. weight_attr=ParamAttr(initializer=weight_attr),
  93. bias_attr=False,
  94. )
  95. def forward(self, x):
  96. c2, c3, c4, c5 = x
  97. in5 = self.in5_conv(c5)
  98. in4 = self.in4_conv(c4)
  99. in3 = self.in3_conv(c3)
  100. in2 = self.in2_conv(c2)
  101. out4 = in4 + F.upsample(
  102. in5, size=in4.shape[2:4], mode="nearest", align_mode=1
  103. ) # 1/16
  104. out3 = in3 + F.upsample(
  105. out4, size=in3.shape[2:4], mode="nearest", align_mode=1
  106. ) # 1/8
  107. out2 = in2 + F.upsample(
  108. out3, size=in2.shape[2:4], mode="nearest", align_mode=1
  109. ) # 1/4
  110. p4 = F.upsample(out4, size=in5.shape[2:4], mode="nearest", align_mode=1)
  111. p3 = F.upsample(out3, size=in5.shape[2:4], mode="nearest", align_mode=1)
  112. p2 = F.upsample(out2, size=in5.shape[2:4], mode="nearest", align_mode=1)
  113. fuse = paddle.concat([in5, p4, p3, p2], axis=1)
  114. fuse_conv = self.fuse_conv(fuse) * 0.005
  115. return [c5 + fuse_conv]