table_master_match.py 36 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525526527528529530531532533534535536537538539540541542543544545546547548549550551552553554555556557558559560561562563564565566567568569570571572573574575576577578579580581582583584585586587588589590591592593594595596597598599600601602603604605606607608609610611612613614615616617618619620621622623624625626627628629630631632633634635636637638639640641642643644645646647648649650651652653654655656657658659660661662663664665666667668669670671672673674675676677678679680681682683684685686687688689690691692693694695696697698699700701702703704705706707708709710711712713714715716717718719720721722723724725726727728729730731732733734735736737738739740741742743744745746747748749750751752753754755756757758759760761762763764765766767768769770771772773774775776777778779780781782783784785786787788789790791792793794795796797798799800801802803804805806807808809810811812813814815816817818819820821822823824825826827828829830831832833834835836837838839840841842843844845846847848849850851852853854855856857858859860861862863864865866867868869870871872873874875876877878879880881882883884885886887888889890891892893894895896897898899900901902903904905906907908909910911912913914915916917918919920921922923924925926927928929930931932933934935936937938939940941942943944945946947948949950951952953954955956957958959960961962963964965966967968969970971972973974975976977978979980981982983984985986987988989990991992993994995
  1. # copyright (c) 2022 PaddlePaddle Authors. All Rights Reserve.
  2. #
  3. # Licensed under the Apache License, Version 2.0 (the "License");
  4. # you may not use this file except in compliance with the License.
  5. # You may obtain a copy of the License at
  6. #
  7. # http://www.apache.org/licenses/LICENSE-2.0
  8. #
  9. # Unless required by applicable law or agreed to in writing, software
  10. # distributed under the License is distributed on an "AS IS" BASIS,
  11. # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  12. # See the License for the specific language governing permissions and
  13. # limitations under the License.
  14. """
  15. This code is refer from:
  16. https://github.com/JiaquanYe/TableMASTER-mmocr/blob/master/table_recognition/match.py
  17. """
  18. import os
  19. import re
  20. import cv2
  21. import glob
  22. import copy
  23. import math
  24. import pickle
  25. import numpy as np
  26. from shapely.geometry import Polygon, MultiPoint
  27. """
  28. Useful function in matching.
  29. """
  30. def remove_empty_bboxes(bboxes):
  31. """
  32. remove [0., 0., 0., 0.] in structure master bboxes.
  33. len(bboxes.shape) must be 2.
  34. :param bboxes:
  35. :return:
  36. """
  37. new_bboxes = []
  38. for bbox in bboxes:
  39. if sum(bbox) == 0.0:
  40. continue
  41. new_bboxes.append(bbox)
  42. return np.array(new_bboxes)
  43. def xywh2xyxy(bboxes):
  44. if len(bboxes.shape) == 1:
  45. new_bboxes = np.empty_like(bboxes)
  46. new_bboxes[0] = bboxes[0] - bboxes[2] / 2
  47. new_bboxes[1] = bboxes[1] - bboxes[3] / 2
  48. new_bboxes[2] = bboxes[0] + bboxes[2] / 2
  49. new_bboxes[3] = bboxes[1] + bboxes[3] / 2
  50. return new_bboxes
  51. elif len(bboxes.shape) == 2:
  52. new_bboxes = np.empty_like(bboxes)
  53. new_bboxes[:, 0] = bboxes[:, 0] - bboxes[:, 2] / 2
  54. new_bboxes[:, 1] = bboxes[:, 1] - bboxes[:, 3] / 2
  55. new_bboxes[:, 2] = bboxes[:, 0] + bboxes[:, 2] / 2
  56. new_bboxes[:, 3] = bboxes[:, 1] + bboxes[:, 3] / 2
  57. return new_bboxes
  58. else:
  59. raise ValueError
  60. def xyxy2xywh(bboxes):
  61. if len(bboxes.shape) == 1:
  62. new_bboxes = np.empty_like(bboxes)
  63. new_bboxes[0] = bboxes[0] + (bboxes[2] - bboxes[0]) / 2
  64. new_bboxes[1] = bboxes[1] + (bboxes[3] - bboxes[1]) / 2
  65. new_bboxes[2] = bboxes[2] - bboxes[0]
  66. new_bboxes[3] = bboxes[3] - bboxes[1]
  67. return new_bboxes
  68. elif len(bboxes.shape) == 2:
  69. new_bboxes = np.empty_like(bboxes)
  70. new_bboxes[:, 0] = bboxes[:, 0] + (bboxes[:, 2] - bboxes[:, 0]) / 2
  71. new_bboxes[:, 1] = bboxes[:, 1] + (bboxes[:, 3] - bboxes[:, 1]) / 2
  72. new_bboxes[:, 2] = bboxes[:, 2] - bboxes[:, 0]
  73. new_bboxes[:, 3] = bboxes[:, 3] - bboxes[:, 1]
  74. return new_bboxes
  75. else:
  76. raise ValueError
  77. def pickle_load(path, prefix="end2end"):
  78. if os.path.isfile(path):
  79. data = pickle.load(open(path, "rb"))
  80. elif os.path.isdir(path):
  81. data = dict()
  82. search_path = os.path.join(path, "{}_*.pkl".format(prefix))
  83. pkls = glob.glob(search_path)
  84. for pkl in pkls:
  85. this_data = pickle.load(open(pkl, "rb"))
  86. data.update(this_data)
  87. else:
  88. raise ValueError
  89. return data
  90. def convert_coord(xyxy):
  91. """
  92. Convert two points format to four points format.
  93. :param xyxy:
  94. :return:
  95. """
  96. new_bbox = np.zeros([4, 2], dtype=np.float32)
  97. new_bbox[0, 0], new_bbox[0, 1] = xyxy[0], xyxy[1]
  98. new_bbox[1, 0], new_bbox[1, 1] = xyxy[2], xyxy[1]
  99. new_bbox[2, 0], new_bbox[2, 1] = xyxy[2], xyxy[3]
  100. new_bbox[3, 0], new_bbox[3, 1] = xyxy[0], xyxy[3]
  101. return new_bbox
  102. def cal_iou(bbox1, bbox2):
  103. bbox1_poly = Polygon(bbox1).convex_hull
  104. bbox2_poly = Polygon(bbox2).convex_hull
  105. union_poly = np.concatenate((bbox1, bbox2))
  106. if not bbox1_poly.intersects(bbox2_poly):
  107. iou = 0
  108. else:
  109. inter_area = bbox1_poly.intersection(bbox2_poly).area
  110. union_area = MultiPoint(union_poly).convex_hull.area
  111. if union_area == 0:
  112. iou = 0
  113. else:
  114. iou = float(inter_area) / union_area
  115. return iou
  116. def cal_distance(p1, p2):
  117. delta_x = p1[0] - p2[0]
  118. delta_y = p1[1] - p2[1]
  119. d = math.sqrt((delta_x**2) + (delta_y**2))
  120. return d
  121. def is_inside(center_point, corner_point):
  122. """
  123. Find if center_point inside the bbox(corner_point) or not.
  124. :param center_point: center point (x, y)
  125. :param corner_point: corner point ((x1,y1),(x2,y2))
  126. :return:
  127. """
  128. x_flag = False
  129. y_flag = False
  130. if (center_point[0] >= corner_point[0][0]) and (
  131. center_point[0] <= corner_point[1][0]
  132. ):
  133. x_flag = True
  134. if (center_point[1] >= corner_point[0][1]) and (
  135. center_point[1] <= corner_point[1][1]
  136. ):
  137. y_flag = True
  138. if x_flag and y_flag:
  139. return True
  140. else:
  141. return False
  142. def find_no_match(match_list, all_end2end_nums, type="end2end"):
  143. """
  144. Find out no match end2end bbox in previous match list.
  145. :param match_list: matching pairs.
  146. :param all_end2end_nums: numbers of end2end_xywh
  147. :param type: 'end2end' corresponding to idx 0, 'master' corresponding to idx 1.
  148. :return: no match pse bbox index list
  149. """
  150. if type == "end2end":
  151. idx = 0
  152. elif type == "master":
  153. idx = 1
  154. else:
  155. raise ValueError
  156. no_match_indexs = []
  157. # m[0] is end2end index m[1] is master index
  158. matched_bbox_indexs = [m[idx] for m in match_list]
  159. for n in range(all_end2end_nums):
  160. if n not in matched_bbox_indexs:
  161. no_match_indexs.append(n)
  162. return no_match_indexs
  163. def is_abs_lower_than_threshold(this_bbox, target_bbox, threshold=3):
  164. # only consider y axis, for grouping in row.
  165. delta = abs(this_bbox[1] - target_bbox[1])
  166. if delta < threshold:
  167. return True
  168. else:
  169. return False
  170. def sort_line_bbox(g, bg):
  171. """
  172. Sorted the bbox in the same line(group)
  173. compare coord 'x' value, where 'y' value is closed in the same group.
  174. :param g: index in the same group
  175. :param bg: bbox in the same group
  176. :return:
  177. """
  178. xs = [bg_item[0] for bg_item in bg]
  179. xs_sorted = sorted(xs)
  180. g_sorted = [None] * len(xs_sorted)
  181. bg_sorted = [None] * len(xs_sorted)
  182. for g_item, bg_item in zip(g, bg):
  183. idx = xs_sorted.index(bg_item[0])
  184. bg_sorted[idx] = bg_item
  185. g_sorted[idx] = g_item
  186. return g_sorted, bg_sorted
  187. def flatten(sorted_groups, sorted_bbox_groups):
  188. idxs = []
  189. bboxes = []
  190. for group, bbox_group in zip(sorted_groups, sorted_bbox_groups):
  191. for g, bg in zip(group, bbox_group):
  192. idxs.append(g)
  193. bboxes.append(bg)
  194. return idxs, bboxes
  195. def sort_bbox(end2end_xywh_bboxes, no_match_end2end_indexes):
  196. """
  197. This function will group the render end2end bboxes in row.
  198. :param end2end_xywh_bboxes:
  199. :param no_match_end2end_indexes:
  200. :return:
  201. """
  202. groups = []
  203. bbox_groups = []
  204. for index, end2end_xywh_bbox in zip(no_match_end2end_indexes, end2end_xywh_bboxes):
  205. this_bbox = end2end_xywh_bbox
  206. if len(groups) == 0:
  207. groups.append([index])
  208. bbox_groups.append([this_bbox])
  209. else:
  210. flag = False
  211. for g, bg in zip(groups, bbox_groups):
  212. # this_bbox is belong to bg's row or not
  213. if is_abs_lower_than_threshold(this_bbox, bg[0]):
  214. g.append(index)
  215. bg.append(this_bbox)
  216. flag = True
  217. break
  218. if not flag:
  219. # this_bbox is not belong to bg's row, create a row.
  220. groups.append([index])
  221. bbox_groups.append([this_bbox])
  222. # sorted bboxes in a group
  223. tmp_groups, tmp_bbox_groups = [], []
  224. for g, bg in zip(groups, bbox_groups):
  225. g_sorted, bg_sorted = sort_line_bbox(g, bg)
  226. tmp_groups.append(g_sorted)
  227. tmp_bbox_groups.append(bg_sorted)
  228. # sorted groups, sort by coord y's value.
  229. sorted_groups = [None] * len(tmp_groups)
  230. sorted_bbox_groups = [None] * len(tmp_bbox_groups)
  231. ys = [bg[0][1] for bg in tmp_bbox_groups]
  232. sorted_ys = sorted(ys)
  233. for g, bg in zip(tmp_groups, tmp_bbox_groups):
  234. idx = sorted_ys.index(bg[0][1])
  235. sorted_groups[idx] = g
  236. sorted_bbox_groups[idx] = bg
  237. # flatten, get final result
  238. end2end_sorted_idx_list, end2end_sorted_bbox_list = flatten(
  239. sorted_groups, sorted_bbox_groups
  240. )
  241. return (
  242. end2end_sorted_idx_list,
  243. end2end_sorted_bbox_list,
  244. sorted_groups,
  245. sorted_bbox_groups,
  246. )
  247. def get_bboxes_list(end2end_result, structure_master_result):
  248. """
  249. This function is use to convert end2end results and structure master results to
  250. List of xyxy bbox format and List of xywh bbox format
  251. :param end2end_result: bbox's format is xyxy
  252. :param structure_master_result: bbox's format is xywh
  253. :return: 4 kind list of bbox ()
  254. """
  255. # end2end
  256. end2end_xyxy_list = []
  257. end2end_xywh_list = []
  258. for end2end_item in end2end_result:
  259. src_bbox = end2end_item["bbox"]
  260. end2end_xyxy_list.append(src_bbox)
  261. xywh_bbox = xyxy2xywh(src_bbox)
  262. end2end_xywh_list.append(xywh_bbox)
  263. end2end_xyxy_bboxes = np.array(end2end_xyxy_list)
  264. end2end_xywh_bboxes = np.array(end2end_xywh_list)
  265. # structure master
  266. src_bboxes = structure_master_result["bbox"]
  267. src_bboxes = remove_empty_bboxes(src_bboxes)
  268. structure_master_xyxy_bboxes = src_bboxes
  269. xywh_bbox = xyxy2xywh(src_bboxes)
  270. structure_master_xywh_bboxes = xywh_bbox
  271. return (
  272. end2end_xyxy_bboxes,
  273. end2end_xywh_bboxes,
  274. structure_master_xywh_bboxes,
  275. structure_master_xyxy_bboxes,
  276. )
  277. def center_rule_match(end2end_xywh_bboxes, structure_master_xyxy_bboxes):
  278. """
  279. Judge end2end Bbox's center point is inside structure master Bbox or not,
  280. if end2end Bbox's center is in structure master Bbox, get matching pair.
  281. :param end2end_xywh_bboxes:
  282. :param structure_master_xyxy_bboxes:
  283. :return: match pairs list, e.g. [[0,1], [1,2], ...]
  284. """
  285. match_pairs_list = []
  286. for i, end2end_xywh in enumerate(end2end_xywh_bboxes):
  287. for j, master_xyxy in enumerate(structure_master_xyxy_bboxes):
  288. x_end2end, y_end2end = end2end_xywh[0], end2end_xywh[1]
  289. x_master1, y_master1, x_master2, y_master2 = (
  290. master_xyxy[0],
  291. master_xyxy[1],
  292. master_xyxy[2],
  293. master_xyxy[3],
  294. )
  295. center_point_end2end = (x_end2end, y_end2end)
  296. corner_point_master = ((x_master1, y_master1), (x_master2, y_master2))
  297. if is_inside(center_point_end2end, corner_point_master):
  298. match_pairs_list.append([i, j])
  299. return match_pairs_list
  300. def iou_rule_match(
  301. end2end_xyxy_bboxes, end2end_xyxy_indexes, structure_master_xyxy_bboxes
  302. ):
  303. """
  304. Use iou to find matching list.
  305. choose max iou value bbox as match pair.
  306. :param end2end_xyxy_bboxes:
  307. :param end2end_xyxy_indexes: original end2end indexes.
  308. :param structure_master_xyxy_bboxes:
  309. :return: match pairs list, e.g. [[0,1], [1,2], ...]
  310. """
  311. match_pair_list = []
  312. for end2end_xyxy_index, end2end_xyxy in zip(
  313. end2end_xyxy_indexes, end2end_xyxy_bboxes
  314. ):
  315. max_iou = 0
  316. max_match = [None, None]
  317. for j, master_xyxy in enumerate(structure_master_xyxy_bboxes):
  318. end2end_4xy = convert_coord(end2end_xyxy)
  319. master_4xy = convert_coord(master_xyxy)
  320. iou = cal_iou(end2end_4xy, master_4xy)
  321. if iou > max_iou:
  322. max_match[0], max_match[1] = end2end_xyxy_index, j
  323. max_iou = iou
  324. if max_match[0] is None:
  325. # no match
  326. continue
  327. match_pair_list.append(max_match)
  328. return match_pair_list
  329. def distance_rule_match(end2end_indexes, end2end_bboxes, master_indexes, master_bboxes):
  330. """
  331. Get matching between no-match end2end bboxes and no-match master bboxes.
  332. Use min distance to match.
  333. This rule will only run (no-match end2end nums > 0) and (no-match master nums > 0)
  334. It will Return master_bboxes_nums match-pairs.
  335. :param end2end_indexes:
  336. :param end2end_bboxes:
  337. :param master_indexes:
  338. :param master_bboxes:
  339. :return: match_pairs list, e.g. [[0,1], [1,2], ...]
  340. """
  341. min_match_list = []
  342. for j, master_bbox in zip(master_indexes, master_bboxes):
  343. min_distance = np.inf
  344. min_match = [0, 0] # i, j
  345. for i, end2end_bbox in zip(end2end_indexes, end2end_bboxes):
  346. x_end2end, y_end2end = end2end_bbox[0], end2end_bbox[1]
  347. x_master, y_master = master_bbox[0], master_bbox[1]
  348. end2end_point = (x_end2end, y_end2end)
  349. master_point = (x_master, y_master)
  350. dist = cal_distance(master_point, end2end_point)
  351. if dist < min_distance:
  352. min_match[0], min_match[1] = i, j
  353. min_distance = dist
  354. min_match_list.append(min_match)
  355. return min_match_list
  356. def extra_match(no_match_end2end_indexes, master_bbox_nums):
  357. """
  358. This function will create some virtual master bboxes,
  359. and get match with the no match end2end indexes.
  360. :param no_match_end2end_indexes:
  361. :param master_bbox_nums:
  362. :return:
  363. """
  364. end_nums = len(no_match_end2end_indexes) + master_bbox_nums
  365. extra_match_list = []
  366. for i in range(master_bbox_nums, end_nums):
  367. end2end_index = no_match_end2end_indexes[i - master_bbox_nums]
  368. extra_match_list.append([end2end_index, i])
  369. return extra_match_list
  370. def get_match_dict(match_list):
  371. """
  372. Convert match_list to a dict, where key is master bbox's index, value is end2end bbox index.
  373. :param match_list:
  374. :return:
  375. """
  376. match_dict = dict()
  377. for match_pair in match_list:
  378. end2end_index, master_index = match_pair[0], match_pair[1]
  379. if master_index not in match_dict.keys():
  380. match_dict[master_index] = [end2end_index]
  381. else:
  382. match_dict[master_index].append(end2end_index)
  383. return match_dict
  384. def deal_successive_space(text):
  385. """
  386. deal successive space character for text
  387. 1. Replace ' '*3 with '<space>' which is real space is text
  388. 2. Remove ' ', which is split token, not true space
  389. 3. Replace '<space>' with ' ', to get real text
  390. :param text:
  391. :return:
  392. """
  393. text = text.replace(" " * 3, "<space>")
  394. text = text.replace(" ", "")
  395. text = text.replace("<space>", " ")
  396. return text
  397. def reduce_repeat_bb(text_list, break_token):
  398. """
  399. convert ['<b>Local</b>', '<b>government</b>', '<b>unit</b>'] to ['<b>Local government unit</b>']
  400. PS: maybe style <i>Local</i> is also exist, too. it can be processed like this.
  401. :param text_list:
  402. :param break_token:
  403. :return:
  404. """
  405. count = 0
  406. for text in text_list:
  407. if text.startswith("<b>"):
  408. count += 1
  409. if count == len(text_list):
  410. new_text_list = []
  411. for text in text_list:
  412. text = text.replace("<b>", "").replace("</b>", "")
  413. new_text_list.append(text)
  414. return ["<b>" + break_token.join(new_text_list) + "</b>"]
  415. else:
  416. return text_list
  417. def get_match_text_dict(match_dict, end2end_info, break_token=" "):
  418. match_text_dict = dict()
  419. for master_index, end2end_index_list in match_dict.items():
  420. text_list = [
  421. end2end_info[end2end_index]["text"] for end2end_index in end2end_index_list
  422. ]
  423. text_list = reduce_repeat_bb(text_list, break_token)
  424. text = break_token.join(text_list)
  425. match_text_dict[master_index] = text
  426. return match_text_dict
  427. def merge_span_token(master_token_list):
  428. """
  429. Merge the span style token (row span or col span).
  430. :param master_token_list:
  431. :return:
  432. """
  433. new_master_token_list = []
  434. pointer = 0
  435. if master_token_list[-1] != "</tbody>":
  436. master_token_list.append("</tbody>")
  437. while master_token_list[pointer] != "</tbody>":
  438. try:
  439. if master_token_list[pointer] == "<td":
  440. if master_token_list[pointer + 1].startswith(
  441. " colspan="
  442. ) or master_token_list[pointer + 1].startswith(" rowspan="):
  443. """
  444. example:
  445. pattern <td colspan="3">
  446. '<td' + 'colspan=" "' + '>' + '</td>'
  447. """
  448. tmp = "".join(master_token_list[pointer : pointer + 3 + 1])
  449. pointer += 4
  450. new_master_token_list.append(tmp)
  451. elif master_token_list[pointer + 2].startswith(
  452. " colspan="
  453. ) or master_token_list[pointer + 2].startswith(" rowspan="):
  454. """
  455. example:
  456. pattern <td rowspan="2" colspan="3">
  457. '<td' + 'rowspan=" "' + 'colspan=" "' + '>' + '</td>'
  458. """
  459. tmp = "".join(master_token_list[pointer : pointer + 4 + 1])
  460. pointer += 5
  461. new_master_token_list.append(tmp)
  462. else:
  463. new_master_token_list.append(master_token_list[pointer])
  464. pointer += 1
  465. else:
  466. new_master_token_list.append(master_token_list[pointer])
  467. pointer += 1
  468. except:
  469. print("Break in merge...")
  470. break
  471. new_master_token_list.append("</tbody>")
  472. return new_master_token_list
  473. def deal_eb_token(master_token):
  474. """
  475. post process with <eb></eb>, <eb1></eb1>, ...
  476. emptyBboxTokenDict = {
  477. "[]": '<eb></eb>',
  478. "[' ']": '<eb1></eb1>',
  479. "['<b>', ' ', '</b>']": '<eb2></eb2>',
  480. "['\\u2028', '\\u2028']": '<eb3></eb3>',
  481. "['<sup>', ' ', '</sup>']": '<eb4></eb4>',
  482. "['<b>', '</b>']": '<eb5></eb5>',
  483. "['<i>', ' ', '</i>']": '<eb6></eb6>',
  484. "['<b>', '<i>', '</i>', '</b>']": '<eb7></eb7>',
  485. "['<b>', '<i>', ' ', '</i>', '</b>']": '<eb8></eb8>',
  486. "['<i>', '</i>']": '<eb9></eb9>',
  487. "['<b>', ' ', '\\u2028', ' ', '\\u2028', ' ', '</b>']": '<eb10></eb10>',
  488. }
  489. :param master_token:
  490. :return:
  491. """
  492. master_token = master_token.replace("<eb></eb>", "<td></td>")
  493. master_token = master_token.replace("<eb1></eb1>", "<td> </td>")
  494. master_token = master_token.replace("<eb2></eb2>", "<td><b> </b></td>")
  495. master_token = master_token.replace("<eb3></eb3>", "<td>\u2028\u2028</td>")
  496. master_token = master_token.replace("<eb4></eb4>", "<td><sup> </sup></td>")
  497. master_token = master_token.replace("<eb5></eb5>", "<td><b></b></td>")
  498. master_token = master_token.replace("<eb6></eb6>", "<td><i> </i></td>")
  499. master_token = master_token.replace("<eb7></eb7>", "<td><b><i></i></b></td>")
  500. master_token = master_token.replace("<eb8></eb8>", "<td><b><i> </i></b></td>")
  501. master_token = master_token.replace("<eb9></eb9>", "<td><i></i></td>")
  502. master_token = master_token.replace(
  503. "<eb10></eb10>", "<td><b> \u2028 \u2028 </b></td>"
  504. )
  505. return master_token
  506. def insert_text_to_token(master_token_list, match_text_dict):
  507. """
  508. Insert OCR text result to structure token.
  509. :param master_token_list:
  510. :param match_text_dict:
  511. :return:
  512. """
  513. master_token_list = merge_span_token(master_token_list)
  514. merged_result_list = []
  515. text_count = 0
  516. for master_token in master_token_list:
  517. if master_token.startswith("<td"):
  518. if text_count > len(match_text_dict) - 1:
  519. text_count += 1
  520. continue
  521. elif text_count not in match_text_dict.keys():
  522. text_count += 1
  523. continue
  524. else:
  525. master_token = master_token.replace(
  526. "><", ">{}<".format(match_text_dict[text_count])
  527. )
  528. text_count += 1
  529. master_token = deal_eb_token(master_token)
  530. merged_result_list.append(master_token)
  531. return "".join(merged_result_list)
  532. def deal_isolate_span(thead_part):
  533. """
  534. Deal with isolate span cases in this function.
  535. It causes by wrong prediction in structure recognition model.
  536. eg. predict <td rowspan="2"></td> to <td></td> rowspan="2"></b></td>.
  537. :param thead_part:
  538. :return:
  539. """
  540. # 1. find out isolate span tokens.
  541. isolate_pattern = (
  542. '<td></td> rowspan="(\d)+" colspan="(\d)+"></b></td>|'
  543. '<td></td> colspan="(\d)+" rowspan="(\d)+"></b></td>|'
  544. '<td></td> rowspan="(\d)+"></b></td>|'
  545. '<td></td> colspan="(\d)+"></b></td>'
  546. )
  547. isolate_iter = re.finditer(isolate_pattern, thead_part)
  548. isolate_list = [i.group() for i in isolate_iter]
  549. # 2. find out span number, by step 1 results.
  550. span_pattern = (
  551. ' rowspan="(\d)+" colspan="(\d)+"|'
  552. ' colspan="(\d)+" rowspan="(\d)+"|'
  553. ' rowspan="(\d)+"|'
  554. ' colspan="(\d)+"'
  555. )
  556. corrected_list = []
  557. for isolate_item in isolate_list:
  558. span_part = re.search(span_pattern, isolate_item)
  559. spanStr_in_isolateItem = span_part.group()
  560. # 3. merge the span number into the span token format string.
  561. if spanStr_in_isolateItem is not None:
  562. corrected_item = "<td{}></td>".format(spanStr_in_isolateItem)
  563. corrected_list.append(corrected_item)
  564. else:
  565. corrected_list.append(None)
  566. # 4. replace original isolated token.
  567. for corrected_item, isolate_item in zip(corrected_list, isolate_list):
  568. if corrected_item is not None:
  569. thead_part = thead_part.replace(isolate_item, corrected_item)
  570. else:
  571. pass
  572. return thead_part
  573. def deal_duplicate_bb(thead_part):
  574. """
  575. Deal duplicate <b> or </b> after replace.
  576. Keep one <b></b> in a <td></td> token.
  577. :param thead_part:
  578. :return:
  579. """
  580. # 1. find out <td></td> in <thead></thead>.
  581. td_pattern = (
  582. '<td rowspan="(\d)+" colspan="(\d)+">(.+?)</td>|'
  583. '<td colspan="(\d)+" rowspan="(\d)+">(.+?)</td>|'
  584. '<td rowspan="(\d)+">(.+?)</td>|'
  585. '<td colspan="(\d)+">(.+?)</td>|'
  586. "<td>(.*?)</td>"
  587. )
  588. td_iter = re.finditer(td_pattern, thead_part)
  589. td_list = [t.group() for t in td_iter]
  590. # 2. is multiply <b></b> in <td></td> or not?
  591. new_td_list = []
  592. for td_item in td_list:
  593. if td_item.count("<b>") > 1 or td_item.count("</b>") > 1:
  594. # multiply <b></b> in <td></td> case.
  595. # 1. remove all <b></b>
  596. td_item = td_item.replace("<b>", "").replace("</b>", "")
  597. # 2. replace <tb> -> <tb><b>, </tb> -> </b></tb>.
  598. td_item = td_item.replace("<td>", "<td><b>").replace("</td>", "</b></td>")
  599. new_td_list.append(td_item)
  600. else:
  601. new_td_list.append(td_item)
  602. # 3. replace original thead part.
  603. for td_item, new_td_item in zip(td_list, new_td_list):
  604. thead_part = thead_part.replace(td_item, new_td_item)
  605. return thead_part
  606. def deal_bb(result_token):
  607. """
  608. In our opinion, <b></b> always occurs in <thead></thead> text's context.
  609. This function will find out all tokens in <thead></thead> and insert <b></b> by manual.
  610. :param result_token:
  611. :return:
  612. """
  613. # find out <thead></thead> parts.
  614. thead_pattern = "<thead>(.*?)</thead>"
  615. if re.search(thead_pattern, result_token) is None:
  616. return result_token
  617. thead_part = re.search(thead_pattern, result_token).group()
  618. origin_thead_part = copy.deepcopy(thead_part)
  619. # check "rowspan" or "colspan" occur in <thead></thead> parts or not .
  620. span_pattern = '<td rowspan="(\d)+" colspan="(\d)+">|<td colspan="(\d)+" rowspan="(\d)+">|<td rowspan="(\d)+">|<td colspan="(\d)+">'
  621. span_iter = re.finditer(span_pattern, thead_part)
  622. span_list = [s.group() for s in span_iter]
  623. has_span_in_head = True if len(span_list) > 0 else False
  624. if not has_span_in_head:
  625. # <thead></thead> not include "rowspan" or "colspan" branch 1.
  626. # 1. replace <td> to <td><b>, and </td> to </b></td>
  627. # 2. it is possible to predict text include <b> or </b> by Text-line recognition,
  628. # so we replace <b><b> to <b>, and </b></b> to </b>
  629. thead_part = (
  630. thead_part.replace("<td>", "<td><b>")
  631. .replace("</td>", "</b></td>")
  632. .replace("<b><b>", "<b>")
  633. .replace("</b></b>", "</b>")
  634. )
  635. else:
  636. # <thead></thead> include "rowspan" or "colspan" branch 2.
  637. # Firstly, we deal rowspan or colspan cases.
  638. # 1. replace > to ><b>
  639. # 2. replace </td> to </b></td>
  640. # 3. it is possible to predict text include <b> or </b> by Text-line recognition,
  641. # so we replace <b><b> to <b>, and </b><b> to </b>
  642. # Secondly, deal ordinary cases like branch 1
  643. # replace ">" to "<b>"
  644. replaced_span_list = []
  645. for sp in span_list:
  646. replaced_span_list.append(sp.replace(">", "><b>"))
  647. for sp, rsp in zip(span_list, replaced_span_list):
  648. thead_part = thead_part.replace(sp, rsp)
  649. # replace "</td>" to "</b></td>"
  650. thead_part = thead_part.replace("</td>", "</b></td>")
  651. # remove duplicated <b> by re.sub
  652. mb_pattern = "(<b>)+"
  653. single_b_string = "<b>"
  654. thead_part = re.sub(mb_pattern, single_b_string, thead_part)
  655. mgb_pattern = "(</b>)+"
  656. single_gb_string = "</b>"
  657. thead_part = re.sub(mgb_pattern, single_gb_string, thead_part)
  658. # ordinary cases like branch 1
  659. thead_part = thead_part.replace("<td>", "<td><b>").replace("<b><b>", "<b>")
  660. # convert <tb><b></b></tb> back to <tb></tb>, empty cell has no <b></b>.
  661. # but space cell(<tb> </tb>) is suitable for <td><b> </b></td>
  662. thead_part = thead_part.replace("<td><b></b></td>", "<td></td>")
  663. # deal with duplicated <b></b>
  664. thead_part = deal_duplicate_bb(thead_part)
  665. # deal with isolate span tokens, which causes by wrong predict by structure prediction.
  666. # eg.PMC5994107_011_00.png
  667. thead_part = deal_isolate_span(thead_part)
  668. # replace original result with new thead part.
  669. result_token = result_token.replace(origin_thead_part, thead_part)
  670. return result_token
  671. class Matcher:
  672. def __init__(self, end2end_file, structure_master_file):
  673. """
  674. This class process the end2end results and structure recognition results.
  675. :param end2end_file: end2end results predict by end2end inference.
  676. :param structure_master_file: structure recognition results predict by structure master inference.
  677. """
  678. self.end2end_file = end2end_file
  679. self.structure_master_file = structure_master_file
  680. self.end2end_results = pickle_load(end2end_file, prefix="end2end")
  681. self.structure_master_results = pickle_load(
  682. structure_master_file, prefix="structure"
  683. )
  684. def match(self):
  685. """
  686. Match process:
  687. pre-process : convert end2end and structure master results to xyxy, xywh ndnarray format.
  688. 1. Use pseBbox is inside masterBbox judge rule
  689. 2. Use iou between pseBbox and masterBbox rule
  690. 3. Use min distance of center point rule
  691. :return:
  692. """
  693. match_results = dict()
  694. for idx, (file_name, end2end_result) in enumerate(self.end2end_results.items()):
  695. match_list = []
  696. if file_name not in self.structure_master_results:
  697. continue
  698. structure_master_result = self.structure_master_results[file_name]
  699. (
  700. end2end_xyxy_bboxes,
  701. end2end_xywh_bboxes,
  702. structure_master_xywh_bboxes,
  703. structure_master_xyxy_bboxes,
  704. ) = get_bboxes_list(end2end_result, structure_master_result)
  705. # rule 1: center rule
  706. center_rule_match_list = center_rule_match(
  707. end2end_xywh_bboxes, structure_master_xyxy_bboxes
  708. )
  709. match_list.extend(center_rule_match_list)
  710. # rule 2: iou rule
  711. # firstly, find not match index in previous step.
  712. center_no_match_end2end_indexs = find_no_match(
  713. match_list, len(end2end_xywh_bboxes), type="end2end"
  714. )
  715. if len(center_no_match_end2end_indexs) > 0:
  716. center_no_match_end2end_xyxy = end2end_xyxy_bboxes[
  717. center_no_match_end2end_indexs
  718. ]
  719. # secondly, iou rule match
  720. iou_rule_match_list = iou_rule_match(
  721. center_no_match_end2end_xyxy,
  722. center_no_match_end2end_indexs,
  723. structure_master_xyxy_bboxes,
  724. )
  725. match_list.extend(iou_rule_match_list)
  726. # rule 3: distance rule
  727. # match between no-match end2end bboxes and no-match master bboxes.
  728. # it will return master_bboxes_nums match-pairs.
  729. # firstly, find not match index in previous step.
  730. centerIou_no_match_end2end_indexs = find_no_match(
  731. match_list, len(end2end_xywh_bboxes), type="end2end"
  732. )
  733. centerIou_no_match_master_indexs = find_no_match(
  734. match_list, len(structure_master_xywh_bboxes), type="master"
  735. )
  736. if (
  737. len(centerIou_no_match_master_indexs) > 0
  738. and len(centerIou_no_match_end2end_indexs) > 0
  739. ):
  740. centerIou_no_match_end2end_xywh = end2end_xywh_bboxes[
  741. centerIou_no_match_end2end_indexs
  742. ]
  743. centerIou_no_match_master_xywh = structure_master_xywh_bboxes[
  744. centerIou_no_match_master_indexs
  745. ]
  746. distance_match_list = distance_rule_match(
  747. centerIou_no_match_end2end_indexs,
  748. centerIou_no_match_end2end_xywh,
  749. centerIou_no_match_master_indexs,
  750. centerIou_no_match_master_xywh,
  751. )
  752. match_list.extend(distance_match_list)
  753. # TODO:
  754. # The render no-match pseBbox, insert the last
  755. # After step3 distance rule, a master bbox at least match one end2end bbox.
  756. # But end2end bbox maybe overmuch, because numbers of master bbox will cut by max length.
  757. # For these render end2end bboxes, we will make some virtual master bboxes, and get matching.
  758. # The above extra insert bboxes will be further processed in "formatOutput" function.
  759. # After this operation, it will increase TEDS score.
  760. no_match_end2end_indexes = find_no_match(
  761. match_list, len(end2end_xywh_bboxes), type="end2end"
  762. )
  763. if len(no_match_end2end_indexes) > 0:
  764. no_match_end2end_xywh = end2end_xywh_bboxes[no_match_end2end_indexes]
  765. # sort the render no-match end2end bbox in row
  766. (
  767. end2end_sorted_indexes_list,
  768. end2end_sorted_bboxes_list,
  769. sorted_groups,
  770. sorted_bboxes_groups,
  771. ) = sort_bbox(no_match_end2end_xywh, no_match_end2end_indexes)
  772. # make virtual master bboxes, and get matching with the no-match end2end bboxes.
  773. extra_match_list = extra_match(
  774. end2end_sorted_indexes_list, len(structure_master_xywh_bboxes)
  775. )
  776. match_list_add_extra_match = copy.deepcopy(match_list)
  777. match_list_add_extra_match.extend(extra_match_list)
  778. else:
  779. # no no-match end2end bboxes
  780. match_list_add_extra_match = copy.deepcopy(match_list)
  781. sorted_groups = []
  782. sorted_bboxes_groups = []
  783. match_result_dict = {
  784. "match_list": match_list,
  785. "match_list_add_extra_match": match_list_add_extra_match,
  786. "sorted_groups": sorted_groups,
  787. "sorted_bboxes_groups": sorted_bboxes_groups,
  788. }
  789. # format output
  790. match_result_dict = self._format(match_result_dict, file_name)
  791. match_results[file_name] = match_result_dict
  792. return match_results
  793. def _format(self, match_result, file_name):
  794. """
  795. Extend the master token(insert virtual master token), and format matching result.
  796. :param match_result:
  797. :param file_name:
  798. :return:
  799. """
  800. end2end_info = self.end2end_results[file_name]
  801. master_info = self.structure_master_results[file_name]
  802. master_token = master_info["text"]
  803. sorted_groups = match_result["sorted_groups"]
  804. # creat virtual master token
  805. virtual_master_token_list = []
  806. for line_group in sorted_groups:
  807. tmp_list = ["<tr>"]
  808. item_nums = len(line_group)
  809. for _ in range(item_nums):
  810. tmp_list.append("<td></td>")
  811. tmp_list.append("</tr>")
  812. virtual_master_token_list.extend(tmp_list)
  813. # insert virtual master token
  814. master_token_list = master_token.split(",")
  815. if master_token_list[-1] == "</tbody>":
  816. # complete predict(no cut by max length)
  817. # This situation insert virtual master token will drop TEDs score in val set.
  818. # So we will not extend virtual token in this situation.
  819. # fake extend virtual
  820. master_token_list[:-1].extend(virtual_master_token_list)
  821. # real extend virtual
  822. # master_token_list = master_token_list[:-1]
  823. # master_token_list.extend(virtual_master_token_list)
  824. # master_token_list.append('</tbody>')
  825. elif master_token_list[-1] == "<td></td>":
  826. master_token_list.append("</tr>")
  827. master_token_list.extend(virtual_master_token_list)
  828. master_token_list.append("</tbody>")
  829. else:
  830. master_token_list.extend(virtual_master_token_list)
  831. master_token_list.append("</tbody>")
  832. # format output
  833. match_result.setdefault("matched_master_token_list", master_token_list)
  834. return match_result
  835. def get_merge_result(self, match_results):
  836. """
  837. Merge the OCR result into structure token to get final results.
  838. :param match_results:
  839. :return:
  840. """
  841. merged_results = dict()
  842. # break_token is linefeed token, when one master bbox has multiply end2end bboxes.
  843. break_token = " "
  844. for idx, (file_name, match_info) in enumerate(match_results.items()):
  845. end2end_info = self.end2end_results[file_name]
  846. master_token_list = match_info["matched_master_token_list"]
  847. match_list = match_info["match_list_add_extra_match"]
  848. match_dict = get_match_dict(match_list)
  849. match_text_dict = get_match_text_dict(match_dict, end2end_info, break_token)
  850. merged_result = insert_text_to_token(master_token_list, match_text_dict)
  851. merged_result = deal_bb(merged_result)
  852. merged_results[file_name] = merged_result
  853. return merged_results
  854. class TableMasterMatcher(Matcher):
  855. def __init__(self):
  856. pass
  857. def __call__(self, structure_res, dt_boxes, rec_res, img_name=1):
  858. end2end_results = {img_name: []}
  859. for dt_box, res in zip(dt_boxes, rec_res):
  860. d = dict(
  861. bbox=np.array(dt_box),
  862. text=res[0],
  863. )
  864. end2end_results[img_name].append(d)
  865. self.end2end_results = end2end_results
  866. structure_master_result_dict = {img_name: {}}
  867. pred_structures, pred_bboxes = structure_res
  868. pred_structures = ",".join(pred_structures[3:-3])
  869. structure_master_result_dict[img_name]["text"] = pred_structures
  870. structure_master_result_dict[img_name]["bbox"] = pred_bboxes
  871. self.structure_master_results = structure_master_result_dict
  872. # match
  873. match_results = self.match()
  874. merged_results = self.get_merge_result(match_results)
  875. pred_html = merged_results[img_name]
  876. pred_html = "<html><body><table>" + pred_html + "</table></body></html>"
  877. return pred_html