map_metric.py 2.9 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778
  1. # Copyright (c) Alibaba, Inc. and its affiliates.
  2. from typing import Dict
  3. import numpy as np
  4. from modelscope.metainfo import Metrics
  5. from modelscope.outputs import OutputKeys
  6. from modelscope.utils.registry import default_group
  7. from .base import Metric
  8. from .builder import METRICS, MetricKeys
  9. @METRICS.register_module(
  10. group_key=default_group, module_name=Metrics.multi_average_precision)
  11. class AveragePrecisionMetric(Metric):
  12. """The metric computation class for multi average precision classes.
  13. This metric class calculates multi average precision for the whole input batches.
  14. """
  15. def __init__(self, *args, **kwargs):
  16. super().__init__(*args, **kwargs)
  17. self.preds = []
  18. self.labels = []
  19. self.thresh = kwargs.get('threshold', 0.5)
  20. def add(self, outputs: Dict, inputs: Dict):
  21. label_name = OutputKeys.LABEL if OutputKeys.LABEL in inputs else OutputKeys.LABELS
  22. ground_truths = inputs[label_name]
  23. eval_results = outputs[label_name]
  24. for key in [
  25. OutputKeys.CAPTION, OutputKeys.TEXT, OutputKeys.BOXES,
  26. OutputKeys.LABELS, OutputKeys.SCORES
  27. ]:
  28. if key in outputs and outputs[key] is not None:
  29. eval_results = outputs[key]
  30. break
  31. assert type(ground_truths) == type(eval_results)
  32. for truth in ground_truths:
  33. self.labels.append(truth)
  34. for result in eval_results:
  35. if isinstance(truth, str):
  36. self.preds.append(result.strip().replace(' ', ''))
  37. else:
  38. self.preds.append(result)
  39. def evaluate(self):
  40. assert len(self.preds) == len(self.labels)
  41. scores = self._calculate_ap_score(self.preds, self.labels, self.thresh)
  42. return {MetricKeys.mAP: scores.mean().item()}
  43. def merge(self, other: 'AveragePrecisionMetric'):
  44. self.preds.extend(other.preds)
  45. self.labels.extend(other.labels)
  46. def __getstate__(self):
  47. return self.preds, self.labels, self.thresh
  48. def __setstate__(self, state):
  49. self.__init__()
  50. self.preds, self.labels, self.thresh = state
  51. def _calculate_ap_score(self, preds, labels, thresh=0.5):
  52. hyps = np.array(preds)
  53. refs = np.array(labels)
  54. a = np.where(hyps[:, :2] < refs[:, :2], refs[:, :2], hyps[:, :2])
  55. b = np.where(hyps[:, 2:] < refs[:, 2:], hyps[:, 2:], refs[:, 2:])
  56. interacts = np.concatenate([a, b], axis=1)
  57. area_predictions = (hyps[:, 2] - hyps[:, 0]) * (
  58. hyps[:, 3] - hyps[:, 1])
  59. area_targets = (refs[:, 2] - refs[:, 0]) * (refs[:, 3] - refs[:, 1])
  60. interacts_w = interacts[:, 2] - interacts[:, 0]
  61. interacts_h = interacts[:, 3] - interacts[:, 1]
  62. area_interacts = interacts_w * interacts_h
  63. ious = area_interacts / (
  64. area_predictions + area_targets - area_interacts + 1e-6)
  65. return (ious >= thresh) & (interacts_w > 0) & (interacts_h > 0)