basic_loss.py 7.8 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247
  1. # copyright (c) 2021 PaddlePaddle Authors. All Rights Reserve.
  2. #
  3. # Licensed under the Apache License, Version 2.0 (the "License");
  4. # you may not use this file except in compliance with the License.
  5. # You may obtain a copy of the License at
  6. #
  7. # http://www.apache.org/licenses/LICENSE-2.0
  8. #
  9. # Unless required by applicable law or agreed to in writing, software
  10. # distributed under the License is distributed on an "AS IS" BASIS,
  11. # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  12. # See the License for the specific language governing permissions and
  13. # limitations under the License.
  14. import paddle
  15. import paddle.nn as nn
  16. import paddle.nn.functional as F
  17. from paddle.nn import L1Loss
  18. from paddle.nn import MSELoss as L2Loss
  19. from paddle.nn import SmoothL1Loss
  20. class CELoss(nn.Layer):
  21. def __init__(self, epsilon=None):
  22. super().__init__()
  23. if epsilon is not None and (epsilon <= 0 or epsilon >= 1):
  24. epsilon = None
  25. self.epsilon = epsilon
  26. def _labelsmoothing(self, target, class_num):
  27. if target.shape[-1] != class_num:
  28. one_hot_target = F.one_hot(target, class_num)
  29. else:
  30. one_hot_target = target
  31. soft_target = F.label_smooth(one_hot_target, epsilon=self.epsilon)
  32. soft_target = paddle.reshape(soft_target, shape=[-1, class_num])
  33. return soft_target
  34. def forward(self, x, label):
  35. loss_dict = {}
  36. if self.epsilon is not None:
  37. class_num = x.shape[-1]
  38. label = self._labelsmoothing(label, class_num)
  39. x = -F.log_softmax(x, axis=-1)
  40. loss = paddle.sum(x * label, axis=-1)
  41. else:
  42. if label.shape[-1] == x.shape[-1]:
  43. label = F.softmax(label, axis=-1)
  44. soft_label = True
  45. else:
  46. soft_label = False
  47. loss = F.cross_entropy(x, label=label, soft_label=soft_label)
  48. return loss
  49. class KLJSLoss(object):
  50. def __init__(self, mode="kl"):
  51. assert mode in [
  52. "kl",
  53. "js",
  54. "KL",
  55. "JS",
  56. ], "mode can only be one of ['kl', 'KL', 'js', 'JS']"
  57. self.mode = mode
  58. def __call__(self, p1, p2, reduction="mean", eps=1e-5):
  59. if self.mode.lower() == "kl":
  60. loss = paddle.multiply(p2, paddle.log((p2 + eps) / (p1 + eps) + eps))
  61. loss += paddle.multiply(p1, paddle.log((p1 + eps) / (p2 + eps) + eps))
  62. loss *= 0.5
  63. elif self.mode.lower() == "js":
  64. loss = paddle.multiply(
  65. p2, paddle.log((2 * p2 + eps) / (p1 + p2 + eps) + eps)
  66. )
  67. loss += paddle.multiply(
  68. p1, paddle.log((2 * p1 + eps) / (p1 + p2 + eps) + eps)
  69. )
  70. loss *= 0.5
  71. else:
  72. raise ValueError(
  73. "The mode.lower() if KLJSLoss should be one of ['kl', 'js']"
  74. )
  75. if reduction == "mean":
  76. loss = paddle.mean(loss, axis=[1, 2])
  77. elif reduction == "none" or reduction is None:
  78. return loss
  79. else:
  80. loss = paddle.sum(loss, axis=[1, 2])
  81. return loss
  82. class DMLLoss(nn.Layer):
  83. """
  84. DMLLoss
  85. """
  86. def __init__(self, act=None, use_log=False):
  87. super().__init__()
  88. if act is not None:
  89. assert act in ["softmax", "sigmoid"]
  90. if act == "softmax":
  91. self.act = nn.Softmax(axis=-1)
  92. elif act == "sigmoid":
  93. self.act = nn.Sigmoid()
  94. else:
  95. self.act = None
  96. self.use_log = use_log
  97. self.jskl_loss = KLJSLoss(mode="kl")
  98. def _kldiv(self, x, target):
  99. eps = 1.0e-10
  100. loss = target * (paddle.log(target + eps) - x)
  101. # batch mean loss
  102. loss = paddle.sum(loss) / loss.shape[0]
  103. return loss
  104. def forward(self, out1, out2):
  105. if self.act is not None:
  106. out1 = self.act(out1) + 1e-10
  107. out2 = self.act(out2) + 1e-10
  108. if self.use_log:
  109. # for recognition distillation, log is needed for feature map
  110. log_out1 = paddle.log(out1)
  111. log_out2 = paddle.log(out2)
  112. loss = (self._kldiv(log_out1, out2) + self._kldiv(log_out2, out1)) / 2.0
  113. else:
  114. # for detection distillation log is not needed
  115. loss = self.jskl_loss(out1, out2)
  116. return loss
  117. class DistanceLoss(nn.Layer):
  118. """
  119. DistanceLoss:
  120. mode: loss mode
  121. """
  122. def __init__(self, mode="l2", **kargs):
  123. super().__init__()
  124. assert mode in ["l1", "l2", "smooth_l1"]
  125. if mode == "l1":
  126. self.loss_func = nn.L1Loss(**kargs)
  127. elif mode == "l2":
  128. self.loss_func = nn.MSELoss(**kargs)
  129. elif mode == "smooth_l1":
  130. self.loss_func = nn.SmoothL1Loss(**kargs)
  131. def forward(self, x, y):
  132. return self.loss_func(x, y)
  133. class LossFromOutput(nn.Layer):
  134. def __init__(self, key="loss", reduction="none"):
  135. super().__init__()
  136. self.key = key
  137. self.reduction = reduction
  138. def forward(self, predicts, batch):
  139. loss = predicts
  140. if self.key is not None and isinstance(predicts, dict):
  141. loss = loss[self.key]
  142. if self.reduction == "mean":
  143. loss = paddle.mean(loss)
  144. elif self.reduction == "sum":
  145. loss = paddle.sum(loss)
  146. return {"loss": loss}
  147. class KLDivLoss(nn.Layer):
  148. """
  149. KLDivLoss
  150. """
  151. def __init__(self):
  152. super().__init__()
  153. def _kldiv(self, x, target, mask=None):
  154. eps = 1.0e-10
  155. loss = target * (paddle.log(target + eps) - x)
  156. if mask is not None:
  157. loss = loss.flatten(0, 1).sum(axis=1)
  158. loss = loss.masked_select(mask).mean()
  159. else:
  160. # batch mean loss
  161. loss = paddle.sum(loss) / loss.shape[0]
  162. return loss
  163. def forward(self, logits_s, logits_t, mask=None):
  164. log_out_s = F.log_softmax(logits_s, axis=-1)
  165. out_t = F.softmax(logits_t, axis=-1)
  166. loss = self._kldiv(log_out_s, out_t, mask)
  167. return loss
  168. class DKDLoss(nn.Layer):
  169. """
  170. KLDivLoss
  171. """
  172. def __init__(self, temperature=1.0, alpha=1.0, beta=1.0):
  173. super().__init__()
  174. self.temperature = temperature
  175. self.alpha = alpha
  176. self.beta = beta
  177. def _cat_mask(self, t, mask1, mask2):
  178. t1 = (t * mask1).sum(axis=1, keepdim=True)
  179. t2 = (t * mask2).sum(axis=1, keepdim=True)
  180. rt = paddle.concat([t1, t2], axis=1)
  181. return rt
  182. def _kl_div(self, x, label, mask=None):
  183. y = (label * (paddle.log(label + 1e-10) - x)).sum(axis=1)
  184. if mask is not None:
  185. y = y.masked_select(mask).mean()
  186. else:
  187. y = y.mean()
  188. return y
  189. def forward(self, logits_student, logits_teacher, target, mask=None):
  190. gt_mask = F.one_hot(target.reshape([-1]), num_classes=logits_student.shape[-1])
  191. other_mask = 1 - gt_mask
  192. logits_student = logits_student.flatten(0, 1)
  193. logits_teacher = logits_teacher.flatten(0, 1)
  194. pred_student = F.softmax(logits_student / self.temperature, axis=1)
  195. pred_teacher = F.softmax(logits_teacher / self.temperature, axis=1)
  196. pred_student = self._cat_mask(pred_student, gt_mask, other_mask)
  197. pred_teacher = self._cat_mask(pred_teacher, gt_mask, other_mask)
  198. log_pred_student = paddle.log(pred_student)
  199. tckd_loss = self._kl_div(log_pred_student, pred_teacher) * (self.temperature**2)
  200. pred_teacher_part2 = F.softmax(
  201. logits_teacher / self.temperature - 1000.0 * gt_mask, axis=1
  202. )
  203. log_pred_student_part2 = F.log_softmax(
  204. logits_student / self.temperature - 1000.0 * gt_mask, axis=1
  205. )
  206. nckd_loss = self._kl_div(log_pred_student_part2, pred_teacher_part2) * (
  207. self.temperature**2
  208. )
  209. loss = self.alpha * tckd_loss + self.beta * nckd_loss
  210. return loss