| 123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188 |
- # Copyright (c) OpenMMLab. All rights reserved.
- """Modified from https://github.com/LikeLy-Journey/SegmenTron/blob/master/
- segmentron/solver/loss.py (Apache-2.0 License)"""
- import torch
- import torch.nn as nn
- import torch.nn.functional as F
- from torch.nn.modules.loss import BCEWithLogitsLoss
- class BinaryDiceLoss(nn.Module):
- """Dice loss of binary class
- Args:
- smooth: A float number to smooth loss, and avoid NaN error, default: 1
- p: Denominator value: \sum{x^p} + \sum{y^p}, default: 2
- predict: A tensor of shape [N, *]
- target: A tensor of shape same with predict
- reduction: Reduction method to apply, return mean over batch if 'mean',
- return sum if 'sum', return a tensor of shape [N,] if 'none'
- Returns:
- Loss tensor according to arg reduction
- Raise:
- Exception if unexpected reduction
- """
- def __init__(self, smooth=1, p=2, reduction='mean'):
- super(BinaryDiceLoss, self).__init__()
- self.smooth = smooth
- self.p = p
- self.reduction = reduction
- def forward(self, predict, target):
- assert predict.shape[0] == target.shape[0], "predict & target batch size don't match"
- predict = predict.contiguous().view(predict.shape[0], -1)
- target = target.contiguous().view(target.shape[0], -1)
- num = torch.sum(torch.mul(predict, target), dim=1) + self.smooth
- den = torch.sum(predict.pow(self.p) + target.pow(self.p), dim=1) + self.smooth
- loss = 1 - num / den
- if self.reduction == 'mean':
- return loss.mean()
- elif self.reduction == 'sum':
- return loss.sum()
- elif self.reduction == 'none':
- return loss
- else:
- raise Exception('Unexpected reduction {}'.format(self.reduction))
- class BalanceCrossEntropyLoss(nn.Module):
- '''
- Balanced cross entropy loss.
- Shape:
- - Input: :math:`(N, 1, H, W)`
- - GT: :math:`(N, 1, H, W)`, same shape as the input
- - Mask: :math:`(N, H, W)`, same spatial shape as the input
- - Output: scalar.
- Examples::
- >>> m = nn.Sigmoid()
- >>> loss = nn.BCELoss()
- >>> input = torch.randn(3, requires_grad=True)
- >>> target = torch.empty(3).random_(2)
- >>> output = loss(m(input), target)
- >>> output.backward()
- '''
- def __init__(self, negative_ratio=3.0, eps=1e-6):
- super(BalanceCrossEntropyLoss, self).__init__()
- self.negative_ratio = negative_ratio
- self.eps = eps
- def forward(self,
- pred: torch.Tensor,
- gt: torch.Tensor,
- mask: torch.Tensor,
- return_origin=False):
- '''
- Args:
- pred: shape :math:`(N, 1, H, W)`, the prediction of network
- gt: shape :math:`(N, 1, H, W)`, the target
- mask: shape :math:`(N, H, W)`, the mask indicates positive regions
- '''
- positive = (gt * mask).byte()
- negative = ((1 - gt) * mask).byte()
- positive_count = int(positive.float().sum())
- negative_count = min(int(negative.float().sum()), int(positive_count * self.negative_ratio))
- # loss = nn.functional.binary_cross_entropy(pred, gt, reduction='none')
- loss = nn.functional.binary_cross_entropy_with_logits(pred, gt, reduction='none')
- positive_loss = loss * positive.float()
- negative_loss = loss * negative.float()
- # negative_loss, _ = torch.topk(negative_loss.view(-1).contiguous(), negative_count)
- negative_loss, _ = negative_loss.view(-1).topk(negative_count)
- balance_loss = (positive_loss.sum() + negative_loss.sum()) / (positive_count + negative_count + self.eps)
- if return_origin:
- return balance_loss, loss
- return balance_loss
- class DiceLoss(nn.Module):
- '''
- Loss function from https://arxiv.org/abs/1707.03237,
- where iou computation is introduced heatmap manner to measure the
- diversity between tow heatmaps.
- '''
- def __init__(self, eps=1e-6):
- super(DiceLoss, self).__init__()
- self.eps = eps
- def forward(self, pred: torch.Tensor, gt, mask, weights=None):
- '''
- pred: one or two heatmaps of shape (N, 1, H, W),
- the losses of tow heatmaps are added together.
- gt: (N, 1, H, W)
- mask: (N, H, W)
- '''
- return self._compute(pred, gt, mask, weights)
- def _compute(self, pred, gt, mask, weights):
- if pred.dim() == 4:
- pred = pred[:, 0, :, :]
- gt = gt[:, 0, :, :]
- assert pred.shape == gt.shape
- assert pred.shape == mask.shape
- if weights is not None:
- assert weights.shape == mask.shape
- mask = weights * mask
- intersection = (pred * gt * mask).sum()
- union = (pred * mask).sum() + (gt * mask).sum() + self.eps
- loss = 1 - 2.0 * intersection / union
- assert loss <= 1
- return loss
- class MaskL1Loss(nn.Module):
- def __init__(self, eps=1e-6):
- super(MaskL1Loss, self).__init__()
- self.eps = eps
- def forward(self, pred: torch.Tensor, gt, mask):
- loss = (torch.abs(pred - gt) * mask).sum() / (mask.sum() + self.eps)
- return loss
- class DBLoss(nn.Module):
- def __init__(self, alpha=3.0, beta=1.0, ohem_ratio=3, reduction='mean', eps=1e-6):
- """
- Implement PSE Loss.
- :param alpha: binary_map loss 前面的系数
- :param beta: threshold_map loss 前面的系数
- :param ohem_ratio: OHEM的比例
- :param reduction: 'mean' or 'sum'对 batch里的loss 算均值或求和
- """
- super().__init__()
- assert reduction in ['mean', 'sum'], " reduction must in ['mean','sum']"
- self.alpha = alpha
- self.beta = beta
- self.bce_loss = BalanceCrossEntropyLoss(negative_ratio=ohem_ratio)
- self.dice_loss = DiceLoss(eps=eps)
- self.l1_loss = MaskL1Loss(eps=eps)
- self.ohem_ratio = ohem_ratio
- self.reduction = reduction
- def forward(self, pred, batch, use_bce=True):
- shrink_maps = pred[:, 0, :, :]
- threshold_maps = pred[:, 1, :, :]
- binary_maps = pred[:, 2, :, :]
- if use_bce:
- loss_shrink_maps = self.bce_loss(pred[:, 3, :, :], batch['shrink_map'], batch['shrink_mask']) + self.dice_loss(shrink_maps, batch['shrink_map'], batch['shrink_mask'])
- else:
- loss_shrink_maps = self.dice_loss(shrink_maps, batch['shrink_map'], batch['shrink_mask'])
-
- loss_threshold_maps = self.l1_loss(threshold_maps, batch['threshold_map'], batch['threshold_mask'])
- metrics = dict(loss_shrink_maps=loss_shrink_maps, loss_threshold_maps=loss_threshold_maps)
- if pred.size()[1] > 2:
- loss_binary_maps = self.dice_loss(binary_maps, batch['shrink_map'], batch['shrink_mask']) + self.bce_loss(binary_maps, batch['shrink_map'], batch['shrink_mask'])
- metrics['loss_binary_maps'] = loss_binary_maps
- loss_all = self.alpha * loss_shrink_maps + self.beta * loss_threshold_maps + loss_binary_maps
- metrics['loss'] = loss_all
- else:
- metrics['loss'] = loss_shrink_maps
- return metrics
|