det_metric.py 5.6 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153
  1. # Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved.
  2. #
  3. # Licensed under the Apache License, Version 2.0 (the "License");
  4. # you may not use this file except in compliance with the License.
  5. # You may obtain a copy of the License at
  6. #
  7. # http://www.apache.org/licenses/LICENSE-2.0
  8. #
  9. # Unless required by applicable law or agreed to in writing, software
  10. # distributed under the License is distributed on an "AS IS" BASIS,
  11. # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  12. # See the License for the specific language governing permissions and
  13. # limitations under the License.
  14. from __future__ import absolute_import
  15. from __future__ import division
  16. from __future__ import print_function
  17. __all__ = ["DetMetric", "DetFCEMetric"]
  18. from .eval_det_iou import DetectionIoUEvaluator
  19. class DetMetric(object):
  20. def __init__(self, main_indicator="hmean", **kwargs):
  21. self.evaluator = DetectionIoUEvaluator()
  22. self.main_indicator = main_indicator
  23. self.reset()
  24. def __call__(self, preds, batch, **kwargs):
  25. """
  26. batch: a list produced by dataloaders.
  27. image: np.ndarray of shape (N, C, H, W).
  28. ratio_list: np.ndarray of shape(N,2)
  29. polygons: np.ndarray of shape (N, K, 4, 2), the polygons of objective regions.
  30. ignore_tags: np.ndarray of shape (N, K), indicates whether a region is ignorable or not.
  31. preds: a list of dict produced by post process
  32. points: np.ndarray of shape (N, K, 4, 2), the polygons of objective regions.
  33. """
  34. gt_polyons_batch = batch[2]
  35. ignore_tags_batch = batch[3]
  36. for pred, gt_polyons, ignore_tags in zip(
  37. preds, gt_polyons_batch, ignore_tags_batch
  38. ):
  39. # prepare gt
  40. gt_info_list = [
  41. {"points": gt_polyon, "text": "", "ignore": ignore_tag}
  42. for gt_polyon, ignore_tag in zip(gt_polyons, ignore_tags)
  43. ]
  44. # prepare det
  45. det_info_list = [
  46. {"points": det_polyon, "text": ""} for det_polyon in pred["points"]
  47. ]
  48. result = self.evaluator.evaluate_image(gt_info_list, det_info_list)
  49. self.results.append(result)
  50. def get_metric(self):
  51. """
  52. return metrics {
  53. 'precision': 0,
  54. 'recall': 0,
  55. 'hmean': 0
  56. }
  57. """
  58. metrics = self.evaluator.combine_results(self.results)
  59. self.reset()
  60. return metrics
  61. def reset(self):
  62. self.results = [] # clear results
  63. class DetFCEMetric(object):
  64. def __init__(self, main_indicator="hmean", **kwargs):
  65. self.evaluator = DetectionIoUEvaluator()
  66. self.main_indicator = main_indicator
  67. self.reset()
  68. def __call__(self, preds, batch, **kwargs):
  69. """
  70. batch: a list produced by dataloaders.
  71. image: np.ndarray of shape (N, C, H, W).
  72. ratio_list: np.ndarray of shape(N,2)
  73. polygons: np.ndarray of shape (N, K, 4, 2), the polygons of objective regions.
  74. ignore_tags: np.ndarray of shape (N, K), indicates whether a region is ignorable or not.
  75. preds: a list of dict produced by post process
  76. points: np.ndarray of shape (N, K, 4, 2), the polygons of objective regions.
  77. """
  78. gt_polyons_batch = batch[2]
  79. ignore_tags_batch = batch[3]
  80. for pred, gt_polyons, ignore_tags in zip(
  81. preds, gt_polyons_batch, ignore_tags_batch
  82. ):
  83. # prepare gt
  84. gt_info_list = [
  85. {"points": gt_polyon, "text": "", "ignore": ignore_tag}
  86. for gt_polyon, ignore_tag in zip(gt_polyons, ignore_tags)
  87. ]
  88. # prepare det
  89. det_info_list = [
  90. {"points": det_polyon, "text": "", "score": score}
  91. for det_polyon, score in zip(pred["points"], pred["scores"])
  92. ]
  93. for score_thr in self.results.keys():
  94. det_info_list_thr = [
  95. det_info
  96. for det_info in det_info_list
  97. if det_info["score"] >= score_thr
  98. ]
  99. result = self.evaluator.evaluate_image(gt_info_list, det_info_list_thr)
  100. self.results[score_thr].append(result)
  101. def get_metric(self):
  102. """
  103. return metrics {'heman':0,
  104. 'thr 0.3':'precision: 0 recall: 0 hmean: 0',
  105. 'thr 0.4':'precision: 0 recall: 0 hmean: 0',
  106. 'thr 0.5':'precision: 0 recall: 0 hmean: 0',
  107. 'thr 0.6':'precision: 0 recall: 0 hmean: 0',
  108. 'thr 0.7':'precision: 0 recall: 0 hmean: 0',
  109. 'thr 0.8':'precision: 0 recall: 0 hmean: 0',
  110. 'thr 0.9':'precision: 0 recall: 0 hmean: 0',
  111. }
  112. """
  113. metrics = {}
  114. hmean = 0
  115. for score_thr in self.results.keys():
  116. metric = self.evaluator.combine_results(self.results[score_thr])
  117. # for key, value in metric.items():
  118. # metrics['{}_{}'.format(key, score_thr)] = value
  119. metric_str = "precision:{:.5f} recall:{:.5f} hmean:{:.5f}".format(
  120. metric["precision"], metric["recall"], metric["hmean"]
  121. )
  122. metrics["thr {}".format(score_thr)] = metric_str
  123. hmean = max(hmean, metric["hmean"])
  124. metrics["hmean"] = hmean
  125. self.reset()
  126. return metrics
  127. def reset(self):
  128. self.results = {
  129. 0.3: [],
  130. 0.4: [],
  131. 0.5: [],
  132. 0.6: [],
  133. 0.7: [],
  134. 0.8: [],
  135. 0.9: [],
  136. } # clear results