table_metric.py 5.6 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161
  1. # copyright (c) 2020 PaddlePaddle Authors. All Rights Reserve.
  2. #
  3. # Licensed under the Apache License, Version 2.0 (the "License");
  4. # you may not use this file except in compliance with the License.
  5. # You may obtain a copy of the License at
  6. #
  7. # http://www.apache.org/licenses/LICENSE-2.0
  8. #
  9. # Unless required by applicable law or agreed to in writing, software
  10. # distributed under the License is distributed on an "AS IS" BASIS,
  11. # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  12. # See the License for the specific language governing permissions and
  13. # limitations under the License.
  14. import numpy as np
  15. from ppocr.metrics.det_metric import DetMetric
  16. class TableStructureMetric(object):
  17. def __init__(self, main_indicator="acc", eps=1e-6, del_thead_tbody=False, **kwargs):
  18. self.main_indicator = main_indicator
  19. self.eps = eps
  20. self.del_thead_tbody = del_thead_tbody
  21. self.reset()
  22. def __call__(self, pred_label, batch=None, *args, **kwargs):
  23. preds, labels = pred_label
  24. pred_structure_batch_list = preds["structure_batch_list"]
  25. gt_structure_batch_list = labels["structure_batch_list"]
  26. correct_num = 0
  27. all_num = 0
  28. for (pred, pred_conf), target in zip(
  29. pred_structure_batch_list, gt_structure_batch_list
  30. ):
  31. pred_str = "".join(pred)
  32. target_str = "".join(target)
  33. if self.del_thead_tbody:
  34. pred_str = (
  35. pred_str.replace("<thead>", "")
  36. .replace("</thead>", "")
  37. .replace("<tbody>", "")
  38. .replace("</tbody>", "")
  39. )
  40. target_str = (
  41. target_str.replace("<thead>", "")
  42. .replace("</thead>", "")
  43. .replace("<tbody>", "")
  44. .replace("</tbody>", "")
  45. )
  46. if pred_str == target_str:
  47. correct_num += 1
  48. all_num += 1
  49. self.correct_num += correct_num
  50. self.all_num += all_num
  51. def get_metric(self):
  52. """
  53. return metrics {
  54. 'acc': 0,
  55. }
  56. """
  57. acc = 1.0 * self.correct_num / (self.all_num + self.eps)
  58. self.reset()
  59. return {"acc": acc}
  60. def reset(self):
  61. self.correct_num = 0
  62. self.all_num = 0
  63. self.len_acc_num = 0
  64. self.token_nums = 0
  65. self.anys_dict = dict()
  66. class TableMetric(object):
  67. def __init__(
  68. self,
  69. main_indicator="acc",
  70. compute_bbox_metric=False,
  71. box_format="xyxy",
  72. del_thead_tbody=False,
  73. **kwargs,
  74. ):
  75. """
  76. @param sub_metrics: configs of sub_metric
  77. @param main_matric: main_matric for save best_model
  78. @param kwargs:
  79. """
  80. self.structure_metric = TableStructureMetric(del_thead_tbody=del_thead_tbody)
  81. self.bbox_metric = DetMetric() if compute_bbox_metric else None
  82. self.main_indicator = main_indicator
  83. self.box_format = box_format
  84. self.reset()
  85. def __call__(self, pred_label, batch=None, *args, **kwargs):
  86. self.structure_metric(pred_label)
  87. if self.bbox_metric is not None:
  88. self.bbox_metric(*self.prepare_bbox_metric_input(pred_label))
  89. def prepare_bbox_metric_input(self, pred_label):
  90. pred_bbox_batch_list = []
  91. gt_ignore_tags_batch_list = []
  92. gt_bbox_batch_list = []
  93. preds, labels = pred_label
  94. batch_num = len(preds["bbox_batch_list"])
  95. for batch_idx in range(batch_num):
  96. # pred
  97. pred_bbox_list = [
  98. self.format_box(pred_box)
  99. for pred_box in preds["bbox_batch_list"][batch_idx]
  100. ]
  101. pred_bbox_batch_list.append({"points": pred_bbox_list})
  102. # gt
  103. gt_bbox_list = []
  104. gt_ignore_tags_list = []
  105. for gt_box in labels["bbox_batch_list"][batch_idx]:
  106. gt_bbox_list.append(self.format_box(gt_box))
  107. gt_ignore_tags_list.append(0)
  108. gt_bbox_batch_list.append(gt_bbox_list)
  109. gt_ignore_tags_batch_list.append(gt_ignore_tags_list)
  110. return [
  111. pred_bbox_batch_list,
  112. [0, 0, gt_bbox_batch_list, gt_ignore_tags_batch_list],
  113. ]
  114. def get_metric(self):
  115. structure_metric = self.structure_metric.get_metric()
  116. if self.bbox_metric is None:
  117. return structure_metric
  118. bbox_metric = self.bbox_metric.get_metric()
  119. if self.main_indicator == self.bbox_metric.main_indicator:
  120. output = bbox_metric
  121. for sub_key in structure_metric:
  122. output["structure_metric_{}".format(sub_key)] = structure_metric[
  123. sub_key
  124. ]
  125. else:
  126. output = structure_metric
  127. for sub_key in bbox_metric:
  128. output["bbox_metric_{}".format(sub_key)] = bbox_metric[sub_key]
  129. return output
  130. def reset(self):
  131. self.structure_metric.reset()
  132. if self.bbox_metric is not None:
  133. self.bbox_metric.reset()
  134. def format_box(self, box):
  135. if self.box_format == "xyxy":
  136. x1, y1, x2, y2 = box
  137. box = [[x1, y1], [x2, y1], [x2, y2], [x1, y2]]
  138. elif self.box_format == "xywh":
  139. x, y, w, h = box
  140. x1, y1, x2, y2 = x - w // 2, y - h // 2, x + w // 2, y + h // 2
  141. box = [[x1, y1], [x2, y1], [x2, y2], [x1, y2]]
  142. elif self.box_format == "xyxyxyxy":
  143. x1, y1, x2, y2, x3, y3, x4, y4 = box
  144. box = [[x1, y1], [x2, y2], [x3, y3], [x4, y4]]
  145. return box