paddleocr_vl.py 14 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398
  1. # Copyright (c) 2025 PaddlePaddle Authors. All Rights Reserved.
  2. #
  3. # Licensed under the Apache License, Version 2.0 (the "License");
  4. # you may not use this file except in compliance with the License.
  5. # You may obtain a copy of the License at
  6. #
  7. # http://www.apache.org/licenses/LICENSE-2.0
  8. #
  9. # Unless required by applicable law or agreed to in writing, software
  10. # distributed under the License is distributed on an "AS IS" BASIS,
  11. # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  12. # See the License for the specific language governing permissions and
  13. # limitations under the License.
  14. from .._utils.cli import (
  15. add_simple_inference_args,
  16. get_subcommand_args,
  17. perform_simple_inference,
  18. str2bool,
  19. )
  20. from .base import PaddleXPipelineWrapper, PipelineCLISubcommandExecutor
  21. from .utils import create_config_from_structure
  22. _SUPPORTED_VL_BACKENDS = ["native", "vllm-server", "sglang-server", "fastdeploy-server"]
  23. class PaddleOCRVL(PaddleXPipelineWrapper):
  24. def __init__(
  25. self,
  26. layout_detection_model_name=None,
  27. layout_detection_model_dir=None,
  28. layout_threshold=None,
  29. layout_nms=None,
  30. layout_unclip_ratio=None,
  31. layout_merge_bboxes_mode=None,
  32. vl_rec_model_name=None,
  33. vl_rec_model_dir=None,
  34. vl_rec_backend=None,
  35. vl_rec_server_url=None,
  36. vl_rec_max_concurrency=None,
  37. vl_rec_api_model_name=None,
  38. vl_rec_api_key=None,
  39. doc_orientation_classify_model_name=None,
  40. doc_orientation_classify_model_dir=None,
  41. doc_unwarping_model_name=None,
  42. doc_unwarping_model_dir=None,
  43. use_doc_orientation_classify=None,
  44. use_doc_unwarping=None,
  45. use_layout_detection=None,
  46. use_chart_recognition=None,
  47. format_block_content=None,
  48. **kwargs,
  49. ):
  50. if vl_rec_backend is not None and vl_rec_backend not in _SUPPORTED_VL_BACKENDS:
  51. raise ValueError(
  52. f"Invalid backend for the VL recognition module: {vl_rec_backend}. Supported values are {_SUPPORTED_VL_BACKENDS}."
  53. )
  54. params = locals().copy()
  55. params.pop("self")
  56. params.pop("kwargs")
  57. self._params = params
  58. super().__init__(**kwargs)
  59. @property
  60. def _paddlex_pipeline_name(self):
  61. return "PaddleOCR-VL"
  62. def predict_iter(
  63. self,
  64. input,
  65. *,
  66. use_doc_orientation_classify=None,
  67. use_doc_unwarping=None,
  68. use_layout_detection=None,
  69. use_chart_recognition=None,
  70. layout_threshold=None,
  71. layout_nms=None,
  72. layout_unclip_ratio=None,
  73. layout_merge_bboxes_mode=None,
  74. use_queues=None,
  75. prompt_label=None,
  76. format_block_content=None,
  77. repetition_penalty=None,
  78. temperature=None,
  79. top_p=None,
  80. min_pixels=None,
  81. max_pixels=None,
  82. **kwargs,
  83. ):
  84. return self.paddlex_pipeline.predict(
  85. input,
  86. use_doc_orientation_classify=use_doc_orientation_classify,
  87. use_doc_unwarping=use_doc_unwarping,
  88. use_layout_detection=use_layout_detection,
  89. use_chart_recognition=use_chart_recognition,
  90. layout_threshold=layout_threshold,
  91. layout_nms=layout_nms,
  92. layout_unclip_ratio=layout_unclip_ratio,
  93. layout_merge_bboxes_mode=layout_merge_bboxes_mode,
  94. use_queues=use_queues,
  95. prompt_label=prompt_label,
  96. format_block_content=format_block_content,
  97. repetition_penalty=repetition_penalty,
  98. temperature=temperature,
  99. top_p=top_p,
  100. min_pixels=min_pixels,
  101. max_pixels=max_pixels,
  102. **kwargs,
  103. )
  104. def predict(
  105. self,
  106. input,
  107. *,
  108. use_doc_orientation_classify=None,
  109. use_doc_unwarping=None,
  110. use_layout_detection=None,
  111. use_chart_recognition=None,
  112. layout_threshold=None,
  113. layout_nms=None,
  114. layout_unclip_ratio=None,
  115. layout_merge_bboxes_mode=None,
  116. use_queues=None,
  117. prompt_label=None,
  118. format_block_content=None,
  119. repetition_penalty=None,
  120. temperature=None,
  121. top_p=None,
  122. min_pixels=None,
  123. max_pixels=None,
  124. **kwargs,
  125. ):
  126. return list(
  127. self.predict_iter(
  128. input,
  129. use_doc_orientation_classify=use_doc_orientation_classify,
  130. use_doc_unwarping=use_doc_unwarping,
  131. use_layout_detection=use_layout_detection,
  132. use_chart_recognition=use_chart_recognition,
  133. layout_threshold=layout_threshold,
  134. layout_nms=layout_nms,
  135. layout_unclip_ratio=layout_unclip_ratio,
  136. layout_merge_bboxes_mode=layout_merge_bboxes_mode,
  137. use_queues=use_queues,
  138. prompt_label=prompt_label,
  139. format_block_content=format_block_content,
  140. repetition_penalty=repetition_penalty,
  141. temperature=temperature,
  142. top_p=top_p,
  143. min_pixels=min_pixels,
  144. max_pixels=max_pixels,
  145. **kwargs,
  146. )
  147. )
  148. def concatenate_markdown_pages(self, markdown_list):
  149. return self.paddlex_pipeline.concatenate_markdown_pages(markdown_list)
  150. @classmethod
  151. def get_cli_subcommand_executor(cls):
  152. return PaddleOCRVLCLISubcommandExecutor()
  153. def _get_paddlex_config_overrides(self):
  154. STRUCTURE = {
  155. "SubPipelines.DocPreprocessor.use_doc_orientation_classify": self._params[
  156. "use_doc_orientation_classify"
  157. ],
  158. "SubPipelines.DocPreprocessor.use_doc_unwarping": self._params[
  159. "use_doc_unwarping"
  160. ],
  161. "use_doc_preprocessor": self._params["use_doc_orientation_classify"]
  162. or self._params["use_doc_unwarping"],
  163. "use_layout_detection": self._params["use_layout_detection"],
  164. "use_chart_recognition": self._params["use_chart_recognition"],
  165. "format_block_content": self._params["format_block_content"],
  166. "SubModules.LayoutDetection.model_name": self._params[
  167. "layout_detection_model_name"
  168. ],
  169. "SubModules.LayoutDetection.model_dir": self._params[
  170. "layout_detection_model_dir"
  171. ],
  172. "SubModules.LayoutDetection.threshold": self._params["layout_threshold"],
  173. "SubModules.LayoutDetection.layout_nms": self._params["layout_nms"],
  174. "SubModules.LayoutDetection.layout_unclip_ratio": self._params[
  175. "layout_unclip_ratio"
  176. ],
  177. "SubModules.LayoutDetection.layout_merge_bboxes_mode": self._params[
  178. "layout_merge_bboxes_mode"
  179. ],
  180. "SubModules.VLRecognition.model_name": self._params["vl_rec_model_name"],
  181. "SubModules.VLRecognition.model_dir": self._params["vl_rec_model_dir"],
  182. "SubModules.VLRecognition.genai_config.backend": self._params[
  183. "vl_rec_backend"
  184. ],
  185. "SubModules.VLRecognition.genai_config.server_url": self._params[
  186. "vl_rec_server_url"
  187. ],
  188. "SubModules.VLRecognition.genai_config.max_concurrency": self._params[
  189. "vl_rec_max_concurrency"
  190. ],
  191. "SubModules.VLRecognition.genai_config.client_kwargs.model_name": self._params[
  192. "vl_rec_api_model_name"
  193. ],
  194. "SubModules.VLRecognition.genai_config.client_kwargs.api_key": self._params[
  195. "vl_rec_api_key"
  196. ],
  197. "SubPipelines.DocPreprocessor.SubModules.DocOrientationClassify.model_name": self._params[
  198. "doc_orientation_classify_model_name"
  199. ],
  200. "SubPipelines.DocPreprocessor.SubModules.DocOrientationClassify.model_dir": self._params[
  201. "doc_orientation_classify_model_dir"
  202. ],
  203. "SubPipelines.DocPreprocessor.SubModules.DocUnwarping.model_name": self._params[
  204. "doc_unwarping_model_name"
  205. ],
  206. "SubPipelines.DocPreprocessor.SubModules.DocUnwarping.model_dir": self._params[
  207. "doc_unwarping_model_dir"
  208. ],
  209. }
  210. return create_config_from_structure(STRUCTURE)
  211. class PaddleOCRVLCLISubcommandExecutor(PipelineCLISubcommandExecutor):
  212. @property
  213. def subparser_name(self):
  214. return "doc_parser"
  215. def _update_subparser(self, subparser):
  216. add_simple_inference_args(subparser)
  217. subparser.add_argument(
  218. "--layout_detection_model_name",
  219. type=str,
  220. help="Name of the layout detection model.",
  221. )
  222. subparser.add_argument(
  223. "--layout_detection_model_dir",
  224. type=str,
  225. help="Path to the layout detection model directory.",
  226. )
  227. subparser.add_argument(
  228. "--layout_threshold",
  229. type=float,
  230. help="Score threshold for the layout detection model.",
  231. )
  232. subparser.add_argument(
  233. "--layout_nms",
  234. type=str2bool,
  235. help="Whether to use NMS in layout detection.",
  236. )
  237. subparser.add_argument(
  238. "--layout_unclip_ratio",
  239. type=float,
  240. help="Expansion coefficient for layout detection.",
  241. )
  242. subparser.add_argument(
  243. "--layout_merge_bboxes_mode",
  244. type=str,
  245. help="Overlapping box filtering method.",
  246. )
  247. subparser.add_argument(
  248. "--vl_rec_model_name",
  249. type=str,
  250. help="Name of the VL recognition model.",
  251. )
  252. subparser.add_argument(
  253. "--vl_rec_model_dir",
  254. type=str,
  255. help="Path to the VL recognition model directory.",
  256. )
  257. subparser.add_argument(
  258. "--vl_rec_backend",
  259. type=str,
  260. help="Backend used by the VL recognition module.",
  261. choices=_SUPPORTED_VL_BACKENDS,
  262. )
  263. subparser.add_argument(
  264. "--vl_rec_server_url",
  265. type=str,
  266. help="Server URL used by the VL recognition module.",
  267. )
  268. subparser.add_argument(
  269. "--vl_rec_max_concurrency",
  270. type=int,
  271. help="Maximum concurrency for making VLM requests.",
  272. )
  273. subparser.add_argument(
  274. "--vl_rec_api_model_name",
  275. type=str,
  276. help="Model name for the VLM server.",
  277. )
  278. subparser.add_argument(
  279. "--vl_rec_api_key",
  280. type=str,
  281. help="API key for the VLM server.",
  282. )
  283. subparser.add_argument(
  284. "--doc_orientation_classify_model_name",
  285. type=str,
  286. help="Name of the document image orientation classification model.",
  287. )
  288. subparser.add_argument(
  289. "--doc_orientation_classify_model_dir",
  290. type=str,
  291. help="Path to the document image orientation classification model directory.",
  292. )
  293. subparser.add_argument(
  294. "--doc_unwarping_model_name",
  295. type=str,
  296. help="Name of the text image unwarping model.",
  297. )
  298. subparser.add_argument(
  299. "--doc_unwarping_model_dir",
  300. type=str,
  301. help="Path to the image unwarping model directory.",
  302. )
  303. subparser.add_argument(
  304. "--use_doc_orientation_classify",
  305. type=str2bool,
  306. help="Whether to use document image orientation classification.",
  307. )
  308. subparser.add_argument(
  309. "--use_doc_unwarping",
  310. type=str2bool,
  311. help="Whether to use text image unwarping.",
  312. )
  313. subparser.add_argument(
  314. "--use_layout_detection",
  315. type=str2bool,
  316. help="Whether to use layout detection.",
  317. )
  318. subparser.add_argument(
  319. "--use_chart_recognition",
  320. type=str2bool,
  321. help="Whether to use chart recognition.",
  322. )
  323. subparser.add_argument(
  324. "--format_block_content",
  325. type=str2bool,
  326. help="Whether to format block content to Markdown.",
  327. )
  328. subparser.add_argument(
  329. "--use_queues",
  330. type=str2bool,
  331. help="Whether to use queues for asynchronous processing.",
  332. )
  333. subparser.add_argument(
  334. "--prompt_label",
  335. type=str,
  336. help="Prompt label for the VLM.",
  337. )
  338. subparser.add_argument(
  339. "--repetition_penalty",
  340. type=float,
  341. help="Repetition penalty used in sampling for the VLM.",
  342. )
  343. subparser.add_argument(
  344. "--temperature",
  345. type=float,
  346. help="Temperature parameter used in sampling for the VLM.",
  347. )
  348. subparser.add_argument(
  349. "--top_p",
  350. type=float,
  351. help="Top-p parameter used in sampling for the VLM.",
  352. )
  353. subparser.add_argument(
  354. "--min_pixels",
  355. type=int,
  356. help="Minimum pixels for image preprocessing for the VLM.",
  357. )
  358. subparser.add_argument(
  359. "--max_pixels",
  360. type=int,
  361. help="Maximum pixels for image preprocessing for the VLM.",
  362. )
  363. def execute_with_args(self, args):
  364. params = get_subcommand_args(args)
  365. perform_simple_inference(
  366. PaddleOCRVL,
  367. params,
  368. predict_param_names={
  369. "use_queues",
  370. "prompt_label",
  371. "repetition_penalty",
  372. "temperature",
  373. "top_p",
  374. "min_pixels",
  375. "max_pixels",
  376. },
  377. )