paddleocr_vl.py 13 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377
  1. # Copyright (c) 2025 PaddlePaddle Authors. All Rights Reserved.
  2. #
  3. # Licensed under the Apache License, Version 2.0 (the "License");
  4. # you may not use this file except in compliance with the License.
  5. # You may obtain a copy of the License at
  6. #
  7. # http://www.apache.org/licenses/LICENSE-2.0
  8. #
  9. # Unless required by applicable law or agreed to in writing, software
  10. # distributed under the License is distributed on an "AS IS" BASIS,
  11. # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  12. # See the License for the specific language governing permissions and
  13. # limitations under the License.
  14. from .._utils.cli import (
  15. add_simple_inference_args,
  16. get_subcommand_args,
  17. perform_simple_inference,
  18. str2bool,
  19. )
  20. from .base import PaddleXPipelineWrapper, PipelineCLISubcommandExecutor
  21. from .utils import create_config_from_structure
  22. _SUPPORTED_VL_BACKENDS = ["native", "vllm-server", "sglang-server", "fastdeploy-server"]
  23. class PaddleOCRVL(PaddleXPipelineWrapper):
  24. def __init__(
  25. self,
  26. layout_detection_model_name=None,
  27. layout_detection_model_dir=None,
  28. layout_threshold=None,
  29. layout_nms=None,
  30. layout_unclip_ratio=None,
  31. layout_merge_bboxes_mode=None,
  32. vl_rec_model_name=None,
  33. vl_rec_model_dir=None,
  34. vl_rec_backend=None,
  35. vl_rec_server_url=None,
  36. vl_rec_max_concurrency=None,
  37. doc_orientation_classify_model_name=None,
  38. doc_orientation_classify_model_dir=None,
  39. doc_unwarping_model_name=None,
  40. doc_unwarping_model_dir=None,
  41. use_doc_orientation_classify=None,
  42. use_doc_unwarping=None,
  43. use_layout_detection=None,
  44. use_chart_recognition=None,
  45. format_block_content=None,
  46. **kwargs,
  47. ):
  48. if vl_rec_backend is not None and vl_rec_backend not in _SUPPORTED_VL_BACKENDS:
  49. raise ValueError(
  50. f"Invalid backend for the VL recognition module: {vl_rec_backend}. Supported values are {_SUPPORTED_VL_BACKENDS}."
  51. )
  52. params = locals().copy()
  53. params.pop("self")
  54. params.pop("kwargs")
  55. self._params = params
  56. super().__init__(**kwargs)
  57. @property
  58. def _paddlex_pipeline_name(self):
  59. return "PaddleOCR-VL"
  60. def predict_iter(
  61. self,
  62. input,
  63. *,
  64. use_doc_orientation_classify=None,
  65. use_doc_unwarping=None,
  66. use_layout_detection=None,
  67. use_chart_recognition=None,
  68. layout_threshold=None,
  69. layout_nms=None,
  70. layout_unclip_ratio=None,
  71. layout_merge_bboxes_mode=None,
  72. use_queues=None,
  73. prompt_label=None,
  74. format_block_content=None,
  75. repetition_penalty=None,
  76. temperature=None,
  77. top_p=None,
  78. min_pixels=None,
  79. max_pixels=None,
  80. **kwargs,
  81. ):
  82. return self.paddlex_pipeline.predict(
  83. input,
  84. use_doc_orientation_classify=use_doc_orientation_classify,
  85. use_doc_unwarping=use_doc_unwarping,
  86. use_layout_detection=use_layout_detection,
  87. use_chart_recognition=use_chart_recognition,
  88. layout_threshold=layout_threshold,
  89. layout_nms=layout_nms,
  90. layout_unclip_ratio=layout_unclip_ratio,
  91. layout_merge_bboxes_mode=layout_merge_bboxes_mode,
  92. use_queues=use_queues,
  93. prompt_label=prompt_label,
  94. format_block_content=format_block_content,
  95. repetition_penalty=repetition_penalty,
  96. temperature=temperature,
  97. top_p=top_p,
  98. min_pixels=min_pixels,
  99. max_pixels=max_pixels,
  100. **kwargs,
  101. )
  102. def predict(
  103. self,
  104. input,
  105. *,
  106. use_doc_orientation_classify=None,
  107. use_doc_unwarping=None,
  108. use_layout_detection=None,
  109. use_chart_recognition=None,
  110. layout_threshold=None,
  111. layout_nms=None,
  112. layout_unclip_ratio=None,
  113. layout_merge_bboxes_mode=None,
  114. use_queues=None,
  115. prompt_label=None,
  116. format_block_content=None,
  117. repetition_penalty=None,
  118. temperature=None,
  119. top_p=None,
  120. min_pixels=None,
  121. max_pixels=None,
  122. **kwargs,
  123. ):
  124. return list(
  125. self.predict_iter(
  126. input,
  127. use_doc_orientation_classify=use_doc_orientation_classify,
  128. use_doc_unwarping=use_doc_unwarping,
  129. use_layout_detection=use_layout_detection,
  130. use_chart_recognition=use_chart_recognition,
  131. layout_threshold=layout_threshold,
  132. layout_nms=layout_nms,
  133. layout_unclip_ratio=layout_unclip_ratio,
  134. layout_merge_bboxes_mode=layout_merge_bboxes_mode,
  135. use_queues=use_queues,
  136. prompt_label=prompt_label,
  137. format_block_content=format_block_content,
  138. repetition_penalty=repetition_penalty,
  139. temperature=temperature,
  140. top_p=top_p,
  141. min_pixels=min_pixels,
  142. max_pixels=max_pixels,
  143. **kwargs,
  144. )
  145. )
  146. def concatenate_markdown_pages(self, markdown_list):
  147. return self.paddlex_pipeline.concatenate_markdown_pages(markdown_list)
  148. @classmethod
  149. def get_cli_subcommand_executor(cls):
  150. return PaddleOCRVLCLISubcommandExecutor()
  151. def _get_paddlex_config_overrides(self):
  152. STRUCTURE = {
  153. "SubPipelines.DocPreprocessor.use_doc_orientation_classify": self._params[
  154. "use_doc_orientation_classify"
  155. ],
  156. "SubPipelines.DocPreprocessor.use_doc_unwarping": self._params[
  157. "use_doc_unwarping"
  158. ],
  159. "use_doc_preprocessor": self._params["use_doc_orientation_classify"]
  160. or self._params["use_doc_unwarping"],
  161. "use_layout_detection": self._params["use_layout_detection"],
  162. "use_chart_recognition": self._params["use_chart_recognition"],
  163. "format_block_content": self._params["format_block_content"],
  164. "SubModules.LayoutDetection.model_name": self._params[
  165. "layout_detection_model_name"
  166. ],
  167. "SubModules.LayoutDetection.model_dir": self._params[
  168. "layout_detection_model_dir"
  169. ],
  170. "SubModules.LayoutDetection.threshold": self._params["layout_threshold"],
  171. "SubModules.LayoutDetection.layout_nms": self._params["layout_nms"],
  172. "SubModules.LayoutDetection.layout_unclip_ratio": self._params[
  173. "layout_unclip_ratio"
  174. ],
  175. "SubModules.LayoutDetection.layout_merge_bboxes_mode": self._params[
  176. "layout_merge_bboxes_mode"
  177. ],
  178. "SubModules.VLRecognition.model_name": self._params["vl_rec_model_name"],
  179. "SubModules.VLRecognition.model_dir": self._params["vl_rec_model_dir"],
  180. "SubModules.VLRecognition.genai_config.backend": self._params[
  181. "vl_rec_backend"
  182. ],
  183. "SubModules.VLRecognition.genai_config.server_url": self._params[
  184. "vl_rec_server_url"
  185. ],
  186. "SubPipelines.DocPreprocessor.SubModules.DocOrientationClassify.model_name": self._params[
  187. "doc_orientation_classify_model_name"
  188. ],
  189. "SubPipelines.DocPreprocessor.SubModules.DocOrientationClassify.model_dir": self._params[
  190. "doc_orientation_classify_model_dir"
  191. ],
  192. "SubPipelines.DocPreprocessor.SubModules.DocUnwarping.model_name": self._params[
  193. "doc_unwarping_model_name"
  194. ],
  195. "SubPipelines.DocPreprocessor.SubModules.DocUnwarping.model_dir": self._params[
  196. "doc_unwarping_model_dir"
  197. ],
  198. }
  199. return create_config_from_structure(STRUCTURE)
  200. class PaddleOCRVLCLISubcommandExecutor(PipelineCLISubcommandExecutor):
  201. @property
  202. def subparser_name(self):
  203. return "doc_parser"
  204. def _update_subparser(self, subparser):
  205. add_simple_inference_args(subparser)
  206. subparser.add_argument(
  207. "--layout_detection_model_name",
  208. type=str,
  209. help="Name of the layout detection model.",
  210. )
  211. subparser.add_argument(
  212. "--layout_detection_model_dir",
  213. type=str,
  214. help="Path to the layout detection model directory.",
  215. )
  216. subparser.add_argument(
  217. "--layout_threshold",
  218. type=float,
  219. help="Score threshold for the layout detection model.",
  220. )
  221. subparser.add_argument(
  222. "--layout_nms",
  223. type=str2bool,
  224. help="Whether to use NMS in layout detection.",
  225. )
  226. subparser.add_argument(
  227. "--layout_unclip_ratio",
  228. type=float,
  229. help="Expansion coefficient for layout detection.",
  230. )
  231. subparser.add_argument(
  232. "--layout_merge_bboxes_mode",
  233. type=str,
  234. help="Overlapping box filtering method.",
  235. )
  236. subparser.add_argument(
  237. "--vl_rec_model_name",
  238. type=str,
  239. help="Name of the VL recognition model.",
  240. )
  241. subparser.add_argument(
  242. "--vl_rec_model_dir",
  243. type=str,
  244. help="Path to the VL recognition model directory.",
  245. )
  246. subparser.add_argument(
  247. "--vl_rec_backend",
  248. type=str,
  249. help="Backend used by the VL recognition module.",
  250. choices=_SUPPORTED_VL_BACKENDS,
  251. )
  252. subparser.add_argument(
  253. "--vl_rec_server_url",
  254. type=str,
  255. help="Server URL used by the VL recognition module.",
  256. )
  257. subparser.add_argument(
  258. "--vl_rec_max_concurrency",
  259. type=str,
  260. help="Maximum concurrency for making VLM requests.",
  261. )
  262. subparser.add_argument(
  263. "--doc_orientation_classify_model_name",
  264. type=str,
  265. help="Name of the document image orientation classification model.",
  266. )
  267. subparser.add_argument(
  268. "--doc_orientation_classify_model_dir",
  269. type=str,
  270. help="Path to the document image orientation classification model directory.",
  271. )
  272. subparser.add_argument(
  273. "--doc_unwarping_model_name",
  274. type=str,
  275. help="Name of the text image unwarping model.",
  276. )
  277. subparser.add_argument(
  278. "--doc_unwarping_model_dir",
  279. type=str,
  280. help="Path to the image unwarping model directory.",
  281. )
  282. subparser.add_argument(
  283. "--use_doc_orientation_classify",
  284. type=str2bool,
  285. help="Whether to use document image orientation classification.",
  286. )
  287. subparser.add_argument(
  288. "--use_doc_unwarping",
  289. type=str2bool,
  290. help="Whether to use text image unwarping.",
  291. )
  292. subparser.add_argument(
  293. "--use_layout_detection",
  294. type=str2bool,
  295. help="Whether to use layout detection.",
  296. )
  297. subparser.add_argument(
  298. "--use_chart_recognition",
  299. type=str2bool,
  300. help="Whether to use chart recognition.",
  301. )
  302. subparser.add_argument(
  303. "--format_block_content",
  304. type=str2bool,
  305. help="Whether to format block content to Markdown.",
  306. )
  307. subparser.add_argument(
  308. "--use_queues",
  309. type=str2bool,
  310. help="Whether to use queues for asynchronous processing.",
  311. )
  312. subparser.add_argument(
  313. "--prompt_label",
  314. type=str,
  315. help="Prompt label for the VLM.",
  316. )
  317. subparser.add_argument(
  318. "--repetition_penalty",
  319. type=float,
  320. help="Repetition penalty used in sampling for the VLM.",
  321. )
  322. subparser.add_argument(
  323. "--temperature",
  324. type=float,
  325. help="Temperature parameter used in sampling for the VLM.",
  326. )
  327. subparser.add_argument(
  328. "--top_p",
  329. type=float,
  330. help="Top-p parameter used in sampling for the VLM.",
  331. )
  332. subparser.add_argument(
  333. "--min_pixels",
  334. type=int,
  335. help="Minimum pixels for image preprocessing for the VLM.",
  336. )
  337. subparser.add_argument(
  338. "--max_pixels",
  339. type=int,
  340. help="Maximum pixels for image preprocessing for the VLM.",
  341. )
  342. def execute_with_args(self, args):
  343. params = get_subcommand_args(args)
  344. perform_simple_inference(
  345. PaddleOCRVL,
  346. params,
  347. predict_param_names={
  348. "use_queues",
  349. "prompt_label",
  350. "repetition_penalty",
  351. "temperature",
  352. "top_p",
  353. "min_pixels",
  354. "max_pixels",
  355. },
  356. )