__init__.py 3.1 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108
  1. # Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved.
  2. #
  3. # Licensed under the Apache License, Version 2.0 (the "License");
  4. # you may not use this file except in compliance with the License.
  5. # You may obtain a copy of the License at
  6. #
  7. # http://www.apache.org/licenses/LICENSE-2.0
  8. #
  9. # Unless required by applicable law or agreed to in writing, software
  10. # distributed under the License is distributed on an "AS IS" BASIS,
  11. # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  12. # See the License for the specific language governing permissions and
  13. # limitations under the License.
  14. __all__ = ["build_head"]
  15. def build_head(config):
  16. # det head
  17. from .det_db_head import DBHead, PFHeadLocal
  18. from .det_east_head import EASTHead
  19. from .det_sast_head import SASTHead
  20. from .det_pse_head import PSEHead
  21. from .det_fce_head import FCEHead
  22. from .e2e_pg_head import PGHead
  23. from .det_ct_head import CT_Head
  24. # rec head
  25. from .rec_ctc_head import CTCHead
  26. from .rec_att_head import AttentionHead
  27. from .rec_srn_head import SRNHead
  28. from .rec_nrtr_head import Transformer
  29. from .rec_sar_head import SARHead
  30. from .rec_aster_head import AsterHead
  31. from .rec_pren_head import PRENHead
  32. from .rec_multi_head import MultiHead
  33. from .rec_spin_att_head import SPINAttentionHead
  34. from .rec_abinet_head import ABINetHead
  35. from .rec_robustscanner_head import RobustScannerHead
  36. from .rec_visionlan_head import VLHead
  37. from .rec_rfl_head import RFLHead
  38. from .rec_can_head import CANHead
  39. from .rec_latexocr_head import LaTeXOCRHead
  40. from .rec_satrn_head import SATRNHead
  41. from .rec_parseq_head import ParseQHead
  42. from .rec_cppd_head import CPPDHead
  43. from .rec_unimernet_head import UniMERNetHead
  44. from .rec_ppformulanet_head import PPFormulaNet_Head
  45. # cls head
  46. from .cls_head import ClsHead
  47. # kie head
  48. from .kie_sdmgr_head import SDMGRHead
  49. from .table_att_head import TableAttentionHead, SLAHead
  50. from .table_master_head import TableMasterHead
  51. support_dict = [
  52. "DBHead",
  53. "PSEHead",
  54. "FCEHead",
  55. "EASTHead",
  56. "SASTHead",
  57. "CTCHead",
  58. "ClsHead",
  59. "AttentionHead",
  60. "SRNHead",
  61. "PGHead",
  62. "Transformer",
  63. "TableAttentionHead",
  64. "SARHead",
  65. "AsterHead",
  66. "SDMGRHead",
  67. "PRENHead",
  68. "MultiHead",
  69. "ABINetHead",
  70. "TableMasterHead",
  71. "SPINAttentionHead",
  72. "VLHead",
  73. "SLAHead",
  74. "RobustScannerHead",
  75. "CT_Head",
  76. "RFLHead",
  77. "DRRGHead",
  78. "CANHead",
  79. "LaTeXOCRHead",
  80. "SATRNHead",
  81. "PFHeadLocal",
  82. "ParseQHead",
  83. "CPPDHead",
  84. "UniMERNetHead",
  85. "PPFormulaNet_Head",
  86. ]
  87. if config["name"] == "DRRGHead":
  88. from .det_drrg_head import DRRGHead
  89. support_dict.append("DRRGHead")
  90. # table head
  91. module_name = config.pop("name")
  92. assert module_name in support_dict, Exception(
  93. "head only support {}".format(support_dict)
  94. )
  95. module_class = eval(module_name)(**config)
  96. return module_class