table_recognition_v2.py 18 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438
  1. # Copyright (c) 2025 PaddlePaddle Authors. All Rights Reserved.
  2. #
  3. # Licensed under the Apache License, Version 2.0 (the "License");
  4. # you may not use this file except in compliance with the License.
  5. # You may obtain a copy of the License at
  6. #
  7. # http://www.apache.org/licenses/LICENSE-2.0
  8. #
  9. # Unless required by applicable law or agreed to in writing, software
  10. # distributed under the License is distributed on an "AS IS" BASIS,
  11. # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  12. # See the License for the specific language governing permissions and
  13. # limitations under the License.
  14. from .._utils.cli import (
  15. add_simple_inference_args,
  16. get_subcommand_args,
  17. perform_simple_inference,
  18. str2bool,
  19. )
  20. from .base import PaddleXPipelineWrapper, PipelineCLISubcommandExecutor
  21. from .utils import create_config_from_structure
  22. class TableRecognitionPipelineV2(PaddleXPipelineWrapper):
  23. def __init__(
  24. self,
  25. layout_detection_model_name=None,
  26. layout_detection_model_dir=None,
  27. table_classification_model_name=None,
  28. table_classification_model_dir=None,
  29. wired_table_structure_recognition_model_name=None,
  30. wired_table_structure_recognition_model_dir=None,
  31. wireless_table_structure_recognition_model_name=None,
  32. wireless_table_structure_recognition_model_dir=None,
  33. wired_table_cells_detection_model_name=None,
  34. wired_table_cells_detection_model_dir=None,
  35. wireless_table_cells_detection_model_name=None,
  36. wireless_table_cells_detection_model_dir=None,
  37. doc_orientation_classify_model_name=None,
  38. doc_orientation_classify_model_dir=None,
  39. doc_unwarping_model_name=None,
  40. doc_unwarping_model_dir=None,
  41. text_detection_model_name=None,
  42. text_detection_model_dir=None,
  43. text_det_limit_side_len=None,
  44. text_det_limit_type=None,
  45. text_det_thresh=None,
  46. text_det_box_thresh=None,
  47. text_det_unclip_ratio=None,
  48. text_recognition_model_name=None,
  49. text_recognition_model_dir=None,
  50. text_recognition_batch_size=None,
  51. text_rec_score_thresh=None,
  52. use_doc_orientation_classify=None,
  53. use_doc_unwarping=None,
  54. use_layout_detection=None,
  55. use_ocr_model=None,
  56. **kwargs,
  57. ):
  58. params = locals().copy()
  59. params.pop("self")
  60. params.pop("kwargs")
  61. self._params = params
  62. super().__init__(**kwargs)
  63. @property
  64. def _paddlex_pipeline_name(self):
  65. return "table_recognition_v2"
  66. def predict_iter(
  67. self,
  68. input,
  69. *,
  70. use_doc_orientation_classify=None,
  71. use_doc_unwarping=None,
  72. use_layout_detection=None,
  73. use_ocr_model=None,
  74. overall_ocr_res=None,
  75. layout_det_res=None,
  76. text_det_limit_side_len=None,
  77. text_det_limit_type=None,
  78. text_det_thresh=None,
  79. text_det_box_thresh=None,
  80. text_det_unclip_ratio=None,
  81. text_rec_score_thresh=None,
  82. use_e2e_wired_table_rec_model=False,
  83. use_e2e_wireless_table_rec_model=False,
  84. use_wired_table_cells_trans_to_html=False,
  85. use_wireless_table_cells_trans_to_html=False,
  86. use_table_orientation_classify=True,
  87. use_ocr_results_with_table_cells=True,
  88. **kwargs,
  89. ):
  90. return self.paddlex_pipeline.predict(
  91. input,
  92. use_doc_orientation_classify=use_doc_orientation_classify,
  93. use_doc_unwarping=use_doc_unwarping,
  94. use_layout_detection=use_layout_detection,
  95. use_ocr_model=use_ocr_model,
  96. overall_ocr_res=overall_ocr_res,
  97. layout_det_res=layout_det_res,
  98. text_det_limit_side_len=text_det_limit_side_len,
  99. text_det_limit_type=text_det_limit_type,
  100. text_det_thresh=text_det_thresh,
  101. text_det_box_thresh=text_det_box_thresh,
  102. text_det_unclip_ratio=text_det_unclip_ratio,
  103. text_rec_score_thresh=text_rec_score_thresh,
  104. use_e2e_wired_table_rec_model=use_e2e_wired_table_rec_model,
  105. use_e2e_wireless_table_rec_model=use_e2e_wireless_table_rec_model,
  106. use_wired_table_cells_trans_to_html=use_wired_table_cells_trans_to_html,
  107. use_wireless_table_cells_trans_to_html=use_wireless_table_cells_trans_to_html,
  108. use_table_orientation_classify=use_table_orientation_classify,
  109. use_ocr_results_with_table_cells=use_ocr_results_with_table_cells,
  110. **kwargs,
  111. )
  112. def predict(
  113. self,
  114. input,
  115. *,
  116. use_doc_orientation_classify=None,
  117. use_doc_unwarping=None,
  118. use_layout_detection=None,
  119. use_ocr_model=None,
  120. overall_ocr_res=None,
  121. layout_det_res=None,
  122. text_det_limit_side_len=None,
  123. text_det_limit_type=None,
  124. text_det_thresh=None,
  125. text_det_box_thresh=None,
  126. text_det_unclip_ratio=None,
  127. text_rec_score_thresh=None,
  128. use_e2e_wired_table_rec_model=False,
  129. use_e2e_wireless_table_rec_model=False,
  130. use_wired_table_cells_trans_to_html=False,
  131. use_wireless_table_cells_trans_to_html=False,
  132. use_table_orientation_classify=True,
  133. use_ocr_results_with_table_cells=True,
  134. **kwargs,
  135. ):
  136. return list(
  137. self.predict_iter(
  138. input,
  139. use_doc_orientation_classify=use_doc_orientation_classify,
  140. use_doc_unwarping=use_doc_unwarping,
  141. use_layout_detection=use_layout_detection,
  142. use_ocr_model=use_ocr_model,
  143. overall_ocr_res=overall_ocr_res,
  144. layout_det_res=layout_det_res,
  145. text_det_limit_side_len=text_det_limit_side_len,
  146. text_det_limit_type=text_det_limit_type,
  147. text_det_thresh=text_det_thresh,
  148. text_det_box_thresh=text_det_box_thresh,
  149. text_det_unclip_ratio=text_det_unclip_ratio,
  150. text_rec_score_thresh=text_rec_score_thresh,
  151. use_e2e_wired_table_rec_model=use_e2e_wired_table_rec_model,
  152. use_e2e_wireless_table_rec_model=use_e2e_wireless_table_rec_model,
  153. use_wired_table_cells_trans_to_html=use_wired_table_cells_trans_to_html,
  154. use_wireless_table_cells_trans_to_html=use_wireless_table_cells_trans_to_html,
  155. use_table_orientation_classify=use_table_orientation_classify,
  156. use_ocr_results_with_table_cells=use_ocr_results_with_table_cells,
  157. **kwargs,
  158. )
  159. )
  160. @classmethod
  161. def get_cli_subcommand_executor(cls):
  162. return TableRecognitionPipelineV2CLISubcommandExecutor()
  163. def _get_paddlex_config_overrides(self):
  164. STRUCTURE = {
  165. "SubPipelines.DocPreprocessor.use_doc_orientation_classify": self._params[
  166. "use_doc_orientation_classify"
  167. ],
  168. "SubPipelines.DocPreprocessor.use_doc_unwarping": self._params[
  169. "use_doc_unwarping"
  170. ],
  171. "use_doc_preprocessor": self._params["use_doc_orientation_classify"]
  172. or self._params["use_doc_unwarping"],
  173. "use_layout_detection": self._params["use_layout_detection"],
  174. "use_ocr_model": self._params["use_ocr_model"],
  175. "SubModules.LayoutDetection.model_name": self._params[
  176. "layout_detection_model_name"
  177. ],
  178. "SubModules.LayoutDetection.model_dir": self._params[
  179. "layout_detection_model_dir"
  180. ],
  181. "SubModules.TableClassification.model_name": self._params[
  182. "table_classification_model_name"
  183. ],
  184. "SubModules.TableClassification.model_dir": self._params[
  185. "table_classification_model_dir"
  186. ],
  187. "SubModules.WiredTableStructureRecognition.model_name": self._params[
  188. "wired_table_structure_recognition_model_name"
  189. ],
  190. "SubModules.WiredTableStructureRecognition.model_dir": self._params[
  191. "wired_table_structure_recognition_model_dir"
  192. ],
  193. "SubModules.WirelessTableStructureRecognition.model_name": self._params[
  194. "wireless_table_structure_recognition_model_name"
  195. ],
  196. "SubModules.WirelessTableStructureRecognition.model_dir": self._params[
  197. "wireless_table_structure_recognition_model_dir"
  198. ],
  199. "SubModules.WiredTableCellsDetection.model_name": self._params[
  200. "wired_table_cells_detection_model_name"
  201. ],
  202. "SubModules.WiredTableCellsDetection.model_dir": self._params[
  203. "wired_table_cells_detection_model_dir"
  204. ],
  205. "SubModules.WirelessTableCellsDetection.model_name": self._params[
  206. "wireless_table_cells_detection_model_name"
  207. ],
  208. "SubModules.WirelessTableCellsDetection.model_dir": self._params[
  209. "wireless_table_cells_detection_model_dir"
  210. ],
  211. "SubPipelines.DocPreprocessor.SubModules.DocOrientationClassify.model_name": self._params[
  212. "doc_orientation_classify_model_name"
  213. ],
  214. "SubPipelines.DocPreprocessor.SubModules.DocOrientationClassify.model_dir": self._params[
  215. "doc_orientation_classify_model_dir"
  216. ],
  217. "SubPipelines.DocPreprocessor.SubModules.DocUnwarping.model_name": self._params[
  218. "doc_unwarping_model_name"
  219. ],
  220. "SubPipelines.DocPreprocessor.SubModules.DocUnwarping.model_dir": self._params[
  221. "doc_unwarping_model_dir"
  222. ],
  223. "SubPipelines.GeneralOCR.SubModules.TextDetection.model_name": self._params[
  224. "text_detection_model_name"
  225. ],
  226. "SubPipelines.GeneralOCR.SubModules.TextDetection.model_dir": self._params[
  227. "text_detection_model_dir"
  228. ],
  229. "SubPipelines.GeneralOCR.SubModules.TextDetection.limit_side_len": self._params[
  230. "text_det_limit_side_len"
  231. ],
  232. "SubPipelines.GeneralOCR.SubModules.TextDetection.limit_type": self._params[
  233. "text_det_limit_type"
  234. ],
  235. "SubPipelines.GeneralOCR.SubModules.TextDetection.thresh": self._params[
  236. "text_det_thresh"
  237. ],
  238. "SubPipelines.GeneralOCR.SubModules.TextDetection.box_thresh": self._params[
  239. "text_det_box_thresh"
  240. ],
  241. "SubPipelines.GeneralOCR.SubModules.TextDetection.unclip_ratio": self._params[
  242. "text_det_unclip_ratio"
  243. ],
  244. "SubPipelines.GeneralOCR.SubModules.TextRecognition.model_name": self._params[
  245. "text_recognition_model_name"
  246. ],
  247. "SubPipelines.GeneralOCR.SubModules.TextRecognition.model_dir": self._params[
  248. "text_recognition_model_dir"
  249. ],
  250. "SubPipelines.GeneralOCR.SubModules.TextRecognition.batch_size": self._params[
  251. "text_recognition_batch_size"
  252. ],
  253. "SubPipelines.GeneralOCR.SubModules.TextRecognition.score_thresh": self._params[
  254. "text_rec_score_thresh"
  255. ],
  256. }
  257. return create_config_from_structure(STRUCTURE)
  258. class TableRecognitionPipelineV2CLISubcommandExecutor(PipelineCLISubcommandExecutor):
  259. @property
  260. def subparser_name(self):
  261. return "table_recognition_v2"
  262. def _update_subparser(self, subparser):
  263. add_simple_inference_args(subparser)
  264. subparser.add_argument(
  265. "--layout_detection_model_name",
  266. type=str,
  267. help="Name of the layout detection model.",
  268. )
  269. subparser.add_argument(
  270. "--layout_detection_model_dir",
  271. type=str,
  272. help="Path to the layout detection model directory.",
  273. )
  274. subparser.add_argument(
  275. "--table_classification_model_name",
  276. type=str,
  277. help="Name of the table classification model.",
  278. )
  279. subparser.add_argument(
  280. "--table_classification_model_dir",
  281. type=str,
  282. help="Path to the table classification model directory.",
  283. )
  284. subparser.add_argument(
  285. "--wired_table_structure_recognition_model_name",
  286. type=str,
  287. help="Name of the wired table structure recognition model.",
  288. )
  289. subparser.add_argument(
  290. "--wired_table_structure_recognition_model_dir",
  291. type=str,
  292. help="Path to the wired table structure recognition model directory.",
  293. )
  294. subparser.add_argument(
  295. "--wireless_table_structure_recognition_model_name",
  296. type=str,
  297. help="Name of the wireless table structure recognition model.",
  298. )
  299. subparser.add_argument(
  300. "--wireless_table_structure_recognition_model_dir",
  301. type=str,
  302. help="Path to the wired table structure recognition model directory.",
  303. )
  304. subparser.add_argument(
  305. "--wired_table_cells_detection_model_name",
  306. type=str,
  307. help="Name of the wired table cells detection model.",
  308. )
  309. subparser.add_argument(
  310. "--wired_table_cells_detection_model_dir",
  311. type=str,
  312. help="Path to the wired table cells detection model directory.",
  313. )
  314. subparser.add_argument(
  315. "--wireless_table_cells_detection_model_name",
  316. type=str,
  317. help="Name of the wireless table cells detection model.",
  318. )
  319. subparser.add_argument(
  320. "--wireless_table_cells_detection_model_dir",
  321. type=str,
  322. help="Path to the wireless table cells detection model directory.",
  323. )
  324. subparser.add_argument(
  325. "--doc_orientation_classify_model_name",
  326. type=str,
  327. help="Name of the document image orientation classification model.",
  328. )
  329. subparser.add_argument(
  330. "--doc_orientation_classify_model_dir",
  331. type=str,
  332. help="Path to the document image orientation classification model directory.",
  333. )
  334. subparser.add_argument(
  335. "--doc_unwarping_model_name",
  336. type=str,
  337. help="Name of the text image unwarping model.",
  338. )
  339. subparser.add_argument(
  340. "--doc_unwarping_model_dir",
  341. type=str,
  342. help="Path to the image unwarping model directory.",
  343. )
  344. subparser.add_argument(
  345. "--text_detection_model_name",
  346. type=str,
  347. help="Name of the text detection model.",
  348. )
  349. subparser.add_argument(
  350. "--text_detection_model_dir",
  351. type=str,
  352. help="Path to the text detection model directory.",
  353. )
  354. subparser.add_argument(
  355. "--text_det_limit_side_len",
  356. type=int,
  357. help="This sets a limit on the side length of the input image for the text detection model.",
  358. )
  359. subparser.add_argument(
  360. "--text_det_limit_type",
  361. type=str,
  362. help="This determines how the side length limit is applied to the input image before feeding it into the text deteciton model.",
  363. )
  364. subparser.add_argument(
  365. "--text_det_thresh",
  366. type=float,
  367. help="Detection pixel threshold for the text detection model. Pixels with scores greater than this threshold in the output probability map are considered text pixels.",
  368. )
  369. subparser.add_argument(
  370. "--text_det_box_thresh",
  371. type=float,
  372. help="Detection box threshold for the text detection model. A detection result is considered a text region if the average score of all pixels within the border of the result is greater than this threshold.",
  373. )
  374. subparser.add_argument(
  375. "--text_det_unclip_ratio",
  376. type=float,
  377. help="Text detection expansion coefficient, which expands the text region using this method. The larger the value, the larger the expansion area.",
  378. )
  379. subparser.add_argument(
  380. "--text_recognition_model_name",
  381. type=str,
  382. help="Name of the text recognition model.",
  383. )
  384. subparser.add_argument(
  385. "--text_recognition_model_dir",
  386. type=str,
  387. help="Path to the text recognition model directory.",
  388. )
  389. subparser.add_argument(
  390. "--text_recognition_batch_size",
  391. type=int,
  392. help="Batch size for the text recognition model.",
  393. )
  394. subparser.add_argument(
  395. "--text_rec_score_thresh",
  396. type=float,
  397. help="Text recognition threshold used in general OCR. Text results with scores greater than this threshold are retained.",
  398. )
  399. subparser.add_argument(
  400. "--use_doc_orientation_classify",
  401. type=str2bool,
  402. help="Whether to use document image orientation classification.",
  403. )
  404. subparser.add_argument(
  405. "--use_doc_unwarping",
  406. type=str2bool,
  407. help="Whether to use text image unwarping.",
  408. )
  409. subparser.add_argument(
  410. "--use_layout_detection",
  411. type=str2bool,
  412. help="Whether to use layout detection.",
  413. )
  414. subparser.add_argument(
  415. "--use_ocr_model",
  416. type=str2bool,
  417. help="Whether to use OCR models.",
  418. )
  419. def execute_with_args(self, args):
  420. params = get_subcommand_args(args)
  421. perform_simple_inference(TableRecognitionPipelineV2, params)