| 123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248 |
- import torch
- from torch import nn
- import torch.nn.functional as F
- from torch.autograd import Variable
- from math import exp
- from config import Config
- class ContourLoss(torch.nn.Module):
- def __init__(self):
- super(ContourLoss, self).__init__()
- def forward(self, pred, target, weight=10):
- '''
- target, pred: tensor of shape (B, C, H, W), where target[:,:,region_in_contour] == 1,
- target[:,:,region_out_contour] == 0.
- weight: scalar, length term weight.
- '''
- # length term
- delta_r = pred[:,:,1:,:] - pred[:,:,:-1,:] # horizontal gradient (B, C, H-1, W)
- delta_c = pred[:,:,:,1:] - pred[:,:,:,:-1] # vertical gradient (B, C, H, W-1)
- delta_r = delta_r[:,:,1:,:-2]**2 # (B, C, H-2, W-2)
- delta_c = delta_c[:,:,:-2,1:]**2 # (B, C, H-2, W-2)
- delta_pred = torch.abs(delta_r + delta_c)
- epsilon = 1e-8 # where is a parameter to avoid square root is zero in practice.
- length = torch.mean(torch.sqrt(delta_pred + epsilon)) # eq.(11) in the paper, mean is used instead of sum.
- c_in = torch.ones_like(pred)
- c_out = torch.zeros_like(pred)
- region_in = torch.mean( pred * (target - c_in )**2 ) # equ.(12) in the paper, mean is used instead of sum.
- region_out = torch.mean( (1-pred) * (target - c_out)**2 )
- region = region_in + region_out
- loss = weight * length + region
- return loss
- class IoULoss(torch.nn.Module):
- def __init__(self):
- super(IoULoss, self).__init__()
- def forward(self, pred, target):
- b = pred.shape[0]
- IoU = 0.0
- for i in range(0, b):
- # compute the IoU of the foreground
- Iand1 = torch.sum(target[i, :, :, :] * pred[i, :, :, :])
- Ior1 = torch.sum(target[i, :, :, :]) + torch.sum(pred[i, :, :, :]) - Iand1
- IoU1 = Iand1 / Ior1
- # IoU loss is (1-IoU1)
- IoU = IoU + (1-IoU1)
- # return IoU/b
- return IoU
- class StructureLoss(torch.nn.Module):
- def __init__(self):
- super(StructureLoss, self).__init__()
- def forward(self, pred, target):
- weit = 1+5*torch.abs(F.avg_pool2d(target, kernel_size=31, stride=1, padding=15)-target)
- wbce = F.binary_cross_entropy_with_logits(pred, target, reduction='none')
- wbce = (weit*wbce).sum(dim=(2,3))/weit.sum(dim=(2,3))
- pred = torch.sigmoid(pred)
- inter = ((pred * target) * weit).sum(dim=(2, 3))
- union = ((pred + target) * weit).sum(dim=(2, 3))
- wiou = 1-(inter+1)/(union-inter+1)
- return (wbce+wiou).mean()
- class PatchIoULoss(torch.nn.Module):
- def __init__(self):
- super(PatchIoULoss, self).__init__()
- self.iou_loss = IoULoss()
- def forward(self, pred, target):
- win_y, win_x = 64, 64
- iou_loss = 0.
- for anchor_y in range(0, target.shape[0], win_y):
- for anchor_x in range(0, target.shape[1], win_y):
- patch_pred = pred[:, :, anchor_y:anchor_y+win_y, anchor_x:anchor_x+win_x]
- patch_target = target[:, :, anchor_y:anchor_y+win_y, anchor_x:anchor_x+win_x]
- patch_iou_loss = self.iou_loss(patch_pred, patch_target)
- iou_loss += patch_iou_loss
- return iou_loss
- class ThrReg_loss(torch.nn.Module):
- def __init__(self):
- super(ThrReg_loss, self).__init__()
- def forward(self, pred, gt=None):
- return torch.mean(1 - ((pred - 0) ** 2 + (pred - 1) ** 2))
- class ClsLoss(nn.Module):
- """
- Auxiliary classification loss for each refined class output.
- """
- def __init__(self):
- super(ClsLoss, self).__init__()
- self.config = Config()
- self.lambdas_cls = self.config.lambdas_cls
- self.criterions_last = {
- 'ce': nn.CrossEntropyLoss()
- }
- def forward(self, preds, gt):
- loss = 0.
- for _, pred_lvl in enumerate(preds):
- if pred_lvl is None:
- continue
- for criterion_name, criterion in self.criterions_last.items():
- loss += criterion(pred_lvl, gt) * self.lambdas_cls[criterion_name]
- return loss
- class PixLoss(nn.Module):
- """
- Pixel loss for each refined map output.
- """
- def __init__(self):
- super(PixLoss, self).__init__()
- self.config = Config()
- self.lambdas_pix_last = self.config.lambdas_pix_last
- self.criterions_last = {}
- if 'bce' in self.lambdas_pix_last and self.lambdas_pix_last['bce']:
- self.criterions_last['bce'] = nn.BCELoss()
- if 'iou' in self.lambdas_pix_last and self.lambdas_pix_last['iou']:
- self.criterions_last['iou'] = IoULoss()
- if 'iou_patch' in self.lambdas_pix_last and self.lambdas_pix_last['iou_patch']:
- self.criterions_last['iou_patch'] = PatchIoULoss()
- if 'ssim' in self.lambdas_pix_last and self.lambdas_pix_last['ssim']:
- self.criterions_last['ssim'] = SSIMLoss()
- if 'mae' in self.lambdas_pix_last and self.lambdas_pix_last['mae']:
- self.criterions_last['mae'] = nn.L1Loss()
- if 'mse' in self.lambdas_pix_last and self.lambdas_pix_last['mse']:
- self.criterions_last['mse'] = nn.MSELoss()
- if 'reg' in self.lambdas_pix_last and self.lambdas_pix_last['reg']:
- self.criterions_last['reg'] = ThrReg_loss()
- if 'cnt' in self.lambdas_pix_last and self.lambdas_pix_last['cnt']:
- self.criterions_last['cnt'] = ContourLoss()
- if 'structure' in self.lambdas_pix_last and self.lambdas_pix_last['structure']:
- self.criterions_last['structure'] = StructureLoss()
- def forward(self, scaled_preds, gt, pix_loss_lambda=1.0):
- loss = 0.
- loss_dict = {}
- for _, pred_lvl in enumerate(scaled_preds):
- if pred_lvl.shape != gt.shape:
- pred_lvl = nn.functional.interpolate(pred_lvl, size=gt.shape[2:], mode='bilinear', align_corners=True)
- for criterion_name, criterion in self.criterions_last.items():
- _loss = criterion(pred_lvl.sigmoid(), gt) * self.lambdas_pix_last[criterion_name] * pix_loss_lambda
- loss += _loss
- loss_dict[criterion_name] = loss_dict.get(criterion_name, 0.) + _loss.item() / len(scaled_preds)
- # print(criterion_name, _loss.item())
- return loss, loss_dict
- class SSIMLoss(torch.nn.Module):
- def __init__(self, window_size=11, size_average=True):
- super(SSIMLoss, self).__init__()
- self.window_size = window_size
- self.size_average = size_average
- self.channel = 1
- self.window = create_window(window_size, self.channel)
- def forward(self, img1, img2):
- (_, channel, _, _) = img1.size()
- if channel == self.channel and self.window.data.type() == img1.data.type():
- window = self.window
- else:
- window = create_window(self.window_size, channel)
- if img1.is_cuda:
- window = window.cuda(img1.get_device())
- window = window.type_as(img1)
- self.window = window
- self.channel = channel
- return 1 - (1 + _ssim(img1, img2, window, self.window_size, channel, self.size_average)) / 2
- def gaussian(window_size, sigma):
- gauss = torch.Tensor([exp(-(x - window_size//2)**2/float(2*sigma**2)) for x in range(window_size)])
- return gauss/gauss.sum()
- def create_window(window_size, channel):
- _1D_window = gaussian(window_size, 1.5).unsqueeze(1)
- _2D_window = _1D_window.mm(_1D_window.t()).float().unsqueeze(0).unsqueeze(0)
- window = Variable(_2D_window.expand(channel, 1, window_size, window_size).contiguous())
- return window
- def _ssim(img1, img2, window, window_size, channel, size_average=True):
- mu1 = F.conv2d(img1, window, padding = window_size//2, groups=channel)
- mu2 = F.conv2d(img2, window, padding = window_size//2, groups=channel)
- mu1_sq = mu1.pow(2)
- mu2_sq = mu2.pow(2)
- mu1_mu2 = mu1*mu2
- sigma1_sq = F.conv2d(img1*img1, window, padding=window_size//2, groups=channel) - mu1_sq
- sigma2_sq = F.conv2d(img2*img2, window, padding=window_size//2, groups=channel) - mu2_sq
- sigma12 = F.conv2d(img1*img2, window, padding=window_size//2, groups=channel) - mu1_mu2
- C1 = 0.01**2
- C2 = 0.03**2
- ssim_map = ((2*mu1_mu2 + C1)*(2*sigma12 + C2))/((mu1_sq + mu2_sq + C1)*(sigma1_sq + sigma2_sq + C2))
- if size_average:
- return ssim_map.mean()
- else:
- return ssim_map.mean(1).mean(1).mean(1)
- def SSIM(x, y):
- C1 = 0.01 ** 2
- C2 = 0.03 ** 2
- mu_x = nn.AvgPool2d(3, 1, 1)(x)
- mu_y = nn.AvgPool2d(3, 1, 1)(y)
- mu_x_mu_y = mu_x * mu_y
- mu_x_sq = mu_x.pow(2)
- mu_y_sq = mu_y.pow(2)
- sigma_x = nn.AvgPool2d(3, 1, 1)(x * x) - mu_x_sq
- sigma_y = nn.AvgPool2d(3, 1, 1)(y * y) - mu_y_sq
- sigma_xy = nn.AvgPool2d(3, 1, 1)(x * y) - mu_x_mu_y
- SSIM_n = (2 * mu_x_mu_y + C1) * (2 * sigma_xy + C2)
- SSIM_d = (mu_x_sq + mu_y_sq + C1) * (sigma_x + sigma_y + C2)
- SSIM = SSIM_n / SSIM_d
- return torch.clamp((1 - SSIM) / 2, 0, 1)
- def saliency_structure_consistency(x, y):
- ssim = torch.mean(SSIM(x,y))
- return ssim
|