token_classification_metric.py 5.6 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153
  1. # Copyright (c) Alibaba, Inc. and its affiliates.
  2. import importlib
  3. from typing import Dict, List, Optional, Union
  4. import numpy as np
  5. from modelscope.outputs import OutputKeys
  6. from ..metainfo import Metrics
  7. from ..utils.registry import default_group
  8. from ..utils.tensor_utils import torch_nested_detach, torch_nested_numpify
  9. from .base import Metric
  10. from .builder import METRICS, MetricKeys
  11. @METRICS.register_module(
  12. group_key=default_group, module_name=Metrics.token_cls_metric)
  13. class TokenClassificationMetric(Metric):
  14. """The metric computation class for token-classification task.
  15. This metric class uses seqeval to calculate the scores.
  16. Args:
  17. label_name(str, `optional`): The key of label column in the 'inputs' arg.
  18. logit_name(str, `optional`): The key of logits column in the 'inputs' arg.
  19. return_entity_level_metrics (bool, `optional`):
  20. Whether to return every label's detail metrics, default False.
  21. label2id(dict, `optional`): The label2id information to get the token labels.
  22. """
  23. def __init__(self,
  24. label_name=OutputKeys.LABELS,
  25. logit_name=OutputKeys.LOGITS,
  26. return_entity_level_metrics=False,
  27. label2id=None,
  28. *args,
  29. **kwargs):
  30. super().__init__(*args, **kwargs)
  31. self.return_entity_level_metrics = return_entity_level_metrics
  32. self.preds = []
  33. self.labels = []
  34. self.label2id = label2id
  35. self.label_name = label_name
  36. self.logit_name = logit_name
  37. def add(self, outputs: Dict, inputs: Dict):
  38. ground_truths = inputs[self.label_name]
  39. eval_results = outputs[self.logit_name]
  40. self.preds.append(
  41. torch_nested_numpify(torch_nested_detach(eval_results)))
  42. self.labels.append(
  43. torch_nested_numpify(torch_nested_detach(ground_truths)))
  44. def evaluate(self):
  45. label2id = self.label2id
  46. if label2id is None:
  47. assert hasattr(self, 'trainer')
  48. label2id = self.trainer.label2id
  49. self.id2label = {id: label for label, id in label2id.items()}
  50. self.preds = np.concatenate(self.preds, axis=0)
  51. self.labels = np.concatenate(self.labels, axis=0)
  52. predictions = np.argmax(self.preds, axis=-1)
  53. true_predictions = [[
  54. self.id2label[p] for (p, lb) in zip(prediction, label)
  55. if lb != -100
  56. ] for prediction, label in zip(predictions, self.labels)]
  57. true_labels = [[
  58. self.id2label[lb] for (p, lb) in zip(prediction, label)
  59. if lb != -100
  60. ] for prediction, label in zip(predictions, self.labels)]
  61. results = self._compute(
  62. predictions=true_predictions, references=true_labels)
  63. if self.return_entity_level_metrics:
  64. final_results = {}
  65. for key, value in results.items():
  66. if isinstance(value, dict):
  67. for n, v in value.items():
  68. final_results[f'{key}_{n}'] = v
  69. else:
  70. final_results[key] = value
  71. return final_results
  72. else:
  73. return {
  74. MetricKeys.PRECISION: results[MetricKeys.PRECISION],
  75. MetricKeys.RECALL: results[MetricKeys.RECALL],
  76. MetricKeys.F1: results[MetricKeys.F1],
  77. MetricKeys.ACCURACY: results[MetricKeys.ACCURACY],
  78. }
  79. def merge(self, other: 'TokenClassificationMetric'):
  80. self.preds.extend(other.preds)
  81. self.labels.extend(other.labels)
  82. def __getstate__(self):
  83. return (self.return_entity_level_metrics, self.preds, self.labels,
  84. self.label2id, self.label_name, self.logit_name)
  85. def __setstate__(self, state):
  86. self.__init__()
  87. (self.return_entity_level_metrics, self.preds, self.labels,
  88. self.label2id, self.label_name, self.logit_name) = state
  89. @staticmethod
  90. def _compute(
  91. predictions,
  92. references,
  93. suffix: bool = False,
  94. scheme: Optional[str] = None,
  95. mode: Optional[str] = None,
  96. sample_weight: Optional[List[int]] = None,
  97. zero_division: Union[str, int] = 'warn',
  98. ):
  99. from seqeval.metrics import accuracy_score, classification_report
  100. if scheme is not None:
  101. try:
  102. scheme_module = importlib.import_module('seqeval.scheme')
  103. scheme = getattr(scheme_module, scheme)
  104. except AttributeError:
  105. raise ValueError(
  106. f'Scheme should be one of [IOB1, IOB2, IOE1, IOE2, IOBES, BILOU], got {scheme}'
  107. )
  108. report = classification_report(
  109. y_true=references,
  110. y_pred=predictions,
  111. suffix=suffix,
  112. output_dict=True,
  113. scheme=scheme,
  114. mode=mode,
  115. sample_weight=sample_weight,
  116. zero_division=zero_division,
  117. )
  118. report.pop('macro avg')
  119. report.pop('weighted avg')
  120. overall_score = report.pop('micro avg')
  121. scores = {
  122. type_name: {
  123. MetricKeys.PRECISION: score['precision'],
  124. MetricKeys.RECALL: score['recall'],
  125. MetricKeys.F1: score['f1-score'],
  126. 'number': score['support'],
  127. }
  128. for type_name, score in report.items()
  129. }
  130. scores[MetricKeys.PRECISION] = overall_score['precision']
  131. scores[MetricKeys.RECALL] = overall_score['recall']
  132. scores[MetricKeys.F1] = overall_score['f1-score']
  133. scores[MetricKeys.ACCURACY] = accuracy_score(
  134. y_true=references, y_pred=predictions)
  135. return scores