eval_table.py 3.4 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107
  1. # Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
  2. #
  3. # Licensed under the Apache License, Version 2.0 (the "License");
  4. # you may not use this file except in compliance with the License.
  5. # You may obtain a copy of the License at
  6. #
  7. # http://www.apache.org/licenses/LICENSE-2.0
  8. #
  9. # Unless required by applicable law or agreed to in writing, software
  10. # distributed under the License is distributed on an "AS IS" BASIS,
  11. # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  12. # See the License for the specific language governing permissions and
  13. # limitations under the License.
  14. import os
  15. import sys
  16. __dir__ = os.path.dirname(os.path.abspath(__file__))
  17. sys.path.append(__dir__)
  18. sys.path.insert(0, os.path.abspath(os.path.join(__dir__, "../..")))
  19. import cv2
  20. import pickle
  21. import paddle
  22. from tqdm import tqdm
  23. from ppstructure.table.table_metric import TEDS
  24. from ppstructure.table.predict_table import TableSystem
  25. from ppstructure.utility import init_args
  26. from ppocr.utils.logging import get_logger
  27. logger = get_logger()
  28. def parse_args():
  29. parser = init_args()
  30. parser.add_argument("--gt_path", type=str)
  31. return parser.parse_args()
  32. def load_txt(txt_path):
  33. pred_html_dict = {}
  34. if not os.path.exists(txt_path):
  35. return pred_html_dict
  36. with open(txt_path, encoding="utf-8") as f:
  37. lines = f.readlines()
  38. for line in lines:
  39. line = line.strip().split("\t")
  40. img_name, pred_html = line
  41. pred_html_dict[img_name] = pred_html
  42. return pred_html_dict
  43. def load_result(path):
  44. data = {}
  45. if os.path.exists(path):
  46. data = pickle.load(open(path, "rb"))
  47. return data
  48. def save_result(path, data):
  49. old_data = load_result(path)
  50. old_data.update(data)
  51. with open(path, "wb") as f:
  52. pickle.dump(old_data, f)
  53. def main(gt_path, img_root, args):
  54. os.makedirs(args.output, exist_ok=True)
  55. # init TableSystem
  56. text_sys = TableSystem(args)
  57. # load gt and preds html result
  58. gt_html_dict = load_txt(gt_path)
  59. ocr_result = load_result(os.path.join(args.output, "ocr.pickle"))
  60. structure_result = load_result(os.path.join(args.output, "structure.pickle"))
  61. pred_htmls = []
  62. gt_htmls = []
  63. for img_name, gt_html in tqdm(gt_html_dict.items()):
  64. img = cv2.imread(os.path.join(img_root, img_name))
  65. # run ocr and save result
  66. if img_name not in ocr_result:
  67. dt_boxes, rec_res, _, _ = text_sys._ocr(img)
  68. ocr_result[img_name] = [dt_boxes, rec_res]
  69. save_result(os.path.join(args.output, "ocr.pickle"), ocr_result)
  70. # run structure and save result
  71. if img_name not in structure_result:
  72. structure_res, _ = text_sys._structure(img)
  73. structure_result[img_name] = structure_res
  74. save_result(os.path.join(args.output, "structure.pickle"), structure_result)
  75. dt_boxes, rec_res = ocr_result[img_name]
  76. structure_res = structure_result[img_name]
  77. # match ocr and structure
  78. pred_html = text_sys.match(structure_res, dt_boxes, rec_res)
  79. pred_htmls.append(pred_html)
  80. gt_htmls.append(gt_html)
  81. # compute teds
  82. teds = TEDS(n_jobs=16)
  83. scores = teds.batch_evaluate_html(gt_htmls, pred_htmls)
  84. logger.info("teds: {}".format(sum(scores) / len(scores)))
  85. if __name__ == "__main__":
  86. args = parse_args()
  87. main(args.gt_path, args.image_dir, args)