table_att_loss.py 3.7 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100
  1. # copyright (c) 2021 PaddlePaddle Authors. All Rights Reserve.
  2. #
  3. # Licensed under the Apache License, Version 2.0 (the "License");
  4. # you may not use this file except in compliance with the License.
  5. # You may obtain a copy of the License at
  6. #
  7. # http://www.apache.org/licenses/LICENSE-2.0
  8. #
  9. # Unless required by applicable law or agreed to in writing, software
  10. # distributed under the License is distributed on an "AS IS" BASIS,
  11. # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  12. # See the License for the specific language governing permissions and
  13. # limitations under the License.
  14. from __future__ import absolute_import
  15. from __future__ import division
  16. from __future__ import print_function
  17. import paddle
  18. from paddle import nn
  19. from paddle.nn import functional as F
  20. class TableAttentionLoss(nn.Layer):
  21. def __init__(self, structure_weight=1.0, loc_weight=0.0, **kwargs):
  22. super(TableAttentionLoss, self).__init__()
  23. self.loss_func = nn.CrossEntropyLoss(weight=None, reduction="none")
  24. self.structure_weight = structure_weight
  25. self.loc_weight = loc_weight
  26. def forward(self, predicts, batch):
  27. structure_probs = predicts["structure_probs"]
  28. structure_targets = batch[1].astype("int64")
  29. structure_targets = structure_targets[:, 1:]
  30. structure_probs = paddle.reshape(
  31. structure_probs, [-1, structure_probs.shape[-1]]
  32. )
  33. structure_targets = paddle.reshape(structure_targets, [-1])
  34. structure_loss = self.loss_func(structure_probs, structure_targets)
  35. structure_loss = paddle.mean(structure_loss) * self.structure_weight
  36. loc_preds = predicts["loc_preds"]
  37. loc_targets = batch[2].astype("float32")
  38. loc_targets_mask = batch[3].astype("float32")
  39. loc_targets = loc_targets[:, 1:, :]
  40. loc_targets_mask = loc_targets_mask[:, 1:, :]
  41. loc_loss = (
  42. F.mse_loss(loc_preds * loc_targets_mask, loc_targets) * self.loc_weight
  43. )
  44. total_loss = structure_loss + loc_loss
  45. return {
  46. "loss": total_loss,
  47. "structure_loss": structure_loss,
  48. "loc_loss": loc_loss,
  49. }
  50. class SLALoss(nn.Layer):
  51. def __init__(self, structure_weight=1.0, loc_weight=0.0, loc_loss="mse", **kwargs):
  52. super(SLALoss, self).__init__()
  53. self.loss_func = nn.CrossEntropyLoss(weight=None, reduction="mean")
  54. self.structure_weight = structure_weight
  55. self.loc_weight = loc_weight
  56. self.loc_loss = loc_loss
  57. self.eps = 1e-12
  58. def forward(self, predicts, batch):
  59. structure_probs = predicts["structure_probs"]
  60. structure_targets = batch[1].astype("int64")
  61. max_len = batch[-2].max().astype("int32")
  62. structure_targets = structure_targets[:, 1 : max_len + 2]
  63. structure_loss = self.loss_func(structure_probs, structure_targets)
  64. structure_loss = paddle.mean(structure_loss) * self.structure_weight
  65. loc_preds = predicts["loc_preds"]
  66. loc_targets = batch[2].astype("float32")
  67. loc_targets_mask = batch[3].astype("float32")
  68. loc_targets = loc_targets[:, 1 : max_len + 2]
  69. loc_targets_mask = loc_targets_mask[:, 1 : max_len + 2]
  70. loc_loss = (
  71. F.smooth_l1_loss(
  72. loc_preds * loc_targets_mask,
  73. loc_targets * loc_targets_mask,
  74. reduction="sum",
  75. )
  76. * self.loc_weight
  77. )
  78. loc_loss = loc_loss / (loc_targets_mask.sum() + self.eps)
  79. total_loss = structure_loss + loc_loss
  80. return {
  81. "loss": total_loss,
  82. "structure_loss": structure_loss,
  83. "loc_loss": loc_loss,
  84. }