text_generation_metric.py 3.6 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103
  1. # Copyright (c) Alibaba, Inc. and its affiliates.
  2. import sys
  3. from contextlib import contextmanager
  4. from typing import Dict, Iterable, List, Tuple
  5. from nltk.translate.bleu_score import SmoothingFunction, corpus_bleu
  6. from rouge import Rouge
  7. from modelscope.metainfo import Metrics
  8. from modelscope.metrics.base import Metric
  9. from modelscope.metrics.builder import METRICS, MetricKeys
  10. from modelscope.utils.chinese_utils import rebuild_chinese_str
  11. from modelscope.utils.registry import default_group
  12. @METRICS.register_module(
  13. group_key=default_group, module_name=Metrics.text_gen_metric)
  14. class TextGenerationMetric(Metric):
  15. """The metric computation class for text generation classes.
  16. This metric class calculates F1 of the rouge scores for the whole evaluation dataset.
  17. Args:
  18. target_text: The key of the target text column in the `inputs` arg.
  19. pred_text: The key of the predicted text column in the `outputs` arg.
  20. """
  21. def __init__(self, target_text='tgts', pred_text='preds'):
  22. self.preds: List[str] = []
  23. self.tgts: List[str] = []
  24. self.rouge = Rouge()
  25. self.target_text = target_text
  26. self.pred_text = pred_text
  27. def add(self, outputs: Dict[str, List[str]], inputs: Dict[str, List[str]]):
  28. ground_truths = inputs[self.target_text]
  29. eval_results = outputs[self.pred_text]
  30. for truth in ground_truths:
  31. self.tgts.append(rebuild_chinese_str(truth))
  32. for result in eval_results:
  33. self.preds.append(rebuild_chinese_str(result))
  34. def _check(self, pred: str, tgt: str) -> bool:
  35. def remove_useless(string: str) -> str:
  36. return string.replace(' ', '').replace('.', '')
  37. return len(remove_useless(pred)) != 0 and len(remove_useless(tgt)) != 0
  38. def evaluate(self):
  39. assert self.preds, 'preds in TextGenerationMetric must not be empty!'
  40. tmp = [(pred, tgt) for pred, tgt in zip(self.preds, self.tgts)
  41. if self._check(pred, tgt)]
  42. preds, tgts = zip(*tmp)
  43. def mean(iter: Iterable) -> float:
  44. return sum(iter) / len(self.preds)
  45. with extend_recursion_limit(preds, tgts):
  46. rouge_scores = self.rouge.get_scores(hyps=preds, refs=tgts)
  47. rouge_1 = mean(map(lambda score: score['rouge-1']['f'], rouge_scores))
  48. rouge_l = mean(map(lambda score: score['rouge-l']['f'], rouge_scores))
  49. pred_list = [each.strip().split(' ') for each in self.preds]
  50. tgt_list = [[each.strip().split(' ')] for each in self.tgts]
  51. bleu_1 = corpus_bleu(
  52. tgt_list,
  53. pred_list,
  54. weights=(1, 0, 0, 0),
  55. smoothing_function=SmoothingFunction().method3)
  56. bleu_4 = corpus_bleu(
  57. tgt_list,
  58. pred_list,
  59. smoothing_function=SmoothingFunction().method3)
  60. return {
  61. MetricKeys.ROUGE_1: rouge_1,
  62. MetricKeys.ROUGE_L: rouge_l,
  63. MetricKeys.BLEU_1: bleu_1,
  64. MetricKeys.BLEU_4: bleu_4
  65. }
  66. def merge(self, other: 'TextGenerationMetric'):
  67. self.preds.extend(other.preds)
  68. self.tgts.extend(other.tgts)
  69. def __getstate__(self):
  70. return self.preds, self.tgts
  71. def __setstate__(self, state):
  72. self.__init__()
  73. self.preds, self.tgts = state
  74. @contextmanager
  75. def extend_recursion_limit(preds: Tuple[str], tgts: Tuple[str]):
  76. origin_limit = sys.getrecursionlimit()
  77. new_limit = max(len(pred)
  78. for pred in preds) * max(len(tgt) for tgt in tgts)
  79. if new_limit > origin_limit:
  80. sys.setrecursionlimit(new_limit)
  81. yield
  82. sys.setrecursionlimit(origin_limit)