| 1234567891011121314151617181920212223242526272829303132333435363738394041424344454647484950 |
- import paddle
- from models.losses.basic_loss import BalanceCrossEntropyLoss, MaskL1Loss, DiceLoss
- class DBLoss(paddle.nn.Layer):
- def __init__(self, alpha=1.0, beta=10, ohem_ratio=3, reduction="mean", eps=1e-06):
- """
- Implement PSE Loss.
- :param alpha: binary_map loss 前面的系数
- :param beta: threshold_map loss 前面的系数
- :param ohem_ratio: OHEM的比例
- :param reduction: 'mean' or 'sum'对 batch里的loss 算均值或求和
- """
- super().__init__()
- assert reduction in ["mean", "sum"], " reduction must in ['mean','sum']"
- self.alpha = alpha
- self.beta = beta
- self.bce_loss = BalanceCrossEntropyLoss(negative_ratio=ohem_ratio)
- self.dice_loss = DiceLoss(eps=eps)
- self.l1_loss = MaskL1Loss(eps=eps)
- self.ohem_ratio = ohem_ratio
- self.reduction = reduction
- def forward(self, pred, batch):
- shrink_maps = pred[:, 0, :, :]
- threshold_maps = pred[:, 1, :, :]
- binary_maps = pred[:, 2, :, :]
- loss_shrink_maps = self.bce_loss(
- shrink_maps, batch["shrink_map"], batch["shrink_mask"]
- )
- loss_threshold_maps = self.l1_loss(
- threshold_maps, batch["threshold_map"], batch["threshold_mask"]
- )
- metrics = dict(
- loss_shrink_maps=loss_shrink_maps, loss_threshold_maps=loss_threshold_maps
- )
- if pred.shape[1] > 2:
- loss_binary_maps = self.dice_loss(
- binary_maps, batch["shrink_map"], batch["shrink_mask"]
- )
- metrics["loss_binary_maps"] = loss_binary_maps
- loss_all = (
- self.alpha * loss_shrink_maps
- + self.beta * loss_threshold_maps
- + loss_binary_maps
- )
- metrics["loss"] = loss_all
- else:
- metrics["loss"] = loss_shrink_maps
- return metrics
|