__init__.py 3.3 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124
  1. # Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved.
  2. #
  3. # Licensed under the Apache License, Version 2.0 (the "License");
  4. # you may not use this file except in compliance with the License.
  5. # You may obtain a copy of the License at
  6. #
  7. # http://www.apache.org/licenses/LICENSE-2.0
  8. #
  9. # Unless required by applicable law or agreed to in writing, software
  10. # distributed under the License is distributed on an "AS IS" BASIS,
  11. # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  12. # See the License for the specific language governing permissions and
  13. # limitations under the License.
  14. import copy
  15. import paddle
  16. import paddle.nn as nn
  17. # basic_loss
  18. from .basic_loss import LossFromOutput
  19. # det loss
  20. from .det_db_loss import DBLoss
  21. from .det_east_loss import EASTLoss
  22. from .det_sast_loss import SASTLoss
  23. from .det_pse_loss import PSELoss
  24. from .det_fce_loss import FCELoss
  25. from .det_ct_loss import CTLoss
  26. from .det_drrg_loss import DRRGLoss
  27. # rec loss
  28. from .rec_ctc_loss import CTCLoss
  29. from .rec_att_loss import AttentionLoss
  30. from .rec_srn_loss import SRNLoss
  31. from .rec_ce_loss import CELoss
  32. from .rec_sar_loss import SARLoss
  33. from .rec_aster_loss import AsterLoss
  34. from .rec_pren_loss import PRENLoss
  35. from .rec_multi_loss import MultiLoss
  36. from .rec_vl_loss import VLLoss
  37. from .rec_spin_att_loss import SPINAttentionLoss
  38. from .rec_rfl_loss import RFLLoss
  39. from .rec_can_loss import CANLoss
  40. from .rec_satrn_loss import SATRNLoss
  41. from .rec_nrtr_loss import NRTRLoss
  42. from .rec_parseq_loss import ParseQLoss
  43. from .rec_cppd_loss import CPPDLoss
  44. from .rec_latexocr_loss import LaTeXOCRLoss
  45. from .rec_unimernet_loss import UniMERNetLoss
  46. from .rec_ppformulanet_loss import PPFormulaNet_S_Loss, PPFormulaNet_L_Loss
  47. # cls loss
  48. from .cls_loss import ClsLoss
  49. # e2e loss
  50. from .e2e_pg_loss import PGLoss
  51. from .kie_sdmgr_loss import SDMGRLoss
  52. # basic loss function
  53. from .basic_loss import DistanceLoss
  54. # combined loss function
  55. from .combined_loss import CombinedLoss
  56. # table loss
  57. from .table_att_loss import TableAttentionLoss, SLALoss
  58. from .table_master_loss import TableMasterLoss
  59. # vqa token loss
  60. from .vqa_token_layoutlm_loss import VQASerTokenLayoutLMLoss
  61. # sr loss
  62. from .stroke_focus_loss import StrokeFocusLoss
  63. from .text_focus_loss import TelescopeLoss
  64. def build_loss(config):
  65. support_dict = [
  66. "DBLoss",
  67. "PSELoss",
  68. "EASTLoss",
  69. "SASTLoss",
  70. "FCELoss",
  71. "CTCLoss",
  72. "ClsLoss",
  73. "AttentionLoss",
  74. "SRNLoss",
  75. "PGLoss",
  76. "CombinedLoss",
  77. "CELoss",
  78. "TableAttentionLoss",
  79. "SARLoss",
  80. "AsterLoss",
  81. "SDMGRLoss",
  82. "VQASerTokenLayoutLMLoss",
  83. "LossFromOutput",
  84. "PRENLoss",
  85. "MultiLoss",
  86. "TableMasterLoss",
  87. "SPINAttentionLoss",
  88. "VLLoss",
  89. "StrokeFocusLoss",
  90. "SLALoss",
  91. "CTLoss",
  92. "RFLLoss",
  93. "DRRGLoss",
  94. "CANLoss",
  95. "TelescopeLoss",
  96. "SATRNLoss",
  97. "NRTRLoss",
  98. "ParseQLoss",
  99. "CPPDLoss",
  100. "LaTeXOCRLoss",
  101. "UniMERNetLoss",
  102. "PPFormulaNet_S_Loss",
  103. "PPFormulaNet_L_Loss",
  104. ]
  105. config = copy.deepcopy(config)
  106. module_name = config.pop("name")
  107. assert module_name in support_dict, Exception(
  108. "loss only support {}".format(support_dict)
  109. )
  110. module_class = eval(module_name)(**config)
  111. return module_class