rec_ce_loss.py 2.5 KB

12345678910111213141516171819202122232425262728293031323334353637383940414243444546474849505152535455565758596061
  1. import paddle
  2. from paddle import nn
  3. import paddle.nn.functional as F
  4. class CELoss(nn.Layer):
  5. def __init__(self, smoothing=False, with_all=False, ignore_index=-1, **kwargs):
  6. super(CELoss, self).__init__()
  7. if ignore_index >= 0:
  8. self.loss_func = nn.CrossEntropyLoss(
  9. reduction="mean", ignore_index=ignore_index
  10. )
  11. else:
  12. self.loss_func = nn.CrossEntropyLoss(reduction="mean")
  13. self.smoothing = smoothing
  14. self.with_all = with_all
  15. def forward(self, pred, batch):
  16. if isinstance(pred, dict): # for ABINet
  17. loss = {}
  18. loss_sum = []
  19. for name, logits in pred.items():
  20. if isinstance(logits, list):
  21. logit_num = len(logits)
  22. all_tgt = paddle.concat([batch[1]] * logit_num, 0)
  23. all_logits = paddle.concat(logits, 0)
  24. flt_logtis = all_logits.reshape([-1, all_logits.shape[2]])
  25. flt_tgt = all_tgt.reshape([-1])
  26. else:
  27. flt_logtis = logits.reshape([-1, logits.shape[2]])
  28. flt_tgt = batch[1].reshape([-1])
  29. loss[name + "_loss"] = self.loss_func(flt_logtis, flt_tgt)
  30. loss_sum.append(loss[name + "_loss"])
  31. loss["loss"] = sum(loss_sum)
  32. return loss
  33. else:
  34. if self.with_all: # for ViTSTR
  35. tgt = batch[1]
  36. pred = pred.reshape([-1, pred.shape[2]])
  37. tgt = tgt.reshape([-1])
  38. loss = self.loss_func(pred, tgt)
  39. return {"loss": loss}
  40. else: # for NRTR
  41. max_len = batch[2].max()
  42. tgt = batch[1][:, 1 : 2 + max_len]
  43. pred = pred.reshape([-1, pred.shape[2]])
  44. tgt = tgt.reshape([-1])
  45. if self.smoothing:
  46. eps = 0.1
  47. n_class = pred.shape[1]
  48. one_hot = F.one_hot(tgt, pred.shape[1])
  49. one_hot = one_hot * (1 - eps) + (1 - one_hot) * eps / (n_class - 1)
  50. log_prb = F.log_softmax(pred, axis=1)
  51. non_pad_mask = paddle.not_equal(
  52. tgt, paddle.zeros(tgt.shape, dtype=tgt.dtype)
  53. )
  54. loss = -(one_hot * log_prb).sum(axis=1)
  55. loss = loss.masked_select(non_pad_mask).mean()
  56. else:
  57. loss = self.loss_func(pred, tgt)
  58. return {"loss": loss}