| 123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778 |
- # Copyright (c) Alibaba, Inc. and its affiliates.
- from typing import Dict
- import numpy as np
- from modelscope.metainfo import Metrics
- from modelscope.outputs import OutputKeys
- from modelscope.utils.registry import default_group
- from .base import Metric
- from .builder import METRICS, MetricKeys
- @METRICS.register_module(
- group_key=default_group, module_name=Metrics.multi_average_precision)
- class AveragePrecisionMetric(Metric):
- """The metric computation class for multi average precision classes.
- This metric class calculates multi average precision for the whole input batches.
- """
- def __init__(self, *args, **kwargs):
- super().__init__(*args, **kwargs)
- self.preds = []
- self.labels = []
- self.thresh = kwargs.get('threshold', 0.5)
- def add(self, outputs: Dict, inputs: Dict):
- label_name = OutputKeys.LABEL if OutputKeys.LABEL in inputs else OutputKeys.LABELS
- ground_truths = inputs[label_name]
- eval_results = outputs[label_name]
- for key in [
- OutputKeys.CAPTION, OutputKeys.TEXT, OutputKeys.BOXES,
- OutputKeys.LABELS, OutputKeys.SCORES
- ]:
- if key in outputs and outputs[key] is not None:
- eval_results = outputs[key]
- break
- assert type(ground_truths) == type(eval_results)
- for truth in ground_truths:
- self.labels.append(truth)
- for result in eval_results:
- if isinstance(truth, str):
- self.preds.append(result.strip().replace(' ', ''))
- else:
- self.preds.append(result)
- def evaluate(self):
- assert len(self.preds) == len(self.labels)
- scores = self._calculate_ap_score(self.preds, self.labels, self.thresh)
- return {MetricKeys.mAP: scores.mean().item()}
- def merge(self, other: 'AveragePrecisionMetric'):
- self.preds.extend(other.preds)
- self.labels.extend(other.labels)
- def __getstate__(self):
- return self.preds, self.labels, self.thresh
- def __setstate__(self, state):
- self.__init__()
- self.preds, self.labels, self.thresh = state
- def _calculate_ap_score(self, preds, labels, thresh=0.5):
- hyps = np.array(preds)
- refs = np.array(labels)
- a = np.where(hyps[:, :2] < refs[:, :2], refs[:, :2], hyps[:, :2])
- b = np.where(hyps[:, 2:] < refs[:, 2:], hyps[:, 2:], refs[:, 2:])
- interacts = np.concatenate([a, b], axis=1)
- area_predictions = (hyps[:, 2] - hyps[:, 0]) * (
- hyps[:, 3] - hyps[:, 1])
- area_targets = (refs[:, 2] - refs[:, 0]) * (refs[:, 3] - refs[:, 1])
- interacts_w = interacts[:, 2] - interacts[:, 0]
- interacts_h = interacts[:, 3] - interacts[:, 1]
- area_interacts = interacts_w * interacts_h
- ious = area_interacts / (
- area_predictions + area_targets - area_interacts + 1e-6)
- return (ious >= thresh) & (interacts_w > 0) & (interacts_h > 0)
|