action_detection_evaluator.py 7.4 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198
  1. # Copyright (c) Alibaba, Inc. and its affiliates.
  2. import copy
  3. import logging
  4. import os.path as osp
  5. from collections import OrderedDict
  6. import numpy as np
  7. import pandas as pd
  8. from detectron2.evaluation import DatasetEvaluator
  9. from detectron2.evaluation.pascal_voc_evaluation import voc_ap
  10. from detectron2.structures.boxes import Boxes, pairwise_iou
  11. from detectron2.utils import comm
  12. from scipy import interpolate
  13. class DetEvaluator(DatasetEvaluator):
  14. def __init__(self, class_names, output_dir, distributed=False):
  15. self.num_classes = len(class_names)
  16. self.class_names = class_names
  17. self.output_dir = output_dir
  18. self.distributed = distributed
  19. self.predictions = []
  20. self.gts = []
  21. def reset(self):
  22. self.predictions.clear()
  23. self.gts.clear()
  24. def process(self, input, output):
  25. """
  26. :param input: dataloader
  27. :param output: model(input)
  28. :return:
  29. """
  30. gt_instances = [x['instances'].to('cpu') for x in input]
  31. pred_instances = [x['instances'].to('cpu') for x in output]
  32. self.gts.extend(gt_instances)
  33. self.predictions.extend(pred_instances)
  34. def get_instance_by_class(self, instances, c):
  35. instances = copy.deepcopy(instances)
  36. name = 'gt_classes' if instances.has('gt_classes') else 'pred_classes'
  37. idxs = np.where(instances.get(name).numpy() == c)[0].tolist()
  38. data = {}
  39. for k, v in instances.get_fields().items():
  40. data[k] = [v[i] for i in idxs]
  41. return data
  42. def evaluate(self):
  43. if self.distributed:
  44. comm.synchronize()
  45. self.predictions = sum(comm.gather(self.predictions, dst=0), [])
  46. self.gts = sum(comm.gather(self.gts, dst=0), [])
  47. if not comm.is_main_process():
  48. return
  49. logger = logging.getLogger('detectron2.human.' + __name__)
  50. logger.info(', '.join([f'{a}' for a in self.class_names]))
  51. maps = []
  52. precisions = []
  53. recalls = []
  54. for iou_th in [0.3, 0.5, 0.7]:
  55. aps, prs, ths = self.calc_map(iou_th)
  56. map = np.nanmean([x for x in aps if x > 0.01])
  57. maps.append(map)
  58. logger.info(f'iou_th:{iou_th},' + 'Aps:'
  59. + ','.join([f'{ap:.2f}'
  60. for ap in aps]) + f', {map:.3f}')
  61. precision, recall = zip(*prs)
  62. logger.info('precision:'
  63. + ', '.join([f'{p:.2f}' for p in precision]))
  64. logger.info('recall: ' + ', '.join([f'{p:.2f}' for p in recall]))
  65. logger.info('score th: ' + ', '.join([f'{p:.2f}' for p in ths]))
  66. logger.info(f'mean-precision:{np.nanmean(precision):.3f}')
  67. logger.info(f'mean-recall:{np.nanmean(recall):.3f}')
  68. precisions.append(np.nanmean(precision))
  69. recalls.append(np.nanmean(recall))
  70. res = OrderedDict({
  71. 'det': {
  72. 'mAP': np.nanmean(maps),
  73. 'precision': np.nanmean(precisions),
  74. 'recall': np.nanmean(recalls)
  75. }
  76. })
  77. return res
  78. def calc_map(self, iou_th):
  79. aps = []
  80. prs = []
  81. ths = []
  82. # 对每个类别
  83. interpolate_precs = []
  84. for c in range(self.num_classes):
  85. ap, recalls, precisions, scores = self.det_eval(iou_th, c)
  86. if iou_th == 0.3:
  87. p1 = interpolate_precision(recalls, precisions)
  88. interpolate_precs.append(p1)
  89. recalls = np.concatenate(([0.0], recalls, [1.0]))
  90. precisions = np.concatenate(([0.0], precisions, [0.0]))
  91. scores = np.concatenate(([1.0], scores, [0.0]))
  92. t = precisions + recalls
  93. t[t == 0] = 1e-5
  94. f_score = 2 * precisions * recalls / t
  95. f_score[np.isnan(f_score)] = 0
  96. idx = np.argmax(f_score)
  97. # print(iou_th,c,np.argmax(f_score),np.argmax(t))
  98. precision_recall = (precisions[idx], recalls[idx])
  99. prs.append(precision_recall)
  100. aps.append(ap)
  101. ths.append(scores[idx])
  102. if iou_th == 0.3:
  103. interpolate_precs = np.stack(interpolate_precs, axis=1)
  104. df = pd.DataFrame(data=interpolate_precs)
  105. df.to_csv(
  106. osp.join(self.output_dir, 'pr_data.csv'),
  107. index=False,
  108. columns=None)
  109. return aps, prs, ths
  110. def det_eval(self, iou_th, class_id):
  111. c = class_id
  112. class_res_gt = {}
  113. npos = 0
  114. # 对每个样本
  115. for i, (gt, pred) in enumerate(zip(self.gts, self.predictions)):
  116. gt_classes = gt.gt_classes.tolist()
  117. pred_classes = pred.pred_classes.tolist()
  118. if c not in gt_classes + pred_classes:
  119. continue
  120. pred_data = self.get_instance_by_class(pred, c)
  121. gt_data = self.get_instance_by_class(gt, c)
  122. res = {}
  123. if c in gt_classes:
  124. res.update({
  125. 'gt_bbox': Boxes.cat(gt_data['gt_boxes']),
  126. 'det': [False] * len(gt_data['gt_classes'])
  127. })
  128. if c in pred_classes:
  129. res.update({'pred_bbox': Boxes.cat(pred_data['pred_boxes'])})
  130. res.update(
  131. {'pred_score': [s.item() for s in pred_data['scores']]})
  132. class_res_gt[i] = res
  133. npos += len(gt_data['gt_classes'])
  134. all_preds = []
  135. for img_id, res in class_res_gt.items():
  136. if 'pred_bbox' in res:
  137. for i in range(len(res['pred_bbox'])):
  138. bbox = res['pred_bbox'][i]
  139. score = res['pred_score'][i]
  140. all_preds.append([img_id, bbox, score])
  141. sorted_preds = list(
  142. sorted(all_preds, key=lambda x: x[2], reverse=True))
  143. scores = [s[-1] for s in sorted_preds]
  144. nd = len(sorted_preds)
  145. tp = np.zeros(nd)
  146. fp = np.zeros(nd)
  147. for d in range(nd):
  148. img_id, pred_bbox, score = sorted_preds[d]
  149. R = class_res_gt[sorted_preds[d][0]]
  150. ovmax = -np.inf
  151. if 'gt_bbox' in R:
  152. gt_bbox = R['gt_bbox']
  153. IoUs = pairwise_iou(pred_bbox, gt_bbox).numpy()
  154. ovmax = IoUs[0].max()
  155. jmax = np.argmax(IoUs[0]) # hit该图像的第几个gt
  156. if ovmax > iou_th:
  157. if not R['det'][jmax]: # 该gt还没有预测过
  158. tp[d] = 1.0
  159. R['det'][jmax] = True
  160. else: # 重复预测
  161. fp[d] = 1.0
  162. else:
  163. fp[d] = 1.0
  164. fp = np.cumsum(fp)
  165. tp = np.cumsum(tp)
  166. rec = tp / float(npos)
  167. # avoid divide by zero in case the first detection matches a difficult
  168. # ground truth
  169. prec = tp / np.maximum(tp + fp, np.finfo(np.float64).eps)
  170. ap = voc_ap(rec, prec, False)
  171. return ap, rec, prec, scores
  172. def interpolate_precision(rec, prec):
  173. rec = np.concatenate(([0.0], rec, [1.0, 1.1]))
  174. prec = np.concatenate(([1.0], prec, [0.0]))
  175. for i in range(prec.size - 1, 0, -1):
  176. prec[i - 1] = np.maximum(prec[i - 1], prec[i])
  177. i = np.where(rec[1:] != rec[:-1])[0] # 从recall改变的地方取值
  178. rec, prec = rec[i], prec[i]
  179. f = interpolate.interp1d(rec, prec)
  180. r1 = np.linspace(0, 1, 101)
  181. p1 = f(r1)
  182. return p1