rec_sar_loss.py 1.2 KB

123456789101112131415161718192021222324252627282930313233343536
  1. from __future__ import absolute_import
  2. from __future__ import division
  3. from __future__ import print_function
  4. import paddle
  5. from paddle import nn
  6. class SARLoss(nn.Layer):
  7. def __init__(self, **kwargs):
  8. super(SARLoss, self).__init__()
  9. ignore_index = kwargs.get("ignore_index", 92) # 6626
  10. self.loss_func = paddle.nn.loss.CrossEntropyLoss(
  11. reduction="mean", ignore_index=ignore_index
  12. )
  13. def forward(self, predicts, batch):
  14. predict = predicts[
  15. :, :-1, :
  16. ] # ignore last index of outputs to be in same seq_len with targets
  17. label = batch[1].astype("int64")[
  18. :, 1:
  19. ] # ignore first index of target in loss calculation
  20. batch_size, num_steps, num_classes = (
  21. predict.shape[0],
  22. predict.shape[1],
  23. predict.shape[2],
  24. )
  25. assert (
  26. len(label.shape) == len(list(predict.shape)) - 1
  27. ), "The target's shape and inputs's shape is [N, d] and [N, num_steps]"
  28. inputs = paddle.reshape(predict, [-1, num_classes])
  29. targets = paddle.reshape(label, [-1])
  30. loss = self.loss_func(inputs, targets)
  31. return {"loss": loss}