rec_can_loss.py 2.7 KB

12345678910111213141516171819202122232425262728293031323334353637383940414243444546474849505152535455565758596061626364656667686970717273747576777879808182838485868788
  1. # copyright (c) 2021 PaddlePaddle Authors. All Rights Reserve.
  2. #
  3. # Licensed under the Apache License, Version 2.0 (the "License");
  4. # you may not use this file except in compliance with the License.
  5. # You may obtain a copy of the License at
  6. #
  7. # http://www.apache.org/licenses/LICENSE-2.0
  8. #
  9. # Unless required by applicable law or agreed to in writing, software
  10. # distributed under the License is distributed on an "AS IS" BASIS,
  11. # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  12. # See the License for the specific language governing permissions and
  13. # limitations under the License.
  14. """
  15. This code is refer from:
  16. https://github.com/LBH1024/CAN/models/can.py
  17. """
  18. import paddle
  19. import paddle.nn as nn
  20. import numpy as np
  21. class CANLoss(nn.Layer):
  22. """
  23. CANLoss is consist of two part:
  24. word_average_loss: average accuracy of the symbol
  25. counting_loss: counting loss of every symbol
  26. """
  27. def __init__(self):
  28. super(CANLoss, self).__init__()
  29. self.use_label_mask = False
  30. self.out_channel = 111
  31. self.cross = (
  32. nn.CrossEntropyLoss(reduction="none")
  33. if self.use_label_mask
  34. else nn.CrossEntropyLoss()
  35. )
  36. self.counting_loss = nn.SmoothL1Loss(reduction="mean")
  37. self.ratio = 16
  38. def forward(self, preds, batch):
  39. word_probs = preds[0]
  40. counting_preds = preds[1]
  41. counting_preds1 = preds[2]
  42. counting_preds2 = preds[3]
  43. labels = batch[2]
  44. labels_mask = batch[3]
  45. counting_labels = gen_counting_label(labels, self.out_channel, True)
  46. counting_loss = (
  47. self.counting_loss(counting_preds1, counting_labels)
  48. + self.counting_loss(counting_preds2, counting_labels)
  49. + self.counting_loss(counting_preds, counting_labels)
  50. )
  51. word_loss = self.cross(
  52. paddle.reshape(word_probs, [-1, word_probs.shape[-1]]),
  53. paddle.reshape(labels, [-1]),
  54. )
  55. word_average_loss = (
  56. paddle.sum(paddle.reshape(word_loss * labels_mask, [-1]))
  57. / (paddle.sum(labels_mask) + 1e-10)
  58. if self.use_label_mask
  59. else word_loss
  60. )
  61. loss = word_average_loss + counting_loss
  62. return {"loss": loss}
  63. def gen_counting_label(labels, channel, tag):
  64. b, t = labels.shape
  65. counting_labels = np.zeros([b, channel])
  66. if tag:
  67. ignore = [0, 1, 107, 108, 109, 110]
  68. else:
  69. ignore = []
  70. for i in range(b):
  71. for j in range(t):
  72. k = labels[i][j]
  73. if k in ignore:
  74. continue
  75. else:
  76. counting_labels[i][k] += 1
  77. counting_labels = paddle.to_tensor(counting_labels, dtype="float32")
  78. return counting_labels