accuracy_metric.py 2.5 KB

12345678910111213141516171819202122232425262728293031323334353637383940414243444546474849505152535455565758596061626364656667686970
  1. # Copyright (c) Alibaba, Inc. and its affiliates.
  2. from typing import Dict
  3. import numpy as np
  4. from modelscope.metainfo import Metrics
  5. from modelscope.outputs import OutputKeys
  6. from modelscope.utils.chinese_utils import remove_space_between_chinese_chars
  7. from modelscope.utils.registry import default_group
  8. from modelscope.utils.tensor_utils import torch_nested_numpify
  9. from .base import Metric
  10. from .builder import METRICS, MetricKeys
  11. @METRICS.register_module(group_key=default_group, module_name=Metrics.accuracy)
  12. class AccuracyMetric(Metric):
  13. """The metric computation class for classification classes.
  14. This metric class calculates accuracy for the whole input batches.
  15. """
  16. def __init__(self, *args, **kwargs):
  17. super().__init__(*args, **kwargs)
  18. self.preds = []
  19. self.labels = []
  20. def add(self, outputs: Dict, inputs: Dict):
  21. label_name = OutputKeys.LABEL if OutputKeys.LABEL in inputs else OutputKeys.LABELS
  22. ground_truths = inputs[label_name]
  23. eval_results = None
  24. for key in [
  25. OutputKeys.CAPTION, OutputKeys.TEXT, OutputKeys.BOXES,
  26. OutputKeys.LABEL, OutputKeys.LABELS, OutputKeys.SCORES
  27. ]:
  28. if key in outputs and outputs[key] is not None:
  29. eval_results = outputs[key]
  30. break
  31. assert type(ground_truths) == type(eval_results)
  32. ground_truths = torch_nested_numpify(ground_truths)
  33. for truth in ground_truths:
  34. self.labels.append(truth)
  35. eval_results = torch_nested_numpify(eval_results)
  36. for result in eval_results:
  37. if isinstance(truth, str):
  38. if isinstance(result, list):
  39. result = result[0]
  40. assert isinstance(result, str), 'both truth and pred are str'
  41. self.preds.append(remove_space_between_chinese_chars(result))
  42. else:
  43. self.preds.append(result)
  44. def evaluate(self):
  45. assert len(self.preds) == len(self.labels)
  46. return {
  47. MetricKeys.ACCURACY: (np.asarray([
  48. pred == ref for pred, ref in zip(self.preds, self.labels)
  49. ])).mean().item()
  50. }
  51. def merge(self, other: 'AccuracyMetric'):
  52. self.preds.extend(other.preds)
  53. self.labels.extend(other.labels)
  54. def __getstate__(self):
  55. return self.preds, self.labels
  56. def __setstate__(self, state):
  57. self.__init__()
  58. self.preds, self.labels = state