# Copyright (c) Alibaba, Inc. and its affiliates. import sys from contextlib import contextmanager from typing import Dict, Iterable, List, Tuple from nltk.translate.bleu_score import SmoothingFunction, corpus_bleu from rouge import Rouge from modelscope.metainfo import Metrics from modelscope.metrics.base import Metric from modelscope.metrics.builder import METRICS, MetricKeys from modelscope.utils.chinese_utils import rebuild_chinese_str from modelscope.utils.registry import default_group @METRICS.register_module( group_key=default_group, module_name=Metrics.text_gen_metric) class TextGenerationMetric(Metric): """The metric computation class for text generation classes. This metric class calculates F1 of the rouge scores for the whole evaluation dataset. Args: target_text: The key of the target text column in the `inputs` arg. pred_text: The key of the predicted text column in the `outputs` arg. """ def __init__(self, target_text='tgts', pred_text='preds'): self.preds: List[str] = [] self.tgts: List[str] = [] self.rouge = Rouge() self.target_text = target_text self.pred_text = pred_text def add(self, outputs: Dict[str, List[str]], inputs: Dict[str, List[str]]): ground_truths = inputs[self.target_text] eval_results = outputs[self.pred_text] for truth in ground_truths: self.tgts.append(rebuild_chinese_str(truth)) for result in eval_results: self.preds.append(rebuild_chinese_str(result)) def _check(self, pred: str, tgt: str) -> bool: def remove_useless(string: str) -> str: return string.replace(' ', '').replace('.', '') return len(remove_useless(pred)) != 0 and len(remove_useless(tgt)) != 0 def evaluate(self): assert self.preds, 'preds in TextGenerationMetric must not be empty!' tmp = [(pred, tgt) for pred, tgt in zip(self.preds, self.tgts) if self._check(pred, tgt)] preds, tgts = zip(*tmp) def mean(iter: Iterable) -> float: return sum(iter) / len(self.preds) with extend_recursion_limit(preds, tgts): rouge_scores = self.rouge.get_scores(hyps=preds, refs=tgts) rouge_1 = mean(map(lambda score: score['rouge-1']['f'], rouge_scores)) rouge_l = mean(map(lambda score: score['rouge-l']['f'], rouge_scores)) pred_list = [each.strip().split(' ') for each in self.preds] tgt_list = [[each.strip().split(' ')] for each in self.tgts] bleu_1 = corpus_bleu( tgt_list, pred_list, weights=(1, 0, 0, 0), smoothing_function=SmoothingFunction().method3) bleu_4 = corpus_bleu( tgt_list, pred_list, smoothing_function=SmoothingFunction().method3) return { MetricKeys.ROUGE_1: rouge_1, MetricKeys.ROUGE_L: rouge_l, MetricKeys.BLEU_1: bleu_1, MetricKeys.BLEU_4: bleu_4 } def merge(self, other: 'TextGenerationMetric'): self.preds.extend(other.preds) self.tgts.extend(other.tgts) def __getstate__(self): return self.preds, self.tgts def __setstate__(self, state): self.__init__() self.preds, self.tgts = state @contextmanager def extend_recursion_limit(preds: Tuple[str], tgts: Tuple[str]): origin_limit = sys.getrecursionlimit() new_limit = max(len(pred) for pred in preds) * max(len(tgt) for tgt in tgts) if new_limit > origin_limit: sys.setrecursionlimit(new_limit) yield sys.setrecursionlimit(origin_limit)