loss.py 9.1 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248
  1. import torch
  2. from torch import nn
  3. import torch.nn.functional as F
  4. from torch.autograd import Variable
  5. from math import exp
  6. from config import Config
  7. class ContourLoss(torch.nn.Module):
  8. def __init__(self):
  9. super(ContourLoss, self).__init__()
  10. def forward(self, pred, target, weight=10):
  11. '''
  12. target, pred: tensor of shape (B, C, H, W), where target[:,:,region_in_contour] == 1,
  13. target[:,:,region_out_contour] == 0.
  14. weight: scalar, length term weight.
  15. '''
  16. # length term
  17. delta_r = pred[:,:,1:,:] - pred[:,:,:-1,:] # horizontal gradient (B, C, H-1, W)
  18. delta_c = pred[:,:,:,1:] - pred[:,:,:,:-1] # vertical gradient (B, C, H, W-1)
  19. delta_r = delta_r[:,:,1:,:-2]**2 # (B, C, H-2, W-2)
  20. delta_c = delta_c[:,:,:-2,1:]**2 # (B, C, H-2, W-2)
  21. delta_pred = torch.abs(delta_r + delta_c)
  22. epsilon = 1e-8 # where is a parameter to avoid square root is zero in practice.
  23. length = torch.mean(torch.sqrt(delta_pred + epsilon)) # eq.(11) in the paper, mean is used instead of sum.
  24. c_in = torch.ones_like(pred)
  25. c_out = torch.zeros_like(pred)
  26. region_in = torch.mean( pred * (target - c_in )**2 ) # equ.(12) in the paper, mean is used instead of sum.
  27. region_out = torch.mean( (1-pred) * (target - c_out)**2 )
  28. region = region_in + region_out
  29. loss = weight * length + region
  30. return loss
  31. class IoULoss(torch.nn.Module):
  32. def __init__(self):
  33. super(IoULoss, self).__init__()
  34. def forward(self, pred, target):
  35. b = pred.shape[0]
  36. IoU = 0.0
  37. for i in range(0, b):
  38. # compute the IoU of the foreground
  39. Iand1 = torch.sum(target[i, :, :, :] * pred[i, :, :, :])
  40. Ior1 = torch.sum(target[i, :, :, :]) + torch.sum(pred[i, :, :, :]) - Iand1
  41. IoU1 = Iand1 / Ior1
  42. # IoU loss is (1-IoU1)
  43. IoU = IoU + (1-IoU1)
  44. # return IoU/b
  45. return IoU
  46. class StructureLoss(torch.nn.Module):
  47. def __init__(self):
  48. super(StructureLoss, self).__init__()
  49. def forward(self, pred, target):
  50. weit = 1+5*torch.abs(F.avg_pool2d(target, kernel_size=31, stride=1, padding=15)-target)
  51. wbce = F.binary_cross_entropy_with_logits(pred, target, reduction='none')
  52. wbce = (weit*wbce).sum(dim=(2,3))/weit.sum(dim=(2,3))
  53. pred = torch.sigmoid(pred)
  54. inter = ((pred * target) * weit).sum(dim=(2, 3))
  55. union = ((pred + target) * weit).sum(dim=(2, 3))
  56. wiou = 1-(inter+1)/(union-inter+1)
  57. return (wbce+wiou).mean()
  58. class PatchIoULoss(torch.nn.Module):
  59. def __init__(self):
  60. super(PatchIoULoss, self).__init__()
  61. self.iou_loss = IoULoss()
  62. def forward(self, pred, target):
  63. win_y, win_x = 64, 64
  64. iou_loss = 0.
  65. for anchor_y in range(0, target.shape[0], win_y):
  66. for anchor_x in range(0, target.shape[1], win_y):
  67. patch_pred = pred[:, :, anchor_y:anchor_y+win_y, anchor_x:anchor_x+win_x]
  68. patch_target = target[:, :, anchor_y:anchor_y+win_y, anchor_x:anchor_x+win_x]
  69. patch_iou_loss = self.iou_loss(patch_pred, patch_target)
  70. iou_loss += patch_iou_loss
  71. return iou_loss
  72. class ThrReg_loss(torch.nn.Module):
  73. def __init__(self):
  74. super(ThrReg_loss, self).__init__()
  75. def forward(self, pred, gt=None):
  76. return torch.mean(1 - ((pred - 0) ** 2 + (pred - 1) ** 2))
  77. class ClsLoss(nn.Module):
  78. """
  79. Auxiliary classification loss for each refined class output.
  80. """
  81. def __init__(self):
  82. super(ClsLoss, self).__init__()
  83. self.config = Config()
  84. self.lambdas_cls = self.config.lambdas_cls
  85. self.criterions_last = {
  86. 'ce': nn.CrossEntropyLoss()
  87. }
  88. def forward(self, preds, gt):
  89. loss = 0.
  90. for _, pred_lvl in enumerate(preds):
  91. if pred_lvl is None:
  92. continue
  93. for criterion_name, criterion in self.criterions_last.items():
  94. loss += criterion(pred_lvl, gt) * self.lambdas_cls[criterion_name]
  95. return loss
  96. class PixLoss(nn.Module):
  97. """
  98. Pixel loss for each refined map output.
  99. """
  100. def __init__(self):
  101. super(PixLoss, self).__init__()
  102. self.config = Config()
  103. self.lambdas_pix_last = self.config.lambdas_pix_last
  104. self.criterions_last = {}
  105. if 'bce' in self.lambdas_pix_last and self.lambdas_pix_last['bce']:
  106. self.criterions_last['bce'] = nn.BCELoss()
  107. if 'iou' in self.lambdas_pix_last and self.lambdas_pix_last['iou']:
  108. self.criterions_last['iou'] = IoULoss()
  109. if 'iou_patch' in self.lambdas_pix_last and self.lambdas_pix_last['iou_patch']:
  110. self.criterions_last['iou_patch'] = PatchIoULoss()
  111. if 'ssim' in self.lambdas_pix_last and self.lambdas_pix_last['ssim']:
  112. self.criterions_last['ssim'] = SSIMLoss()
  113. if 'mae' in self.lambdas_pix_last and self.lambdas_pix_last['mae']:
  114. self.criterions_last['mae'] = nn.L1Loss()
  115. if 'mse' in self.lambdas_pix_last and self.lambdas_pix_last['mse']:
  116. self.criterions_last['mse'] = nn.MSELoss()
  117. if 'reg' in self.lambdas_pix_last and self.lambdas_pix_last['reg']:
  118. self.criterions_last['reg'] = ThrReg_loss()
  119. if 'cnt' in self.lambdas_pix_last and self.lambdas_pix_last['cnt']:
  120. self.criterions_last['cnt'] = ContourLoss()
  121. if 'structure' in self.lambdas_pix_last and self.lambdas_pix_last['structure']:
  122. self.criterions_last['structure'] = StructureLoss()
  123. def forward(self, scaled_preds, gt, pix_loss_lambda=1.0):
  124. loss = 0.
  125. loss_dict = {}
  126. for _, pred_lvl in enumerate(scaled_preds):
  127. if pred_lvl.shape != gt.shape:
  128. pred_lvl = nn.functional.interpolate(pred_lvl, size=gt.shape[2:], mode='bilinear', align_corners=True)
  129. for criterion_name, criterion in self.criterions_last.items():
  130. _loss = criterion(pred_lvl.sigmoid(), gt) * self.lambdas_pix_last[criterion_name] * pix_loss_lambda
  131. loss += _loss
  132. loss_dict[criterion_name] = loss_dict.get(criterion_name, 0.) + _loss.item() / len(scaled_preds)
  133. # print(criterion_name, _loss.item())
  134. return loss, loss_dict
  135. class SSIMLoss(torch.nn.Module):
  136. def __init__(self, window_size=11, size_average=True):
  137. super(SSIMLoss, self).__init__()
  138. self.window_size = window_size
  139. self.size_average = size_average
  140. self.channel = 1
  141. self.window = create_window(window_size, self.channel)
  142. def forward(self, img1, img2):
  143. (_, channel, _, _) = img1.size()
  144. if channel == self.channel and self.window.data.type() == img1.data.type():
  145. window = self.window
  146. else:
  147. window = create_window(self.window_size, channel)
  148. if img1.is_cuda:
  149. window = window.cuda(img1.get_device())
  150. window = window.type_as(img1)
  151. self.window = window
  152. self.channel = channel
  153. return 1 - (1 + _ssim(img1, img2, window, self.window_size, channel, self.size_average)) / 2
  154. def gaussian(window_size, sigma):
  155. gauss = torch.Tensor([exp(-(x - window_size//2)**2/float(2*sigma**2)) for x in range(window_size)])
  156. return gauss/gauss.sum()
  157. def create_window(window_size, channel):
  158. _1D_window = gaussian(window_size, 1.5).unsqueeze(1)
  159. _2D_window = _1D_window.mm(_1D_window.t()).float().unsqueeze(0).unsqueeze(0)
  160. window = Variable(_2D_window.expand(channel, 1, window_size, window_size).contiguous())
  161. return window
  162. def _ssim(img1, img2, window, window_size, channel, size_average=True):
  163. mu1 = F.conv2d(img1, window, padding = window_size//2, groups=channel)
  164. mu2 = F.conv2d(img2, window, padding = window_size//2, groups=channel)
  165. mu1_sq = mu1.pow(2)
  166. mu2_sq = mu2.pow(2)
  167. mu1_mu2 = mu1*mu2
  168. sigma1_sq = F.conv2d(img1*img1, window, padding=window_size//2, groups=channel) - mu1_sq
  169. sigma2_sq = F.conv2d(img2*img2, window, padding=window_size//2, groups=channel) - mu2_sq
  170. sigma12 = F.conv2d(img1*img2, window, padding=window_size//2, groups=channel) - mu1_mu2
  171. C1 = 0.01**2
  172. C2 = 0.03**2
  173. ssim_map = ((2*mu1_mu2 + C1)*(2*sigma12 + C2))/((mu1_sq + mu2_sq + C1)*(sigma1_sq + sigma2_sq + C2))
  174. if size_average:
  175. return ssim_map.mean()
  176. else:
  177. return ssim_map.mean(1).mean(1).mean(1)
  178. def SSIM(x, y):
  179. C1 = 0.01 ** 2
  180. C2 = 0.03 ** 2
  181. mu_x = nn.AvgPool2d(3, 1, 1)(x)
  182. mu_y = nn.AvgPool2d(3, 1, 1)(y)
  183. mu_x_mu_y = mu_x * mu_y
  184. mu_x_sq = mu_x.pow(2)
  185. mu_y_sq = mu_y.pow(2)
  186. sigma_x = nn.AvgPool2d(3, 1, 1)(x * x) - mu_x_sq
  187. sigma_y = nn.AvgPool2d(3, 1, 1)(y * y) - mu_y_sq
  188. sigma_xy = nn.AvgPool2d(3, 1, 1)(x * y) - mu_x_mu_y
  189. SSIM_n = (2 * mu_x_mu_y + C1) * (2 * sigma_xy + C2)
  190. SSIM_d = (mu_x_sq + mu_y_sq + C1) * (sigma_x + sigma_y + C2)
  191. SSIM = SSIM_n / SSIM_d
  192. return torch.clamp((1 - SSIM) / 2, 0, 1)
  193. def saliency_structure_consistency(x, y):
  194. ssim = torch.mean(SSIM(x,y))
  195. return ssim