__init__.py 3.6 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121
  1. # copyright (c) 2020 PaddlePaddle Authors. All Rights Reserve.
  2. #
  3. # Licensed under the Apache License, Version 2.0 (the "License");
  4. # you may not use this file except in compliance with the License.
  5. # You may obtain a copy of the License at
  6. #
  7. # http://www.apache.org/licenses/LICENSE-2.0
  8. #
  9. # Unless required by applicable law or agreed to in writing, software
  10. # distributed under the License is distributed on an "AS IS" BASIS,
  11. # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  12. # See the License for the specific language governing permissions and
  13. # limitations under the License.
  14. from __future__ import absolute_import
  15. from __future__ import division
  16. from __future__ import print_function
  17. from __future__ import unicode_literals
  18. import os
  19. import copy
  20. __all__ = ["build_post_process"]
  21. from .db_postprocess import DBPostProcess, DistillationDBPostProcess
  22. from .east_postprocess import EASTPostProcess
  23. from .sast_postprocess import SASTPostProcess
  24. from .fce_postprocess import FCEPostProcess
  25. from .rec_postprocess import (
  26. CTCLabelDecode,
  27. AttnLabelDecode,
  28. SRNLabelDecode,
  29. DistillationCTCLabelDecode,
  30. NRTRLabelDecode,
  31. SARLabelDecode,
  32. SEEDLabelDecode,
  33. PRENLabelDecode,
  34. ViTSTRLabelDecode,
  35. ABINetLabelDecode,
  36. SPINLabelDecode,
  37. VLLabelDecode,
  38. RFLLabelDecode,
  39. SATRNLabelDecode,
  40. ParseQLabelDecode,
  41. CPPDLabelDecode,
  42. LaTeXOCRDecode,
  43. UniMERNetDecode,
  44. )
  45. from .cls_postprocess import ClsPostProcess
  46. from .pg_postprocess import PGPostProcess
  47. from .vqa_token_ser_layoutlm_postprocess import (
  48. VQASerTokenLayoutLMPostProcess,
  49. DistillationSerPostProcess,
  50. )
  51. from .vqa_token_re_layoutlm_postprocess import (
  52. VQAReTokenLayoutLMPostProcess,
  53. DistillationRePostProcess,
  54. )
  55. from .table_postprocess import TableMasterLabelDecode, TableLabelDecode
  56. from .picodet_postprocess import PicoDetPostProcess
  57. from .ct_postprocess import CTPostProcess
  58. from .drrg_postprocess import DRRGPostprocess
  59. from .rec_postprocess import CANLabelDecode
  60. def build_post_process(config, global_config=None):
  61. support_dict = [
  62. "DBPostProcess",
  63. "EASTPostProcess",
  64. "SASTPostProcess",
  65. "FCEPostProcess",
  66. "CTCLabelDecode",
  67. "AttnLabelDecode",
  68. "ClsPostProcess",
  69. "SRNLabelDecode",
  70. "PGPostProcess",
  71. "DistillationCTCLabelDecode",
  72. "TableLabelDecode",
  73. "DistillationDBPostProcess",
  74. "NRTRLabelDecode",
  75. "SARLabelDecode",
  76. "SEEDLabelDecode",
  77. "VQASerTokenLayoutLMPostProcess",
  78. "VQAReTokenLayoutLMPostProcess",
  79. "PRENLabelDecode",
  80. "DistillationSARLabelDecode",
  81. "ViTSTRLabelDecode",
  82. "ABINetLabelDecode",
  83. "TableMasterLabelDecode",
  84. "SPINLabelDecode",
  85. "DistillationSerPostProcess",
  86. "DistillationRePostProcess",
  87. "VLLabelDecode",
  88. "PicoDetPostProcess",
  89. "CTPostProcess",
  90. "RFLLabelDecode",
  91. "DRRGPostprocess",
  92. "CANLabelDecode",
  93. "SATRNLabelDecode",
  94. "ParseQLabelDecode",
  95. "CPPDLabelDecode",
  96. "LaTeXOCRDecode",
  97. "UniMERNetDecode",
  98. ]
  99. if config["name"] == "PSEPostProcess":
  100. from .pse_postprocess import PSEPostProcess
  101. support_dict.append("PSEPostProcess")
  102. config = copy.deepcopy(config)
  103. module_name = config.pop("name")
  104. if module_name == "None":
  105. return
  106. if global_config is not None:
  107. config.update(global_config)
  108. assert module_name in support_dict, Exception(
  109. "post process only support {}".format(support_dict)
  110. )
  111. module_class = eval(module_name)(**config)
  112. return module_class