doc_understanding.py 3.5 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107
  1. # Copyright (c) 2025 PaddlePaddle Authors. All Rights Reserved.
  2. #
  3. # Licensed under the Apache License, Version 2.0 (the "License");
  4. # you may not use this file except in compliance with the License.
  5. # You may obtain a copy of the License at
  6. #
  7. # http://www.apache.org/licenses/LICENSE-2.0
  8. #
  9. # Unless required by applicable law or agreed to in writing, software
  10. # distributed under the License is distributed on an "AS IS" BASIS,
  11. # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  12. # See the License for the specific language governing permissions and
  13. # limitations under the License.
  14. from paddlex.utils.pipeline_arguments import custom_type
  15. from .._utils.cli import (
  16. add_simple_inference_args,
  17. get_subcommand_args,
  18. perform_simple_inference,
  19. )
  20. from .base import PaddleXPipelineWrapper, PipelineCLISubcommandExecutor
  21. from .utils import create_config_from_structure
  22. class DocUnderstanding(PaddleXPipelineWrapper):
  23. def __init__(
  24. self,
  25. doc_understanding_model_name=None,
  26. doc_understanding_model_dir=None,
  27. doc_understanding_batch_size=None,
  28. **kwargs,
  29. ):
  30. self._params = {
  31. "doc_understanding_model_name": doc_understanding_model_name,
  32. "doc_understanding_model_dir": doc_understanding_model_dir,
  33. "doc_understanding_batch_size": doc_understanding_batch_size,
  34. }
  35. super().__init__(**kwargs)
  36. @property
  37. def _paddlex_pipeline_name(self):
  38. return "doc_understanding"
  39. def predict_iter(self, input, **kwargs):
  40. return self.paddlex_pipeline.predict(input, **kwargs)
  41. def predict(
  42. self,
  43. input,
  44. **kwargs,
  45. ):
  46. return list(self.predict_iter(input, **kwargs))
  47. @classmethod
  48. def get_cli_subcommand_executor(cls):
  49. return DocUnderstandingCLISubcommandExecutor()
  50. def _get_paddlex_config_overrides(self):
  51. STRUCTURE = {
  52. "SubModules.DocUnderstanding.model_name": self._params[
  53. "doc_understanding_model_name"
  54. ],
  55. "SubModules.DocUnderstanding.model_dir": self._params[
  56. "doc_understanding_model_dir"
  57. ],
  58. "SubModules.DocUnderstanding.batch_size": self._params[
  59. "doc_understanding_batch_size"
  60. ],
  61. }
  62. return create_config_from_structure(STRUCTURE)
  63. class DocUnderstandingCLISubcommandExecutor(PipelineCLISubcommandExecutor):
  64. input_validator = staticmethod(custom_type(dict))
  65. @property
  66. def subparser_name(self):
  67. return "doc_understanding"
  68. def _update_subparser(self, subparser):
  69. add_simple_inference_args(
  70. subparser,
  71. input_help='Input dict, e.g. `{"image": "https://paddle-model-ecology.bj.bcebos.com/paddlex/imgs/demo_image/medal_table.png", "query": "Recognize this table"}`.',
  72. )
  73. subparser.add_argument(
  74. "--doc_understanding_model_name",
  75. type=str,
  76. help="Name of the document understanding model.",
  77. )
  78. subparser.add_argument(
  79. "--doc_understanding_model_dir",
  80. type=str,
  81. help="Path to the document understanding model directory.",
  82. )
  83. subparser.add_argument(
  84. "--doc_understanding_batch_size",
  85. type=str,
  86. help="Batch size for the document understanding model.",
  87. )
  88. def execute_with_args(self, args):
  89. params = get_subcommand_args(args)
  90. params["input"] = self.input_validator(params["input"])
  91. perform_simple_inference(DocUnderstanding, params)