translation_evaluation_metric.py 5.6 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174
  1. # Copyright (c) Alibaba, Inc. and its affiliates.
  2. import importlib
  3. from typing import Dict, List, Union
  4. from pandas import DataFrame
  5. from modelscope.metainfo import Metrics
  6. from modelscope.metrics.base import Metric
  7. from modelscope.metrics.builder import METRICS, MetricKeys
  8. from modelscope.models.nlp.unite.configuration import InputFormat
  9. from modelscope.utils.logger import get_logger
  10. from modelscope.utils.registry import default_group
  11. logger = get_logger()
  12. @METRICS.register_module(
  13. group_key=default_group, module_name=Metrics.translation_evaluation_metric)
  14. class TranslationEvaluationMetric(Metric):
  15. r"""The metric class for translation evaluation.
  16. """
  17. def __init__(self, gap_threshold: float = 25.0):
  18. r"""Build a translation evaluation metric, following the designed
  19. Kendall's tau correlation from WMT Metrics Shared Task competitions.
  20. Args:
  21. gap_threshold: The score gap denoting the available hypothesis pair.
  22. Returns:
  23. A metric for translation evaluation.
  24. """
  25. self.gap_threshold = gap_threshold
  26. self.lp = list()
  27. self.segment_id = list()
  28. self.raw_score = list()
  29. self.score = list()
  30. self.input_format = list()
  31. def clear(self) -> None:
  32. r"""Clear all the stored variables.
  33. """
  34. self.lp.clear()
  35. self.segment_id.clear()
  36. self.raw_score.clear()
  37. self.input_format.clear()
  38. self.score.clear()
  39. return
  40. def add(self, outputs: Dict[str, List[float]],
  41. inputs: Dict[str, List[Union[float, int]]]) -> None:
  42. r"""Collect the related results for processing.
  43. Args:
  44. outputs: Dict containing 'scores'
  45. inputs: Dict containing 'labels' and 'segment_ids'
  46. """
  47. self.lp += inputs['lp']
  48. self.segment_id += inputs['segment_id']
  49. self.raw_score += inputs['raw_score']
  50. self.input_format += inputs['input_format']
  51. self.score += outputs['score']
  52. return
  53. def evaluate(self) -> Dict[str, Dict[str, float]]:
  54. r"""Compute the Kendall's tau correlation.
  55. Returns:
  56. A dict denoting Kendall's tau correlation.
  57. """
  58. data = {
  59. 'lp': self.lp,
  60. 'segment_id': self.segment_id,
  61. 'raw_score': self.raw_score,
  62. 'input_format': self.input_format,
  63. 'score': self.score
  64. }
  65. data = DataFrame(data=data)
  66. correlation = dict()
  67. for input_format in data.input_format.unique():
  68. logger.info('Evaluation results for %s input format'
  69. % input_format.value)
  70. input_format_data = data[data.input_format == input_format]
  71. temp_correlation = dict()
  72. for lp in sorted(input_format_data.lp.unique()):
  73. sub_data = input_format_data[input_format_data.lp == lp]
  74. temp_correlation[input_format.value + '_'
  75. + lp] = self.compute_kendall_tau(sub_data)
  76. logger.info(
  77. '\t%s: %f' %
  78. (lp,
  79. temp_correlation[input_format.value + '_' + lp] * 100))
  80. avg_correlation = sum(
  81. temp_correlation.values()) / len(temp_correlation)
  82. correlation[input_format.value + '_avg'] = avg_correlation
  83. logger.info('Average evaluation result for %s input format: %f' %
  84. (input_format.value, avg_correlation))
  85. logger.info('')
  86. correlation.update(temp_correlation)
  87. return correlation
  88. def merge(self, other: 'TranslationEvaluationMetric') -> None:
  89. r"""Merge the predictions from other TranslationEvaluationMetric objects.
  90. Args:
  91. other: Another TranslationEvaluationMetric object.
  92. """
  93. self.lp += other.lp
  94. self.segment_id += other.segment_ids
  95. self.raw_score += other.raw_score
  96. self.input_format += other.input_format
  97. self.score += other.score
  98. return
  99. def compute_kendall_tau(self, csv_data: DataFrame) -> float:
  100. r"""Compute kendall's tau correlation.
  101. Args:
  102. csv_data: The pandas dataframe.
  103. Returns:
  104. float: THe kendall's Tau correlation.
  105. """
  106. concor = discor = 0
  107. for segment_id in sorted(csv_data.segment_id.unique()):
  108. group_csv_data = csv_data[csv_data.segment_id == segment_id]
  109. examples = group_csv_data.to_dict('records')
  110. for i in range(0, len(examples)):
  111. for j in range(i + 1, len(examples)):
  112. if self.raw_score[i] - self.raw_score[
  113. j] >= self.gap_threshold:
  114. if self.score[i] > self.score[j]:
  115. concor += 1
  116. elif self.score[i] < self.score[j]:
  117. discor += 1
  118. elif self.raw_score[i] - self.raw_score[
  119. j] <= -self.gap_threshold:
  120. if self.score[i] < self.score[j]:
  121. concor += 1
  122. elif self.score[i] > self.score[j]:
  123. discor += 1
  124. if concor + discor == 0:
  125. logger.warning(
  126. 'We don\'t have available pairs when evaluation. '
  127. 'Marking the kendall tau correlation as the lowest value (-1.0).'
  128. )
  129. return -1.0
  130. else:
  131. return (concor - discor) / (concor + discor)