sequence_classification_metric.py 3.0 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384
  1. # Copyright (c) Alibaba, Inc. and its affiliates.
  2. from typing import Dict
  3. import numpy as np
  4. from sklearn.metrics import accuracy_score, f1_score
  5. from modelscope.metainfo import Metrics
  6. from modelscope.outputs import OutputKeys
  7. from modelscope.utils.registry import default_group
  8. from modelscope.utils.tensor_utils import (torch_nested_detach,
  9. torch_nested_numpify)
  10. from .base import Metric
  11. from .builder import METRICS, MetricKeys
  12. @METRICS.register_module(
  13. group_key=default_group, module_name=Metrics.seq_cls_metric)
  14. class SequenceClassificationMetric(Metric):
  15. """The metric computation class for sequence classification tasks.
  16. This metric class calculates accuracy/F1 of all the input batches.
  17. Args:
  18. label_name: The key of label column in the 'inputs' arg.
  19. logit_name: The key of logits column in the 'inputs' arg.
  20. """
  21. def __init__(self,
  22. label_name=OutputKeys.LABELS,
  23. logit_name=OutputKeys.LOGITS,
  24. *args,
  25. **kwargs):
  26. super().__init__(*args, **kwargs)
  27. self.preds = []
  28. self.labels = []
  29. self.label_name = label_name
  30. self.logit_name = logit_name
  31. def add(self, outputs: Dict, inputs: Dict):
  32. ground_truths = inputs[self.label_name]
  33. eval_results = outputs[self.logit_name]
  34. self.preds.append(
  35. torch_nested_numpify(torch_nested_detach(eval_results)))
  36. self.labels.append(
  37. torch_nested_numpify(torch_nested_detach(ground_truths)))
  38. def evaluate(self):
  39. preds = np.concatenate(self.preds, axis=0)
  40. labels = np.concatenate(self.labels, axis=0)
  41. assert len(preds.shape) == 2, 'Only support predictions with shape: (batch_size, num_labels),' \
  42. 'multi-label classification is not supported in this metric class.'
  43. preds_max = np.argmax(preds, axis=1)
  44. if preds.shape[1] > 2:
  45. metrics = {
  46. MetricKeys.ACCURACY: accuracy_score(labels, preds_max),
  47. MetricKeys.Micro_F1:
  48. f1_score(labels, preds_max, average='micro'),
  49. MetricKeys.Macro_F1:
  50. f1_score(labels, preds_max, average='macro'),
  51. }
  52. metrics[MetricKeys.F1] = metrics[MetricKeys.Micro_F1]
  53. return metrics
  54. else:
  55. metrics = {
  56. MetricKeys.ACCURACY:
  57. accuracy_score(labels, preds_max),
  58. MetricKeys.Binary_F1:
  59. f1_score(labels, preds_max, average='binary'),
  60. }
  61. metrics[MetricKeys.F1] = metrics[MetricKeys.Binary_F1]
  62. return metrics
  63. def merge(self, other: 'SequenceClassificationMetric'):
  64. self.preds.extend(other.preds)
  65. self.labels.extend(other.labels)
  66. def __getstate__(self):
  67. return self.preds, self.labels, self.label_name, self.logit_name
  68. def __setstate__(self, state):
  69. self.__init__()
  70. self.preds, self.labels, self.label_name, self.logit_name = state