FPN.py 2.5 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475
  1. # -*- coding: utf-8 -*-
  2. # @Time : 2019/9/13 10:29
  3. # @Author : zhoujun
  4. import paddle
  5. import paddle.nn.functional as F
  6. from paddle import nn
  7. from models.basic import ConvBnRelu
  8. class FPN(nn.Layer):
  9. def __init__(self, in_channels, inner_channels=256, **kwargs):
  10. """
  11. :param in_channels: 基础网络输出的维度
  12. :param kwargs:
  13. """
  14. super().__init__()
  15. inplace = True
  16. self.conv_out = inner_channels
  17. inner_channels = inner_channels // 4
  18. # reduce layers
  19. self.reduce_conv_c2 = ConvBnRelu(
  20. in_channels[0], inner_channels, kernel_size=1, inplace=inplace
  21. )
  22. self.reduce_conv_c3 = ConvBnRelu(
  23. in_channels[1], inner_channels, kernel_size=1, inplace=inplace
  24. )
  25. self.reduce_conv_c4 = ConvBnRelu(
  26. in_channels[2], inner_channels, kernel_size=1, inplace=inplace
  27. )
  28. self.reduce_conv_c5 = ConvBnRelu(
  29. in_channels[3], inner_channels, kernel_size=1, inplace=inplace
  30. )
  31. # Smooth layers
  32. self.smooth_p4 = ConvBnRelu(
  33. inner_channels, inner_channels, kernel_size=3, padding=1, inplace=inplace
  34. )
  35. self.smooth_p3 = ConvBnRelu(
  36. inner_channels, inner_channels, kernel_size=3, padding=1, inplace=inplace
  37. )
  38. self.smooth_p2 = ConvBnRelu(
  39. inner_channels, inner_channels, kernel_size=3, padding=1, inplace=inplace
  40. )
  41. self.conv = nn.Sequential(
  42. nn.Conv2D(self.conv_out, self.conv_out, kernel_size=3, padding=1, stride=1),
  43. nn.BatchNorm2D(self.conv_out),
  44. nn.ReLU(),
  45. )
  46. self.out_channels = self.conv_out
  47. def forward(self, x):
  48. c2, c3, c4, c5 = x
  49. # Top-down
  50. p5 = self.reduce_conv_c5(c5)
  51. p4 = self._upsample_add(p5, self.reduce_conv_c4(c4))
  52. p4 = self.smooth_p4(p4)
  53. p3 = self._upsample_add(p4, self.reduce_conv_c3(c3))
  54. p3 = self.smooth_p3(p3)
  55. p2 = self._upsample_add(p3, self.reduce_conv_c2(c2))
  56. p2 = self.smooth_p2(p2)
  57. x = self._upsample_cat(p2, p3, p4, p5)
  58. x = self.conv(x)
  59. return x
  60. def _upsample_add(self, x, y):
  61. return F.interpolate(x, size=y.shape[2:]) + y
  62. def _upsample_cat(self, p2, p3, p4, p5):
  63. h, w = p2.shape[2:]
  64. p3 = F.interpolate(p3, size=(h, w))
  65. p4 = F.interpolate(p4, size=(h, w))
  66. p5 = F.interpolate(p5, size=(h, w))
  67. return paddle.concat([p2, p3, p4, p5], axis=1)