| 12345678910111213141516171819202122232425262728293031323334353637383940414243444546474849505152535455565758596061626364656667686970 |
- # Copyright (c) Alibaba, Inc. and its affiliates.
- from typing import Dict
- import numpy as np
- from modelscope.metainfo import Metrics
- from modelscope.outputs import OutputKeys
- from modelscope.utils.chinese_utils import remove_space_between_chinese_chars
- from modelscope.utils.registry import default_group
- from modelscope.utils.tensor_utils import torch_nested_numpify
- from .base import Metric
- from .builder import METRICS, MetricKeys
- @METRICS.register_module(group_key=default_group, module_name=Metrics.accuracy)
- class AccuracyMetric(Metric):
- """The metric computation class for classification classes.
- This metric class calculates accuracy for the whole input batches.
- """
- def __init__(self, *args, **kwargs):
- super().__init__(*args, **kwargs)
- self.preds = []
- self.labels = []
- def add(self, outputs: Dict, inputs: Dict):
- label_name = OutputKeys.LABEL if OutputKeys.LABEL in inputs else OutputKeys.LABELS
- ground_truths = inputs[label_name]
- eval_results = None
- for key in [
- OutputKeys.CAPTION, OutputKeys.TEXT, OutputKeys.BOXES,
- OutputKeys.LABEL, OutputKeys.LABELS, OutputKeys.SCORES
- ]:
- if key in outputs and outputs[key] is not None:
- eval_results = outputs[key]
- break
- assert type(ground_truths) == type(eval_results)
- ground_truths = torch_nested_numpify(ground_truths)
- for truth in ground_truths:
- self.labels.append(truth)
- eval_results = torch_nested_numpify(eval_results)
- for result in eval_results:
- if isinstance(truth, str):
- if isinstance(result, list):
- result = result[0]
- assert isinstance(result, str), 'both truth and pred are str'
- self.preds.append(remove_space_between_chinese_chars(result))
- else:
- self.preds.append(result)
- def evaluate(self):
- assert len(self.preds) == len(self.labels)
- return {
- MetricKeys.ACCURACY: (np.asarray([
- pred == ref for pred, ref in zip(self.preds, self.labels)
- ])).mean().item()
- }
- def merge(self, other: 'AccuracyMetric'):
- self.preds.extend(other.preds)
- self.labels.extend(other.labels)
- def __getstate__(self):
- return self.preds, self.labels
- def __setstate__(self, state):
- self.__init__()
- self.preds, self.labels = state
|