matcher.py 8.0 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206
  1. # copyright (c) 2022 PaddlePaddle Authors. All Rights Reserve.
  2. #
  3. # Licensed under the Apache License, Version 2.0 (the "License");
  4. # you may not use this file except in compliance with the License.
  5. # You may obtain a copy of the License at
  6. #
  7. # http://www.apache.org/licenses/LICENSE-2.0
  8. #
  9. # Unless required by applicable law or agreed to in writing, software
  10. # distributed under the License is distributed on an "AS IS" BASIS,
  11. # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  12. # See the License for the specific language governing permissions and
  13. # limitations under the License.
  14. import numpy as np
  15. from ppstructure.table.table_master_match import deal_eb_token, deal_bb
  16. import html
  17. def distance(box_1, box_2):
  18. x1, y1, x2, y2 = box_1
  19. x3, y3, x4, y4 = box_2
  20. dis = abs(x3 - x1) + abs(y3 - y1) + abs(x4 - x2) + abs(y4 - y2)
  21. dis_2 = abs(x3 - x1) + abs(y3 - y1)
  22. dis_3 = abs(x4 - x2) + abs(y4 - y2)
  23. return dis + min(dis_2, dis_3)
  24. def compute_iou(rec1, rec2):
  25. """
  26. computing IoU
  27. :param rec1: (y0, x0, y1, x1), which reflects
  28. (top, left, bottom, right)
  29. :param rec2: (y0, x0, y1, x1)
  30. :return: scala value of IoU
  31. """
  32. # computing area of each rectangles
  33. S_rec1 = (rec1[2] - rec1[0]) * (rec1[3] - rec1[1])
  34. S_rec2 = (rec2[2] - rec2[0]) * (rec2[3] - rec2[1])
  35. # computing the sum_area
  36. sum_area = S_rec1 + S_rec2
  37. # find the each edge of intersect rectangle
  38. left_line = max(rec1[1], rec2[1])
  39. right_line = min(rec1[3], rec2[3])
  40. top_line = max(rec1[0], rec2[0])
  41. bottom_line = min(rec1[2], rec2[2])
  42. # judge if there is an intersect
  43. if left_line >= right_line or top_line >= bottom_line:
  44. return 0.0
  45. else:
  46. intersect = (right_line - left_line) * (bottom_line - top_line)
  47. return (intersect / (sum_area - intersect)) * 1.0
  48. class TableMatch:
  49. def __init__(self, filter_ocr_result=False, use_master=False):
  50. self.filter_ocr_result = filter_ocr_result
  51. self.use_master = use_master
  52. def __call__(self, structure_res, dt_boxes, rec_res):
  53. pred_structures, pred_bboxes = structure_res
  54. if self.filter_ocr_result:
  55. dt_boxes, rec_res = self._filter_ocr_result(pred_bboxes, dt_boxes, rec_res)
  56. matched_index = self.match_result(dt_boxes, pred_bboxes)
  57. if self.use_master:
  58. pred_html, pred = self.get_pred_html_master(
  59. pred_structures, matched_index, rec_res
  60. )
  61. else:
  62. pred_html, pred = self.get_pred_html(
  63. pred_structures, matched_index, rec_res
  64. )
  65. return pred_html
  66. def match_result(self, dt_boxes, pred_bboxes):
  67. matched = {}
  68. for i, gt_box in enumerate(dt_boxes):
  69. distances = []
  70. for j, pred_box in enumerate(pred_bboxes):
  71. if len(pred_box) == 8:
  72. pred_box = [
  73. np.min(pred_box[0::2]),
  74. np.min(pred_box[1::2]),
  75. np.max(pred_box[0::2]),
  76. np.max(pred_box[1::2]),
  77. ]
  78. distances.append(
  79. (distance(gt_box, pred_box), 1.0 - compute_iou(gt_box, pred_box))
  80. ) # compute iou and l1 distance
  81. sorted_distances = distances.copy()
  82. # select det box by iou and l1 distance
  83. sorted_distances = sorted(
  84. sorted_distances, key=lambda item: (item[1], item[0])
  85. )
  86. if distances.index(sorted_distances[0]) not in matched.keys():
  87. matched[distances.index(sorted_distances[0])] = [i]
  88. else:
  89. matched[distances.index(sorted_distances[0])].append(i)
  90. return matched
  91. def get_pred_html(self, pred_structures, matched_index, ocr_contents):
  92. end_html = []
  93. td_index = 0
  94. for tag in pred_structures:
  95. if "</td>" in tag:
  96. if "<td></td>" == tag:
  97. end_html.extend("<td>")
  98. if td_index in matched_index.keys():
  99. b_with = False
  100. if (
  101. "<b>" in ocr_contents[matched_index[td_index][0]]
  102. and len(matched_index[td_index]) > 1
  103. ):
  104. b_with = True
  105. end_html.extend("<b>")
  106. for i, td_index_index in enumerate(matched_index[td_index]):
  107. content = ocr_contents[td_index_index][0]
  108. if len(matched_index[td_index]) > 1:
  109. if len(content) == 0:
  110. continue
  111. if content[0] == " ":
  112. content = content[1:]
  113. if "<b>" in content:
  114. content = content[3:]
  115. if "</b>" in content:
  116. content = content[:-4]
  117. if len(content) == 0:
  118. continue
  119. if (
  120. i != len(matched_index[td_index]) - 1
  121. and " " != content[-1]
  122. ):
  123. content += " "
  124. # escape content
  125. content = html.escape(content)
  126. end_html.extend(content)
  127. if b_with:
  128. end_html.extend("</b>")
  129. if "<td></td>" == tag:
  130. end_html.append("</td>")
  131. else:
  132. end_html.append(tag)
  133. td_index += 1
  134. else:
  135. end_html.append(tag)
  136. return "".join(end_html), end_html
  137. def get_pred_html_master(self, pred_structures, matched_index, ocr_contents):
  138. end_html = []
  139. td_index = 0
  140. for token in pred_structures:
  141. if "</td>" in token:
  142. txt = ""
  143. b_with = False
  144. if td_index in matched_index.keys():
  145. if (
  146. "<b>" in ocr_contents[matched_index[td_index][0]]
  147. and len(matched_index[td_index]) > 1
  148. ):
  149. b_with = True
  150. for i, td_index_index in enumerate(matched_index[td_index]):
  151. content = ocr_contents[td_index_index][0]
  152. if len(matched_index[td_index]) > 1:
  153. if len(content) == 0:
  154. continue
  155. if content[0] == " ":
  156. content = content[1:]
  157. if "<b>" in content:
  158. content = content[3:]
  159. if "</b>" in content:
  160. content = content[:-4]
  161. if len(content) == 0:
  162. continue
  163. if (
  164. i != len(matched_index[td_index]) - 1
  165. and " " != content[-1]
  166. ):
  167. content += " "
  168. txt += content
  169. if b_with:
  170. txt = "<b>{}</b>".format(txt)
  171. if "<td></td>" == token:
  172. token = "<td>{}</td>".format(txt)
  173. else:
  174. token = "{}</td>".format(txt)
  175. td_index += 1
  176. token = deal_eb_token(token)
  177. end_html.append(token)
  178. html = "".join(end_html)
  179. html = deal_bb(html)
  180. return html, end_html
  181. def _filter_ocr_result(self, pred_bboxes, dt_boxes, rec_res):
  182. y1 = pred_bboxes[:, 1::2].min()
  183. new_dt_boxes = []
  184. new_rec_res = []
  185. for box, rec in zip(dt_boxes, rec_res):
  186. if np.max(box[1::2]) < y1:
  187. continue
  188. new_dt_boxes.append(box)
  189. new_rec_res.append(rec)
  190. return new_dt_boxes, new_rec_res