eval_with_label_end2end.py 8.0 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259
  1. # Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved.
  2. #
  3. # Licensed under the Apache License, Version 2.0 (the "License");
  4. # you may not use this file except in compliance with the License.
  5. # You may obtain a copy of the License at
  6. #
  7. # http://www.apache.org/licenses/LICENSE-2.0
  8. #
  9. # Unless required by applicable law or agreed to in writing, software
  10. # distributed under the License is distributed on an "AS IS" BASIS,
  11. # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  12. # See the License for the specific language governing permissions and
  13. # limitations under the License.
  14. import os
  15. import re
  16. import sys
  17. import shapely
  18. from shapely.geometry import Polygon
  19. import numpy as np
  20. from collections import defaultdict
  21. import operator
  22. from rapidfuzz.distance import Levenshtein
  23. import argparse
  24. import json
  25. import copy
  26. def parse_ser_results_fp(fp, fp_type="gt", ignore_background=True):
  27. # img/zh_val_0.jpg {
  28. # "height": 3508,
  29. # "width": 2480,
  30. # "ocr_info": [
  31. # {"text": "Maribyrnong", "label": "other", "bbox": [1958, 144, 2184, 198]},
  32. # {"text": "CITYCOUNCIL", "label": "other", "bbox": [2052, 183, 2171, 214]},
  33. # ]
  34. assert fp_type in ["gt", "pred"]
  35. key = "label" if fp_type == "gt" else "pred"
  36. res_dict = dict()
  37. with open(fp, "r", encoding="utf-8") as fin:
  38. lines = fin.readlines()
  39. for _, line in enumerate(lines):
  40. img_path, info = line.strip().split("\t")
  41. # get key
  42. image_name = os.path.basename(img_path)
  43. res_dict[image_name] = []
  44. # get infos
  45. json_info = json.loads(info)
  46. for single_ocr_info in json_info["ocr_info"]:
  47. label = single_ocr_info[key].upper()
  48. if label in ["O", "OTHERS", "OTHER"]:
  49. label = "O"
  50. if ignore_background and label == "O":
  51. continue
  52. single_ocr_info["label"] = label
  53. res_dict[image_name].append(copy.deepcopy(single_ocr_info))
  54. return res_dict
  55. def polygon_from_str(polygon_points):
  56. """
  57. Create a shapely polygon object from gt or dt line.
  58. """
  59. polygon_points = np.array(polygon_points).reshape(4, 2)
  60. polygon = Polygon(polygon_points).convex_hull
  61. return polygon
  62. def polygon_iou(poly1, poly2):
  63. """
  64. Intersection over union between two shapely polygons.
  65. """
  66. if not poly1.intersects(poly2): # this test is fast and can accelerate calculation
  67. iou = 0
  68. else:
  69. try:
  70. inter_area = poly1.intersection(poly2).area
  71. union_area = poly1.area + poly2.area - inter_area
  72. iou = float(inter_area) / union_area
  73. except shapely.geos.TopologicalError:
  74. # except Exception as e:
  75. # print(e)
  76. print("shapely.geos.TopologicalError occurred, iou set to 0")
  77. iou = 0
  78. return iou
  79. def ed(args, str1, str2):
  80. if args.ignore_space:
  81. str1 = str1.replace(" ", "")
  82. str2 = str2.replace(" ", "")
  83. if args.ignore_case:
  84. str1 = str1.lower()
  85. str2 = str2.lower()
  86. return Levenshtein.distance(str1, str2)
  87. def convert_bbox_to_polygon(bbox):
  88. """
  89. bbox : [x1, y1, x2, y2]
  90. output: [[x1, y1], [x2, y2], [x3, y3], [x4, y4]]
  91. """
  92. xmin, ymin, xmax, ymax = bbox
  93. poly = [[xmin, ymin], [xmax, ymin], [xmax, ymax], [xmin, ymax]]
  94. return poly
  95. def eval_e2e(args):
  96. # gt
  97. gt_results = parse_ser_results_fp(args.gt_json_path, "gt", args.ignore_background)
  98. # pred
  99. dt_results = parse_ser_results_fp(
  100. args.pred_json_path, "pred", args.ignore_background
  101. )
  102. iou_thresh = args.iou_thres
  103. num_gt_chars = 0
  104. gt_count = 0
  105. dt_count = 0
  106. hit = 0
  107. ed_sum = 0
  108. for img_name in dt_results:
  109. gt_info = gt_results[img_name]
  110. gt_count += len(gt_info)
  111. dt_info = dt_results[img_name]
  112. dt_count += len(dt_info)
  113. dt_match = [False] * len(dt_info)
  114. gt_match = [False] * len(gt_info)
  115. all_ious = defaultdict(tuple)
  116. # gt: {text, label, bbox or poly}
  117. for index_gt, gt in enumerate(gt_info):
  118. if "poly" not in gt:
  119. gt["poly"] = convert_bbox_to_polygon(gt["bbox"])
  120. gt_poly = polygon_from_str(gt["poly"])
  121. for index_dt, dt in enumerate(dt_info):
  122. if "poly" not in dt:
  123. dt["poly"] = convert_bbox_to_polygon(dt["bbox"])
  124. dt_poly = polygon_from_str(dt["poly"])
  125. iou = polygon_iou(dt_poly, gt_poly)
  126. if iou >= iou_thresh:
  127. all_ious[(index_gt, index_dt)] = iou
  128. sorted_ious = sorted(all_ious.items(), key=operator.itemgetter(1), reverse=True)
  129. sorted_gt_dt_pairs = [item[0] for item in sorted_ious]
  130. # matched gt and dt
  131. for gt_dt_pair in sorted_gt_dt_pairs:
  132. index_gt, index_dt = gt_dt_pair
  133. if gt_match[index_gt] == False and dt_match[index_dt] == False:
  134. gt_match[index_gt] = True
  135. dt_match[index_dt] = True
  136. # ocr rec results
  137. gt_text = gt_info[index_gt]["text"]
  138. dt_text = dt_info[index_dt]["text"]
  139. # ser results
  140. gt_label = gt_info[index_gt]["label"]
  141. dt_label = dt_info[index_dt]["pred"]
  142. if True: # ignore_masks[index_gt] == '0':
  143. ed_sum += ed(args, gt_text, dt_text)
  144. num_gt_chars += len(gt_text)
  145. if gt_text == dt_text:
  146. if args.ignore_ser_prediction or gt_label == dt_label:
  147. hit += 1
  148. # unmatched dt
  149. for tindex, dt_match_flag in enumerate(dt_match):
  150. if dt_match_flag == False:
  151. dt_text = dt_info[tindex]["text"]
  152. gt_text = ""
  153. ed_sum += ed(args, dt_text, gt_text)
  154. # unmatched gt
  155. for tindex, gt_match_flag in enumerate(gt_match):
  156. if gt_match_flag == False:
  157. dt_text = ""
  158. gt_text = gt_info[tindex]["text"]
  159. ed_sum += ed(args, gt_text, dt_text)
  160. num_gt_chars += len(gt_text)
  161. eps = 1e-9
  162. print("config: ", args)
  163. print("hit, dt_count, gt_count", hit, dt_count, gt_count)
  164. precision = hit / (dt_count + eps)
  165. recall = hit / (gt_count + eps)
  166. fmeasure = 2.0 * precision * recall / (precision + recall + eps)
  167. avg_edit_dist_img = ed_sum / len(gt_results)
  168. avg_edit_dist_field = ed_sum / (gt_count + eps)
  169. character_acc = 1 - ed_sum / (num_gt_chars + eps)
  170. print("character_acc: %.2f" % (character_acc * 100) + "%")
  171. print("avg_edit_dist_field: %.2f" % (avg_edit_dist_field))
  172. print("avg_edit_dist_img: %.2f" % (avg_edit_dist_img))
  173. print("precision: %.2f" % (precision * 100) + "%")
  174. print("recall: %.2f" % (recall * 100) + "%")
  175. print("fmeasure: %.2f" % (fmeasure * 100) + "%")
  176. return
  177. def parse_args():
  178. """ """
  179. def str2bool(v):
  180. return v.lower() in ("true", "t", "1")
  181. parser = argparse.ArgumentParser()
  182. ## Required parameters
  183. parser.add_argument(
  184. "--gt_json_path",
  185. default=None,
  186. type=str,
  187. required=True,
  188. )
  189. parser.add_argument(
  190. "--pred_json_path",
  191. default=None,
  192. type=str,
  193. required=True,
  194. )
  195. parser.add_argument("--iou_thres", default=0.5, type=float)
  196. parser.add_argument(
  197. "--ignore_case",
  198. default=False,
  199. type=str2bool,
  200. help="whether to do lower case for the strs",
  201. )
  202. parser.add_argument(
  203. "--ignore_space", default=True, type=str2bool, help="whether to ignore space"
  204. )
  205. parser.add_argument(
  206. "--ignore_background",
  207. default=True,
  208. type=str2bool,
  209. help="whether to ignore other label",
  210. )
  211. parser.add_argument(
  212. "--ignore_ser_prediction",
  213. default=False,
  214. type=str2bool,
  215. help="whether to ignore ocr pred results",
  216. )
  217. args = parser.parse_args()
  218. return args
  219. if __name__ == "__main__":
  220. args = parse_args()
  221. eval_e2e(args)