utility.py 10 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298
  1. # copyright (c) 2020 PaddlePaddle Authors. All Rights Reserve.
  2. #
  3. # Licensed under the Apache License, Version 2.0 (the "License");
  4. # you may not use this file except in compliance with the License.
  5. # You may obtain a copy of the License at
  6. #
  7. # http://www.apache.org/licenses/LICENSE-2.0
  8. #
  9. # Unless required by applicable law or agreed to in writing, software
  10. # distributed under the License is distributed on an "AS IS" BASIS,
  11. # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  12. # See the License for the specific language governing permissions and
  13. # limitations under the License.
  14. import random
  15. import ast
  16. import PIL
  17. from PIL import Image, ImageDraw, ImageFont
  18. import numpy as np
  19. from tools.infer.utility import (
  20. draw_ocr_box_txt,
  21. str2bool,
  22. str2int_tuple,
  23. init_args as infer_args,
  24. )
  25. import math
  26. def init_args():
  27. parser = infer_args()
  28. # params for output
  29. parser.add_argument("--output", type=str, default="./output")
  30. # params for table structure
  31. parser.add_argument("--table_max_len", type=int, default=488)
  32. parser.add_argument("--table_algorithm", type=str, default="TableAttn")
  33. parser.add_argument("--table_model_dir", type=str)
  34. parser.add_argument("--merge_no_span_structure", type=str2bool, default=True)
  35. parser.add_argument(
  36. "--table_char_dict_path",
  37. type=str,
  38. default="../ppocr/utils/dict/table_structure_dict_ch.txt",
  39. )
  40. # params for formula recognition
  41. parser.add_argument("--formula_algorithm", type=str, default="LaTeXOCR")
  42. parser.add_argument("--formula_model_dir", type=str)
  43. parser.add_argument(
  44. "--formula_char_dict_path",
  45. type=str,
  46. default="../ppocr/utils/dict/latex_ocr_tokenizer.json",
  47. )
  48. parser.add_argument("--formula_batch_num", type=int, default=1)
  49. # params for layout
  50. parser.add_argument("--layout_model_dir", type=str)
  51. parser.add_argument(
  52. "--layout_dict_path",
  53. type=str,
  54. default="../ppocr/utils/dict/layout_dict/layout_publaynet_dict.txt",
  55. )
  56. parser.add_argument(
  57. "--layout_score_threshold", type=float, default=0.5, help="Threshold of score."
  58. )
  59. parser.add_argument(
  60. "--layout_nms_threshold", type=float, default=0.5, help="Threshold of nms."
  61. )
  62. # params for kie
  63. parser.add_argument("--kie_algorithm", type=str, default="LayoutXLM")
  64. parser.add_argument("--ser_model_dir", type=str)
  65. parser.add_argument("--re_model_dir", type=str)
  66. parser.add_argument("--use_visual_backbone", type=str2bool, default=True)
  67. parser.add_argument(
  68. "--ser_dict_path", type=str, default="../train_data/XFUND/class_list_xfun.txt"
  69. )
  70. # need to be None or tb-yx
  71. parser.add_argument("--ocr_order_method", type=str, default=None)
  72. # params for inference
  73. parser.add_argument(
  74. "--mode",
  75. type=str,
  76. choices=["structure", "kie"],
  77. default="structure",
  78. help="structure and kie is supported",
  79. )
  80. parser.add_argument(
  81. "--image_orientation",
  82. type=bool,
  83. default=False,
  84. help="Whether to enable image orientation recognition",
  85. )
  86. parser.add_argument(
  87. "--layout",
  88. type=str2bool,
  89. default=True,
  90. help="Whether to enable layout analysis",
  91. )
  92. parser.add_argument(
  93. "--table",
  94. type=str2bool,
  95. default=True,
  96. help="In the forward, whether the table area uses table recognition",
  97. )
  98. parser.add_argument(
  99. "--formula",
  100. type=str2bool,
  101. default=False,
  102. help="Whether to enable formula recognition",
  103. )
  104. parser.add_argument(
  105. "--ocr",
  106. type=str2bool,
  107. default=True,
  108. help="In the forward, whether the non-table area is recognition by ocr",
  109. )
  110. # param for recovery
  111. parser.add_argument(
  112. "--recovery",
  113. type=str2bool,
  114. default=False,
  115. help="Whether to enable layout of recovery",
  116. )
  117. parser.add_argument(
  118. "--recovery_to_markdown",
  119. type=str2bool,
  120. default=False,
  121. help="Whether to enable layout of recovery to markdown",
  122. )
  123. parser.add_argument(
  124. "--use_pdf2docx_api",
  125. type=str2bool,
  126. default=False,
  127. help="Whether to use pdf2docx api",
  128. )
  129. parser.add_argument(
  130. "--invert",
  131. type=str2bool,
  132. default=False,
  133. help="Whether to invert image before processing",
  134. )
  135. parser.add_argument(
  136. "--binarize",
  137. type=str2bool,
  138. default=False,
  139. help="Whether to threshold binarize image before processing",
  140. )
  141. parser.add_argument(
  142. "--alphacolor",
  143. type=str2int_tuple,
  144. default=(255, 255, 255),
  145. help="Replacement color for the alpha channel, if the latter is present; R,G,B integers",
  146. )
  147. return parser
  148. def parse_args():
  149. parser = init_args()
  150. return parser.parse_args()
  151. def draw_structure_result(image, result, font_path):
  152. if isinstance(image, np.ndarray):
  153. image = Image.fromarray(image)
  154. boxes, txts, scores = [], [], []
  155. img_layout = image.copy()
  156. draw_layout = ImageDraw.Draw(img_layout)
  157. text_color = (255, 255, 255)
  158. text_background_color = (80, 127, 255)
  159. catid2color = {}
  160. font_size = 15
  161. font = ImageFont.truetype(font_path, font_size, encoding="utf-8")
  162. for region in result:
  163. if region["type"] not in catid2color:
  164. box_color = (
  165. random.randint(0, 255),
  166. random.randint(0, 255),
  167. random.randint(0, 255),
  168. )
  169. catid2color[region["type"]] = box_color
  170. else:
  171. box_color = catid2color[region["type"]]
  172. box_layout = region["bbox"]
  173. draw_layout.rectangle(
  174. [(box_layout[0], box_layout[1]), (box_layout[2], box_layout[3])],
  175. outline=box_color,
  176. width=3,
  177. )
  178. if int(PIL.__version__.split(".")[0]) < 10:
  179. text_w, text_h = font.getsize(region["type"])
  180. else:
  181. left, top, right, bottom = font.getbbox(region["type"])
  182. text_w, text_h = right - left, bottom - top
  183. draw_layout.rectangle(
  184. [
  185. (box_layout[0], box_layout[1]),
  186. (box_layout[0] + text_w, box_layout[1] + text_h),
  187. ],
  188. fill=text_background_color,
  189. )
  190. draw_layout.text(
  191. (box_layout[0], box_layout[1]), region["type"], fill=text_color, font=font
  192. )
  193. if region["type"] == "table" or (
  194. region["type"] == "equation" and "latex" in region["res"]
  195. ):
  196. pass
  197. else:
  198. for text_result in region["res"]:
  199. boxes.append(np.array(text_result["text_region"]))
  200. txts.append(text_result["text"])
  201. scores.append(text_result["confidence"])
  202. if "text_word_region" in text_result:
  203. for word_region in text_result["text_word_region"]:
  204. char_box = word_region
  205. box_height = int(
  206. math.sqrt(
  207. (char_box[0][0] - char_box[3][0]) ** 2
  208. + (char_box[0][1] - char_box[3][1]) ** 2
  209. )
  210. )
  211. box_width = int(
  212. math.sqrt(
  213. (char_box[0][0] - char_box[1][0]) ** 2
  214. + (char_box[0][1] - char_box[1][1]) ** 2
  215. )
  216. )
  217. if box_height == 0 or box_width == 0:
  218. continue
  219. boxes.append(word_region)
  220. txts.append("")
  221. scores.append(1.0)
  222. im_show = draw_ocr_box_txt(
  223. img_layout, boxes, txts, scores, font_path=font_path, drop_score=0
  224. )
  225. return im_show
  226. def cal_ocr_word_box(rec_str, box, rec_word_info):
  227. """Calculate the detection frame for each word based on the results of recognition and detection of ocr"""
  228. col_num, word_list, word_col_list, state_list = rec_word_info
  229. box = box.tolist()
  230. bbox_x_start = box[0][0]
  231. bbox_x_end = box[1][0]
  232. bbox_y_start = box[0][1]
  233. bbox_y_end = box[2][1]
  234. cell_width = (bbox_x_end - bbox_x_start) / col_num
  235. word_box_list = []
  236. word_box_content_list = []
  237. cn_width_list = []
  238. cn_col_list = []
  239. for word, word_col, state in zip(word_list, word_col_list, state_list):
  240. if state == "cn":
  241. if len(word_col) != 1:
  242. char_seq_length = (word_col[-1] - word_col[0] + 1) * cell_width
  243. char_width = char_seq_length / (len(word_col) - 1)
  244. cn_width_list.append(char_width)
  245. cn_col_list += word_col
  246. word_box_content_list += word
  247. else:
  248. cell_x_start = bbox_x_start + int(word_col[0] * cell_width)
  249. cell_x_end = bbox_x_start + int((word_col[-1] + 1) * cell_width)
  250. cell = (
  251. (cell_x_start, bbox_y_start),
  252. (cell_x_end, bbox_y_start),
  253. (cell_x_end, bbox_y_end),
  254. (cell_x_start, bbox_y_end),
  255. )
  256. word_box_list.append(cell)
  257. word_box_content_list.append("".join(word))
  258. if len(cn_col_list) != 0:
  259. if len(cn_width_list) != 0:
  260. avg_char_width = np.mean(cn_width_list)
  261. else:
  262. avg_char_width = (bbox_x_end - bbox_x_start) / len(rec_str)
  263. for center_idx in cn_col_list:
  264. center_x = (center_idx + 0.5) * cell_width
  265. cell_x_start = max(int(center_x - avg_char_width / 2), 0) + bbox_x_start
  266. cell_x_end = (
  267. min(int(center_x + avg_char_width / 2), bbox_x_end - bbox_x_start)
  268. + bbox_x_start
  269. )
  270. cell = (
  271. (cell_x_start, bbox_y_start),
  272. (cell_x_end, bbox_y_start),
  273. (cell_x_end, bbox_y_end),
  274. (cell_x_start, bbox_y_end),
  275. )
  276. word_box_list.append(cell)
  277. return word_box_content_list, word_box_list