det_basic_loss.py 5.3 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161
  1. # copyright (c) 2019 PaddlePaddle Authors. All Rights Reserve.
  2. #
  3. # Licensed under the Apache License, Version 2.0 (the "License");
  4. # you may not use this file except in compliance with the License.
  5. # You may obtain a copy of the License at
  6. #
  7. # http://www.apache.org/licenses/LICENSE-2.0
  8. #
  9. # Unless required by applicable law or agreed to in writing, software
  10. # distributed under the License is distributed on an "AS IS" BASIS,
  11. # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  12. # See the License for the specific language governing permissions and
  13. # limitations under the License.
  14. """
  15. This code is refer from:
  16. https://github.com/WenmuZhou/DBNet.pytorch/blob/master/models/losses/basic_loss.py
  17. """
  18. from __future__ import absolute_import
  19. from __future__ import division
  20. from __future__ import print_function
  21. import numpy as np
  22. import paddle
  23. from paddle import nn
  24. import paddle.nn.functional as F
  25. class BalanceLoss(nn.Layer):
  26. def __init__(
  27. self,
  28. balance_loss=True,
  29. main_loss_type="DiceLoss",
  30. negative_ratio=3,
  31. return_origin=False,
  32. eps=1e-6,
  33. **kwargs,
  34. ):
  35. """
  36. The BalanceLoss for Differentiable Binarization text detection
  37. args:
  38. balance_loss (bool): whether balance loss or not, default is True
  39. main_loss_type (str): can only be one of ['CrossEntropy','DiceLoss',
  40. 'Euclidean','BCELoss', 'MaskL1Loss'], default is 'DiceLoss'.
  41. negative_ratio (int|float): float, default is 3.
  42. return_origin (bool): whether return unbalanced loss or not, default is False.
  43. eps (float): default is 1e-6.
  44. """
  45. super(BalanceLoss, self).__init__()
  46. self.balance_loss = balance_loss
  47. self.main_loss_type = main_loss_type
  48. self.negative_ratio = negative_ratio
  49. self.return_origin = return_origin
  50. self.eps = eps
  51. if self.main_loss_type == "CrossEntropy":
  52. self.loss = nn.CrossEntropyLoss()
  53. elif self.main_loss_type == "Euclidean":
  54. self.loss = nn.MSELoss()
  55. elif self.main_loss_type == "DiceLoss":
  56. self.loss = DiceLoss(self.eps)
  57. elif self.main_loss_type == "BCELoss":
  58. self.loss = BCELoss(reduction="none")
  59. elif self.main_loss_type == "MaskL1Loss":
  60. self.loss = MaskL1Loss(self.eps)
  61. else:
  62. loss_type = [
  63. "CrossEntropy",
  64. "DiceLoss",
  65. "Euclidean",
  66. "BCELoss",
  67. "MaskL1Loss",
  68. ]
  69. raise Exception(
  70. "main_loss_type in BalanceLoss() can only be one of {}".format(
  71. loss_type
  72. )
  73. )
  74. def forward(self, pred, gt, mask=None):
  75. """
  76. The BalanceLoss for Differentiable Binarization text detection
  77. args:
  78. pred (variable): predicted feature maps.
  79. gt (variable): ground truth feature maps.
  80. mask (variable): masked maps.
  81. return: (variable) balanced loss
  82. """
  83. positive = gt * mask
  84. negative = (1 - gt) * mask
  85. positive_count = int(positive.sum())
  86. negative_count = int(min(negative.sum(), positive_count * self.negative_ratio))
  87. loss = self.loss(pred, gt, mask=mask)
  88. if not self.balance_loss:
  89. return loss
  90. positive_loss = positive * loss
  91. negative_loss = negative * loss
  92. negative_loss = paddle.reshape(negative_loss, shape=[-1])
  93. if negative_count > 0:
  94. sort_loss = negative_loss.sort(descending=True)
  95. negative_loss = sort_loss[:negative_count]
  96. # negative_loss, _ = paddle.topk(negative_loss, k=negative_count_int)
  97. balance_loss = (positive_loss.sum() + negative_loss.sum()) / (
  98. positive_count + negative_count + self.eps
  99. )
  100. else:
  101. balance_loss = positive_loss.sum() / (positive_count + self.eps)
  102. if self.return_origin:
  103. return balance_loss, loss
  104. return balance_loss
  105. class DiceLoss(nn.Layer):
  106. def __init__(self, eps=1e-6):
  107. super(DiceLoss, self).__init__()
  108. self.eps = eps
  109. def forward(self, pred, gt, mask, weights=None):
  110. """
  111. DiceLoss function.
  112. """
  113. assert pred.shape == gt.shape
  114. assert pred.shape == mask.shape
  115. if weights is not None:
  116. assert weights.shape == mask.shape
  117. mask = weights * mask
  118. intersection = paddle.sum(pred * gt * mask)
  119. union = paddle.sum(pred * mask) + paddle.sum(gt * mask) + self.eps
  120. loss = 1 - 2.0 * intersection / union
  121. assert loss <= 1
  122. return loss
  123. class MaskL1Loss(nn.Layer):
  124. def __init__(self, eps=1e-6):
  125. super(MaskL1Loss, self).__init__()
  126. self.eps = eps
  127. def forward(self, pred, gt, mask):
  128. """
  129. Mask L1 Loss
  130. """
  131. loss = (paddle.abs(pred - gt) * mask).sum() / (mask.sum() + self.eps)
  132. loss = paddle.mean(loss)
  133. return loss
  134. class BCELoss(nn.Layer):
  135. def __init__(self, reduction="mean"):
  136. super(BCELoss, self).__init__()
  137. self.reduction = reduction
  138. def forward(self, input, label, mask=None, weight=None, name=None):
  139. loss = F.binary_cross_entropy(input, label, reduction=self.reduction)
  140. return loss