loss.py 7.1 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188
  1. # Copyright (c) OpenMMLab. All rights reserved.
  2. """Modified from https://github.com/LikeLy-Journey/SegmenTron/blob/master/
  3. segmentron/solver/loss.py (Apache-2.0 License)"""
  4. import torch
  5. import torch.nn as nn
  6. import torch.nn.functional as F
  7. from torch.nn.modules.loss import BCEWithLogitsLoss
  8. class BinaryDiceLoss(nn.Module):
  9. """Dice loss of binary class
  10. Args:
  11. smooth: A float number to smooth loss, and avoid NaN error, default: 1
  12. p: Denominator value: \sum{x^p} + \sum{y^p}, default: 2
  13. predict: A tensor of shape [N, *]
  14. target: A tensor of shape same with predict
  15. reduction: Reduction method to apply, return mean over batch if 'mean',
  16. return sum if 'sum', return a tensor of shape [N,] if 'none'
  17. Returns:
  18. Loss tensor according to arg reduction
  19. Raise:
  20. Exception if unexpected reduction
  21. """
  22. def __init__(self, smooth=1, p=2, reduction='mean'):
  23. super(BinaryDiceLoss, self).__init__()
  24. self.smooth = smooth
  25. self.p = p
  26. self.reduction = reduction
  27. def forward(self, predict, target):
  28. assert predict.shape[0] == target.shape[0], "predict & target batch size don't match"
  29. predict = predict.contiguous().view(predict.shape[0], -1)
  30. target = target.contiguous().view(target.shape[0], -1)
  31. num = torch.sum(torch.mul(predict, target), dim=1) + self.smooth
  32. den = torch.sum(predict.pow(self.p) + target.pow(self.p), dim=1) + self.smooth
  33. loss = 1 - num / den
  34. if self.reduction == 'mean':
  35. return loss.mean()
  36. elif self.reduction == 'sum':
  37. return loss.sum()
  38. elif self.reduction == 'none':
  39. return loss
  40. else:
  41. raise Exception('Unexpected reduction {}'.format(self.reduction))
  42. class BalanceCrossEntropyLoss(nn.Module):
  43. '''
  44. Balanced cross entropy loss.
  45. Shape:
  46. - Input: :math:`(N, 1, H, W)`
  47. - GT: :math:`(N, 1, H, W)`, same shape as the input
  48. - Mask: :math:`(N, H, W)`, same spatial shape as the input
  49. - Output: scalar.
  50. Examples::
  51. >>> m = nn.Sigmoid()
  52. >>> loss = nn.BCELoss()
  53. >>> input = torch.randn(3, requires_grad=True)
  54. >>> target = torch.empty(3).random_(2)
  55. >>> output = loss(m(input), target)
  56. >>> output.backward()
  57. '''
  58. def __init__(self, negative_ratio=3.0, eps=1e-6):
  59. super(BalanceCrossEntropyLoss, self).__init__()
  60. self.negative_ratio = negative_ratio
  61. self.eps = eps
  62. def forward(self,
  63. pred: torch.Tensor,
  64. gt: torch.Tensor,
  65. mask: torch.Tensor,
  66. return_origin=False):
  67. '''
  68. Args:
  69. pred: shape :math:`(N, 1, H, W)`, the prediction of network
  70. gt: shape :math:`(N, 1, H, W)`, the target
  71. mask: shape :math:`(N, H, W)`, the mask indicates positive regions
  72. '''
  73. positive = (gt * mask).byte()
  74. negative = ((1 - gt) * mask).byte()
  75. positive_count = int(positive.float().sum())
  76. negative_count = min(int(negative.float().sum()), int(positive_count * self.negative_ratio))
  77. # loss = nn.functional.binary_cross_entropy(pred, gt, reduction='none')
  78. loss = nn.functional.binary_cross_entropy_with_logits(pred, gt, reduction='none')
  79. positive_loss = loss * positive.float()
  80. negative_loss = loss * negative.float()
  81. # negative_loss, _ = torch.topk(negative_loss.view(-1).contiguous(), negative_count)
  82. negative_loss, _ = negative_loss.view(-1).topk(negative_count)
  83. balance_loss = (positive_loss.sum() + negative_loss.sum()) / (positive_count + negative_count + self.eps)
  84. if return_origin:
  85. return balance_loss, loss
  86. return balance_loss
  87. class DiceLoss(nn.Module):
  88. '''
  89. Loss function from https://arxiv.org/abs/1707.03237,
  90. where iou computation is introduced heatmap manner to measure the
  91. diversity between tow heatmaps.
  92. '''
  93. def __init__(self, eps=1e-6):
  94. super(DiceLoss, self).__init__()
  95. self.eps = eps
  96. def forward(self, pred: torch.Tensor, gt, mask, weights=None):
  97. '''
  98. pred: one or two heatmaps of shape (N, 1, H, W),
  99. the losses of tow heatmaps are added together.
  100. gt: (N, 1, H, W)
  101. mask: (N, H, W)
  102. '''
  103. return self._compute(pred, gt, mask, weights)
  104. def _compute(self, pred, gt, mask, weights):
  105. if pred.dim() == 4:
  106. pred = pred[:, 0, :, :]
  107. gt = gt[:, 0, :, :]
  108. assert pred.shape == gt.shape
  109. assert pred.shape == mask.shape
  110. if weights is not None:
  111. assert weights.shape == mask.shape
  112. mask = weights * mask
  113. intersection = (pred * gt * mask).sum()
  114. union = (pred * mask).sum() + (gt * mask).sum() + self.eps
  115. loss = 1 - 2.0 * intersection / union
  116. assert loss <= 1
  117. return loss
  118. class MaskL1Loss(nn.Module):
  119. def __init__(self, eps=1e-6):
  120. super(MaskL1Loss, self).__init__()
  121. self.eps = eps
  122. def forward(self, pred: torch.Tensor, gt, mask):
  123. loss = (torch.abs(pred - gt) * mask).sum() / (mask.sum() + self.eps)
  124. return loss
  125. class DBLoss(nn.Module):
  126. def __init__(self, alpha=3.0, beta=1.0, ohem_ratio=3, reduction='mean', eps=1e-6):
  127. """
  128. Implement PSE Loss.
  129. :param alpha: binary_map loss 前面的系数
  130. :param beta: threshold_map loss 前面的系数
  131. :param ohem_ratio: OHEM的比例
  132. :param reduction: 'mean' or 'sum'对 batch里的loss 算均值或求和
  133. """
  134. super().__init__()
  135. assert reduction in ['mean', 'sum'], " reduction must in ['mean','sum']"
  136. self.alpha = alpha
  137. self.beta = beta
  138. self.bce_loss = BalanceCrossEntropyLoss(negative_ratio=ohem_ratio)
  139. self.dice_loss = DiceLoss(eps=eps)
  140. self.l1_loss = MaskL1Loss(eps=eps)
  141. self.ohem_ratio = ohem_ratio
  142. self.reduction = reduction
  143. def forward(self, pred, batch, use_bce=True):
  144. shrink_maps = pred[:, 0, :, :]
  145. threshold_maps = pred[:, 1, :, :]
  146. binary_maps = pred[:, 2, :, :]
  147. if use_bce:
  148. loss_shrink_maps = self.bce_loss(pred[:, 3, :, :], batch['shrink_map'], batch['shrink_mask']) + self.dice_loss(shrink_maps, batch['shrink_map'], batch['shrink_mask'])
  149. else:
  150. loss_shrink_maps = self.dice_loss(shrink_maps, batch['shrink_map'], batch['shrink_mask'])
  151. loss_threshold_maps = self.l1_loss(threshold_maps, batch['threshold_map'], batch['threshold_mask'])
  152. metrics = dict(loss_shrink_maps=loss_shrink_maps, loss_threshold_maps=loss_threshold_maps)
  153. if pred.size()[1] > 2:
  154. loss_binary_maps = self.dice_loss(binary_maps, batch['shrink_map'], batch['shrink_mask']) + self.bce_loss(binary_maps, batch['shrink_map'], batch['shrink_mask'])
  155. metrics['loss_binary_maps'] = loss_binary_maps
  156. loss_all = self.alpha * loss_shrink_maps + self.beta * loss_threshold_maps + loss_binary_maps
  157. metrics['loss'] = loss_all
  158. else:
  159. metrics['loss'] = loss_shrink_maps
  160. return metrics