doc_preprocessor.py 4.9 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147
  1. # Copyright (c) 2025 PaddlePaddle Authors. All Rights Reserved.
  2. #
  3. # Licensed under the Apache License, Version 2.0 (the "License");
  4. # you may not use this file except in compliance with the License.
  5. # You may obtain a copy of the License at
  6. #
  7. # http://www.apache.org/licenses/LICENSE-2.0
  8. #
  9. # Unless required by applicable law or agreed to in writing, software
  10. # distributed under the License is distributed on an "AS IS" BASIS,
  11. # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  12. # See the License for the specific language governing permissions and
  13. # limitations under the License.
  14. from .._utils.cli import (
  15. add_simple_inference_args,
  16. get_subcommand_args,
  17. perform_simple_inference,
  18. str2bool,
  19. )
  20. from .base import PaddleXPipelineWrapper, PipelineCLISubcommandExecutor
  21. from .utils import create_config_from_structure
  22. class DocPreprocessor(PaddleXPipelineWrapper):
  23. def __init__(
  24. self,
  25. doc_orientation_classify_model_name=None,
  26. doc_orientation_classify_model_dir=None,
  27. doc_unwarping_model_name=None,
  28. doc_unwarping_model_dir=None,
  29. use_doc_orientation_classify=None,
  30. use_doc_unwarping=None,
  31. **kwargs,
  32. ):
  33. self._params = {
  34. "doc_orientation_classify_model_name": doc_orientation_classify_model_name,
  35. "doc_orientation_classify_model_dir": doc_orientation_classify_model_dir,
  36. "doc_unwarping_model_name": doc_unwarping_model_name,
  37. "doc_unwarping_model_dir": doc_unwarping_model_dir,
  38. "use_doc_orientation_classify": use_doc_orientation_classify,
  39. "use_doc_unwarping": use_doc_unwarping,
  40. }
  41. super().__init__(**kwargs)
  42. @property
  43. def _paddlex_pipeline_name(self):
  44. return "doc_preprocessor"
  45. def predict_iter(
  46. self,
  47. input,
  48. *,
  49. use_doc_orientation_classify=None,
  50. use_doc_unwarping=None,
  51. ):
  52. return self.paddlex_pipeline.predict(
  53. input,
  54. use_doc_orientation_classify=use_doc_orientation_classify,
  55. use_doc_unwarping=use_doc_unwarping,
  56. )
  57. def predict(
  58. self,
  59. input,
  60. *,
  61. use_doc_orientation_classify=None,
  62. use_doc_unwarping=None,
  63. ):
  64. return list(
  65. self.predict_iter(
  66. input,
  67. use_doc_orientation_classify=use_doc_orientation_classify,
  68. use_doc_unwarping=use_doc_unwarping,
  69. )
  70. )
  71. @classmethod
  72. def get_cli_subcommand_executor(cls):
  73. return DocPreprocessorCLISubcommandExecutor()
  74. def _get_paddlex_config_overrides(self):
  75. STRUCTURE = {
  76. "SubModules.DocOrientationClassify.model_name": self._params[
  77. "doc_orientation_classify_model_name"
  78. ],
  79. "SubModules.DocOrientationClassify.model_dir": self._params[
  80. "doc_orientation_classify_model_dir"
  81. ],
  82. "SubModules.DocUnwarping.model_name": self._params[
  83. "doc_unwarping_model_name"
  84. ],
  85. "SubModules.DocUnwarping.model_dir": self._params[
  86. "doc_unwarping_model_dir"
  87. ],
  88. "use_doc_orientation_classify": self._params[
  89. "use_doc_orientation_classify"
  90. ],
  91. "use_doc_unwarping": self._params["use_doc_unwarping"],
  92. }
  93. return create_config_from_structure(STRUCTURE)
  94. class DocPreprocessorCLISubcommandExecutor(PipelineCLISubcommandExecutor):
  95. @property
  96. def subparser_name(self):
  97. return "doc_preprocessor"
  98. def _update_subparser(self, subparser):
  99. add_simple_inference_args(subparser)
  100. subparser.add_argument(
  101. "--doc_orientation_classify_model_name",
  102. type=str,
  103. help="Name of the document image orientation classification model.",
  104. )
  105. subparser.add_argument(
  106. "--doc_orientation_classify_model_dir",
  107. type=str,
  108. help="Path to the document image orientation classification model directory.",
  109. )
  110. subparser.add_argument(
  111. "--doc_unwarping_model_name",
  112. type=str,
  113. help="Name of the document image unwarping model.",
  114. )
  115. subparser.add_argument(
  116. "--doc_unwarping_model_dir",
  117. type=str,
  118. help="Path to the document image unwarping model directory.",
  119. )
  120. subparser.add_argument(
  121. "--use_doc_orientation_classify",
  122. type=str2bool,
  123. help="Whether to use document image orientation classification.",
  124. )
  125. subparser.add_argument(
  126. "--use_doc_unwarping",
  127. type=str2bool,
  128. help="Whether to use text image unwarping.",
  129. )
  130. def execute_with_args(self, args):
  131. params = get_subcommand_args(args)
  132. perform_simple_inference(DocPreprocessor, params)