bleu_metric.py 1.7 KB

12345678910111213141516171819202122232425262728293031323334353637383940414243444546474849505152
  1. from itertools import zip_longest
  2. from typing import Dict
  3. import sacrebleu
  4. from modelscope.metainfo import Metrics
  5. from modelscope.utils.registry import default_group
  6. from .base import Metric
  7. from .builder import METRICS, MetricKeys
  8. EVAL_BLEU_ORDER = 4
  9. @METRICS.register_module(group_key=default_group, module_name=Metrics.BLEU)
  10. class BleuMetric(Metric):
  11. """The metric computation bleu for text generation classes.
  12. This metric class calculates accuracy for the whole input batches.
  13. """
  14. def __init__(self, *args, **kwargs):
  15. super().__init__(*args, **kwargs)
  16. self.eval_tokenized_bleu = kwargs.get('eval_tokenized_bleu', False)
  17. self.hyp_name = kwargs.get('hyp_name', 'hyp')
  18. self.ref_name = kwargs.get('ref_name', 'ref')
  19. self.refs = list()
  20. self.hyps = list()
  21. def add(self, outputs: Dict, inputs: Dict):
  22. self.refs.extend(inputs[self.ref_name])
  23. self.hyps.extend(outputs[self.hyp_name])
  24. def evaluate(self):
  25. if self.eval_tokenized_bleu:
  26. bleu = sacrebleu.corpus_bleu(
  27. self.hyps, list(zip_longest(*self.refs)), tokenize='none')
  28. else:
  29. bleu = sacrebleu.corpus_bleu(self.hyps,
  30. list(zip_longest(*self.refs)))
  31. return {
  32. MetricKeys.BLEU_4: bleu.score,
  33. }
  34. def merge(self, other: 'BleuMetric'):
  35. self.refs.extend(other.refs)
  36. self.hyps.extend(other.hyps)
  37. def __getstate__(self):
  38. return self.eval_tokenized_bleu, self.hyp_name, self.ref_name, self.refs, self.hyps
  39. def __setstate__(self, state):
  40. self.eval_tokenized_bleu, self.hyp_name, self.ref_name, self.refs, self.hyps = state