basic_loss.py 3.0 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101
  1. # -*- coding: utf-8 -*-
  2. # @Time : 2019/12/4 14:39
  3. # @Author : zhoujun
  4. import paddle
  5. import paddle.nn as nn
  6. class BalanceCrossEntropyLoss(nn.Layer):
  7. """
  8. Balanced cross entropy loss.
  9. Shape:
  10. - Input: :math:`(N, 1, H, W)`
  11. - GT: :math:`(N, 1, H, W)`, same shape as the input
  12. - Mask: :math:`(N, H, W)`, same spatial shape as the input
  13. - Output: scalar.
  14. """
  15. def __init__(self, negative_ratio=3.0, eps=1e-6):
  16. super(BalanceCrossEntropyLoss, self).__init__()
  17. self.negative_ratio = negative_ratio
  18. self.eps = eps
  19. def forward(
  20. self,
  21. pred: paddle.Tensor,
  22. gt: paddle.Tensor,
  23. mask: paddle.Tensor,
  24. return_origin=False,
  25. ):
  26. """
  27. Args:
  28. pred: shape :math:`(N, 1, H, W)`, the prediction of network
  29. gt: shape :math:`(N, 1, H, W)`, the target
  30. mask: shape :math:`(N, H, W)`, the mask indicates positive regions
  31. """
  32. positive = gt * mask
  33. negative = (1 - gt) * mask
  34. positive_count = int(positive.sum())
  35. negative_count = min(
  36. int(negative.sum()), int(positive_count * self.negative_ratio)
  37. )
  38. loss = nn.functional.binary_cross_entropy(pred, gt, reduction="none")
  39. positive_loss = loss * positive
  40. negative_loss = loss * negative
  41. negative_loss, _ = negative_loss.reshape([-1]).topk(negative_count)
  42. balance_loss = (positive_loss.sum() + negative_loss.sum()) / (
  43. positive_count + negative_count + self.eps
  44. )
  45. if return_origin:
  46. return balance_loss, loss
  47. return balance_loss
  48. class DiceLoss(nn.Layer):
  49. """
  50. Loss function from https://arxiv.org/abs/1707.03237,
  51. where iou computation is introduced heatmap manner to measure the
  52. diversity between tow heatmaps.
  53. """
  54. def __init__(self, eps=1e-6):
  55. super(DiceLoss, self).__init__()
  56. self.eps = eps
  57. def forward(self, pred: paddle.Tensor, gt, mask, weights=None):
  58. """
  59. pred: one or two heatmaps of shape (N, 1, H, W),
  60. the losses of tow heatmaps are added together.
  61. gt: (N, 1, H, W)
  62. mask: (N, H, W)
  63. """
  64. return self._compute(pred, gt, mask, weights)
  65. def _compute(self, pred, gt, mask, weights):
  66. if len(pred.shape) == 4:
  67. pred = pred[:, 0, :, :]
  68. gt = gt[:, 0, :, :]
  69. assert pred.shape == gt.shape
  70. assert pred.shape == mask.shape
  71. if weights is not None:
  72. assert weights.shape == mask.shape
  73. mask = weights * mask
  74. intersection = (pred * gt * mask).sum()
  75. union = (pred * mask).sum() + (gt * mask).sum() + self.eps
  76. loss = 1 - 2.0 * intersection / union
  77. assert loss <= 1
  78. return loss
  79. class MaskL1Loss(nn.Layer):
  80. def __init__(self, eps=1e-6):
  81. super(MaskL1Loss, self).__init__()
  82. self.eps = eps
  83. def forward(self, pred: paddle.Tensor, gt, mask):
  84. loss = (paddle.abs(pred - gt) * mask).sum() / (mask.sum() + self.eps)
  85. return loss