combined_loss.py 2.8 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384
  1. # Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved.
  2. #
  3. # Licensed under the Apache License, Version 2.0 (the "License");
  4. # you may not use this file except in compliance with the License.
  5. # You may obtain a copy of the License at
  6. #
  7. # http://www.apache.org/licenses/LICENSE-2.0
  8. #
  9. # Unless required by applicable law or agreed to in writing, software
  10. # distributed under the License is distributed on an "AS IS" BASIS,
  11. # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  12. # See the License for the specific language governing permissions and
  13. # limitations under the License.
  14. import paddle
  15. import paddle.nn as nn
  16. from .rec_ctc_loss import CTCLoss
  17. from .center_loss import CenterLoss
  18. from .ace_loss import ACELoss
  19. from .rec_sar_loss import SARLoss
  20. from .distillation_loss import DistillationCTCLoss, DistillCTCLogits
  21. from .distillation_loss import DistillationSARLoss, DistillationNRTRLoss
  22. from .distillation_loss import (
  23. DistillationDMLLoss,
  24. DistillationKLDivLoss,
  25. DistillationDKDLoss,
  26. )
  27. from .distillation_loss import (
  28. DistillationDistanceLoss,
  29. DistillationDBLoss,
  30. DistillationDilaDBLoss,
  31. )
  32. from .distillation_loss import (
  33. DistillationVQASerTokenLayoutLMLoss,
  34. DistillationSERDMLLoss,
  35. )
  36. from .distillation_loss import DistillationLossFromOutput
  37. from .distillation_loss import DistillationVQADistanceLoss
  38. class CombinedLoss(nn.Layer):
  39. """
  40. CombinedLoss:
  41. a combionation of loss function
  42. """
  43. def __init__(self, loss_config_list=None):
  44. super().__init__()
  45. self.loss_func = []
  46. self.loss_weight = []
  47. assert isinstance(loss_config_list, list), "operator config should be a list"
  48. for config in loss_config_list:
  49. assert isinstance(config, dict) and len(config) == 1, "yaml format error"
  50. name = list(config)[0]
  51. param = config[name]
  52. assert (
  53. "weight" in param
  54. ), "weight must be in param, but param just contains {}".format(
  55. param.keys()
  56. )
  57. self.loss_weight.append(param.pop("weight"))
  58. self.loss_func.append(eval(name)(**param))
  59. def forward(self, input, batch, **kargs):
  60. loss_dict = {}
  61. loss_all = 0.0
  62. for idx, loss_func in enumerate(self.loss_func):
  63. loss = loss_func(input, batch, **kargs)
  64. if isinstance(loss, paddle.Tensor):
  65. loss = {"loss_{}_{}".format(str(loss), idx): loss}
  66. weight = self.loss_weight[idx]
  67. loss = {key: loss[key] * weight for key in loss}
  68. if "loss" in loss:
  69. loss_all += loss["loss"]
  70. else:
  71. loss_all += paddle.add_n(list(loss.values()))
  72. loss_dict.update(loss)
  73. loss_dict["loss"] = loss_all
  74. return loss_dict