DB_loss.py 1.9 KB

1234567891011121314151617181920212223242526272829303132333435363738394041424344454647484950
  1. import paddle
  2. from models.losses.basic_loss import BalanceCrossEntropyLoss, MaskL1Loss, DiceLoss
  3. class DBLoss(paddle.nn.Layer):
  4. def __init__(self, alpha=1.0, beta=10, ohem_ratio=3, reduction="mean", eps=1e-06):
  5. """
  6. Implement PSE Loss.
  7. :param alpha: binary_map loss 前面的系数
  8. :param beta: threshold_map loss 前面的系数
  9. :param ohem_ratio: OHEM的比例
  10. :param reduction: 'mean' or 'sum'对 batch里的loss 算均值或求和
  11. """
  12. super().__init__()
  13. assert reduction in ["mean", "sum"], " reduction must in ['mean','sum']"
  14. self.alpha = alpha
  15. self.beta = beta
  16. self.bce_loss = BalanceCrossEntropyLoss(negative_ratio=ohem_ratio)
  17. self.dice_loss = DiceLoss(eps=eps)
  18. self.l1_loss = MaskL1Loss(eps=eps)
  19. self.ohem_ratio = ohem_ratio
  20. self.reduction = reduction
  21. def forward(self, pred, batch):
  22. shrink_maps = pred[:, 0, :, :]
  23. threshold_maps = pred[:, 1, :, :]
  24. binary_maps = pred[:, 2, :, :]
  25. loss_shrink_maps = self.bce_loss(
  26. shrink_maps, batch["shrink_map"], batch["shrink_mask"]
  27. )
  28. loss_threshold_maps = self.l1_loss(
  29. threshold_maps, batch["threshold_map"], batch["threshold_mask"]
  30. )
  31. metrics = dict(
  32. loss_shrink_maps=loss_shrink_maps, loss_threshold_maps=loss_threshold_maps
  33. )
  34. if pred.shape[1] > 2:
  35. loss_binary_maps = self.dice_loss(
  36. binary_maps, batch["shrink_map"], batch["shrink_mask"]
  37. )
  38. metrics["loss_binary_maps"] = loss_binary_maps
  39. loss_all = (
  40. self.alpha * loss_shrink_maps
  41. + self.beta * loss_threshold_maps
  42. + loss_binary_maps
  43. )
  44. metrics["loss"] = loss_all
  45. else:
  46. metrics["loss"] = loss_shrink_maps
  47. return metrics