DBHead.py 4.6 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132
  1. # -*- coding: utf-8 -*-
  2. # @Time : 2019/12/4 14:54
  3. # @Author : zhoujun
  4. import paddle
  5. from paddle import nn, ParamAttr
  6. class DBHead(nn.Layer):
  7. def __init__(self, in_channels, out_channels, k=50):
  8. super().__init__()
  9. self.k = k
  10. self.binarize = nn.Sequential(
  11. nn.Conv2D(
  12. in_channels,
  13. in_channels // 4,
  14. 3,
  15. padding=1,
  16. weight_attr=ParamAttr(initializer=nn.initializer.KaimingNormal()),
  17. ),
  18. nn.BatchNorm2D(
  19. in_channels // 4,
  20. weight_attr=ParamAttr(initializer=nn.initializer.Constant(1)),
  21. bias_attr=ParamAttr(initializer=nn.initializer.Constant(1e-4)),
  22. ),
  23. nn.ReLU(),
  24. nn.Conv2DTranspose(
  25. in_channels // 4,
  26. in_channels // 4,
  27. 2,
  28. 2,
  29. weight_attr=ParamAttr(initializer=nn.initializer.KaimingNormal()),
  30. ),
  31. nn.BatchNorm2D(
  32. in_channels // 4,
  33. weight_attr=ParamAttr(initializer=nn.initializer.Constant(1)),
  34. bias_attr=ParamAttr(initializer=nn.initializer.Constant(1e-4)),
  35. ),
  36. nn.ReLU(),
  37. nn.Conv2DTranspose(
  38. in_channels // 4, 1, 2, 2, weight_attr=nn.initializer.KaimingNormal()
  39. ),
  40. nn.Sigmoid(),
  41. )
  42. self.thresh = self._init_thresh(in_channels)
  43. def forward(self, x):
  44. shrink_maps = self.binarize(x)
  45. threshold_maps = self.thresh(x)
  46. if self.training:
  47. binary_maps = self.step_function(shrink_maps, threshold_maps)
  48. y = paddle.concat((shrink_maps, threshold_maps, binary_maps), axis=1)
  49. else:
  50. y = paddle.concat((shrink_maps, threshold_maps), axis=1)
  51. return y
  52. def _init_thresh(self, inner_channels, serial=False, smooth=False, bias=False):
  53. in_channels = inner_channels
  54. if serial:
  55. in_channels += 1
  56. self.thresh = nn.Sequential(
  57. nn.Conv2D(
  58. in_channels,
  59. inner_channels // 4,
  60. 3,
  61. padding=1,
  62. bias_attr=bias,
  63. weight_attr=ParamAttr(initializer=nn.initializer.KaimingNormal()),
  64. ),
  65. nn.BatchNorm2D(
  66. inner_channels // 4,
  67. weight_attr=ParamAttr(initializer=nn.initializer.Constant(1)),
  68. bias_attr=ParamAttr(initializer=nn.initializer.Constant(1e-4)),
  69. ),
  70. nn.ReLU(),
  71. self._init_upsample(
  72. inner_channels // 4, inner_channels // 4, smooth=smooth, bias=bias
  73. ),
  74. nn.BatchNorm2D(
  75. inner_channels // 4,
  76. weight_attr=ParamAttr(initializer=nn.initializer.Constant(1)),
  77. bias_attr=ParamAttr(initializer=nn.initializer.Constant(1e-4)),
  78. ),
  79. nn.ReLU(),
  80. self._init_upsample(inner_channels // 4, 1, smooth=smooth, bias=bias),
  81. nn.Sigmoid(),
  82. )
  83. return self.thresh
  84. def _init_upsample(self, in_channels, out_channels, smooth=False, bias=False):
  85. if smooth:
  86. inter_out_channels = out_channels
  87. if out_channels == 1:
  88. inter_out_channels = in_channels
  89. module_list = [
  90. nn.Upsample(scale_factor=2, mode="nearest"),
  91. nn.Conv2D(
  92. in_channels,
  93. inter_out_channels,
  94. 3,
  95. 1,
  96. 1,
  97. bias_attr=bias,
  98. weight_attr=ParamAttr(initializer=nn.initializer.KaimingNormal()),
  99. ),
  100. ]
  101. if out_channels == 1:
  102. module_list.append(
  103. nn.Conv2D(
  104. in_channels,
  105. out_channels,
  106. kernel_size=1,
  107. stride=1,
  108. padding=1,
  109. bias_attr=True,
  110. weight_attr=ParamAttr(
  111. initializer=nn.initializer.KaimingNormal()
  112. ),
  113. )
  114. )
  115. return nn.Sequential(module_list)
  116. else:
  117. return nn.Conv2DTranspose(
  118. in_channels,
  119. out_channels,
  120. 2,
  121. 2,
  122. weight_attr=ParamAttr(initializer=nn.initializer.KaimingNormal()),
  123. )
  124. def step_function(self, x, y):
  125. return paddle.reciprocal(1 + paddle.exp(-self.k * (x - y)))