bleu.py 7.9 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243
  1. # copyright (c) 2024 PaddlePaddle Authors. All Rights Reserve.
  2. #
  3. # Licensed under the Apache License, Version 2.0 (the "License");
  4. # you may not use this file except in compliance with the License.
  5. # You may obtain a copy of the License at
  6. #
  7. # http://www.apache.org/licenses/LICENSE-2.0
  8. #
  9. # Unless required by applicable law or agreed to in writing, software
  10. # distributed under the License is distributed on an "AS IS" BASIS,
  11. # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  12. # See the License for the specific language governing permissions and
  13. # limitations under the License.
  14. """
  15. This code is refer from:
  16. https://github.com/tensorflow/nmt/blob/master/nmt/scripts/bleu.py
  17. """
  18. import re
  19. import math
  20. import collections
  21. from functools import lru_cache
  22. def _get_ngrams(segment, max_order):
  23. """Extracts all n-grams upto a given maximum order from an input segment.
  24. Args:
  25. segment: text segment from which n-grams will be extracted.
  26. max_order: maximum length in tokens of the n-grams returned by this
  27. methods.
  28. Returns:
  29. The Counter containing all n-grams upto max_order in segment
  30. with a count of how many times each n-gram occurred.
  31. """
  32. ngram_counts = collections.Counter()
  33. for order in range(1, max_order + 1):
  34. for i in range(0, len(segment) - order + 1):
  35. ngram = tuple(segment[i : i + order])
  36. ngram_counts[ngram] += 1
  37. return ngram_counts
  38. def compute_bleu(reference_corpus, translation_corpus, max_order=4, smooth=False):
  39. """Computes BLEU score of translated segments against one or more references.
  40. Args:
  41. reference_corpus: list of lists of references for each translation. Each
  42. reference should be tokenized into a list of tokens.
  43. translation_corpus: list of translations to score. Each translation
  44. should be tokenized into a list of tokens.
  45. max_order: Maximum n-gram order to use when computing BLEU score.
  46. smooth: Whether or not to apply Lin et al. 2004 smoothing.
  47. Returns:
  48. 3-Tuple with the BLEU score, n-gram precisions, geometric mean of n-gram
  49. precisions and brevity penalty.
  50. """
  51. matches_by_order = [0] * max_order
  52. possible_matches_by_order = [0] * max_order
  53. reference_length = 0
  54. translation_length = 0
  55. for references, translation in zip(reference_corpus, translation_corpus):
  56. reference_length += min(len(r) for r in references)
  57. translation_length += len(translation)
  58. merged_ref_ngram_counts = collections.Counter()
  59. for reference in references:
  60. merged_ref_ngram_counts |= _get_ngrams(reference, max_order)
  61. translation_ngram_counts = _get_ngrams(translation, max_order)
  62. overlap = translation_ngram_counts & merged_ref_ngram_counts
  63. for ngram in overlap:
  64. matches_by_order[len(ngram) - 1] += overlap[ngram]
  65. for order in range(1, max_order + 1):
  66. possible_matches = len(translation) - order + 1
  67. if possible_matches > 0:
  68. possible_matches_by_order[order - 1] += possible_matches
  69. precisions = [0] * max_order
  70. for i in range(0, max_order):
  71. if smooth:
  72. precisions[i] = (matches_by_order[i] + 1.0) / (
  73. possible_matches_by_order[i] + 1.0
  74. )
  75. else:
  76. if possible_matches_by_order[i] > 0:
  77. precisions[i] = (
  78. float(matches_by_order[i]) / possible_matches_by_order[i]
  79. )
  80. else:
  81. precisions[i] = 0.0
  82. if min(precisions) > 0:
  83. p_log_sum = sum((1.0 / max_order) * math.log(p) for p in precisions)
  84. geo_mean = math.exp(p_log_sum)
  85. else:
  86. geo_mean = 0
  87. if float(translation_length) == 0 or float(reference_length) == 0:
  88. ratio = 1e-5
  89. else:
  90. ratio = float(translation_length) / reference_length
  91. if ratio > 1.0:
  92. bp = 1.0
  93. else:
  94. bp = math.exp(1 - 1.0 / ratio)
  95. bleu = geo_mean * bp
  96. return (bleu, precisions, bp, ratio, translation_length, reference_length)
  97. class BaseTokenizer:
  98. """A base dummy tokenizer to derive from."""
  99. def signature(self):
  100. """
  101. Returns a signature for the tokenizer.
  102. :return: signature string
  103. """
  104. return "none"
  105. def __call__(self, line):
  106. """
  107. Tokenizes an input line with the tokenizer.
  108. :param line: a segment to tokenize
  109. :return: the tokenized line
  110. """
  111. return line
  112. class TokenizerRegexp(BaseTokenizer):
  113. def signature(self):
  114. return "re"
  115. def __init__(self):
  116. self._re = [
  117. # language-dependent part (assuming Western languages)
  118. (re.compile(r"([\{-\~\[-\` -\&\(-\+\:-\@\/])"), r" \1 "),
  119. # tokenize period and comma unless preceded by a digit
  120. (re.compile(r"([^0-9])([\.,])"), r"\1 \2 "),
  121. # tokenize period and comma unless followed by a digit
  122. (re.compile(r"([\.,])([^0-9])"), r" \1 \2"),
  123. # tokenize dash when preceded by a digit
  124. (re.compile(r"([0-9])(-)"), r"\1 \2 "),
  125. # one space only between words
  126. # NOTE: Doing this in Python (below) is faster
  127. # (re.compile(r'\s+'), r' '),
  128. ]
  129. @lru_cache(maxsize=2**16)
  130. def __call__(self, line):
  131. """Common post-processing tokenizer for `13a` and `zh` tokenizers.
  132. :param line: a segment to tokenize
  133. :return: the tokenized line
  134. """
  135. for _re, repl in self._re:
  136. line = _re.sub(repl, line)
  137. # no leading or trailing spaces, single space within words
  138. # return ' '.join(line.split())
  139. # This line is changed with regards to the original tokenizer (seen above) to return individual words
  140. return line.split()
  141. class Tokenizer13a(BaseTokenizer):
  142. def signature(self):
  143. return "13a"
  144. def __init__(self):
  145. self._post_tokenizer = TokenizerRegexp()
  146. @lru_cache(maxsize=2**16)
  147. def __call__(self, line):
  148. """Tokenizes an input line using a relatively minimal tokenization
  149. that is however equivalent to mteval-v13a, used by WMT.
  150. :param line: a segment to tokenize
  151. :return: the tokenized line
  152. """
  153. # language-independent part:
  154. line = line.replace("<skipped>", "")
  155. line = line.replace("-\n", "")
  156. line = line.replace("\n", " ")
  157. if "&" in line:
  158. line = line.replace("&quot;", '"')
  159. line = line.replace("&amp;", "&")
  160. line = line.replace("&lt;", "<")
  161. line = line.replace("&gt;", ">")
  162. return self._post_tokenizer(f" {line} ")
  163. def compute_bleu_score(
  164. predictions, references, tokenizer=Tokenizer13a(), max_order=4, smooth=False
  165. ):
  166. # if only one reference is provided make sure we still use list of lists
  167. if isinstance(references[0], str):
  168. references = [[ref] for ref in references]
  169. references = [[tokenizer(r) for r in ref] for ref in references]
  170. predictions = [tokenizer(p) for p in predictions]
  171. score = compute_bleu(
  172. reference_corpus=references,
  173. translation_corpus=predictions,
  174. max_order=max_order,
  175. smooth=smooth,
  176. )
  177. (bleu, precisions, bp, ratio, translation_length, reference_length) = score
  178. return bleu
  179. def cal_distance(word1, word2):
  180. m = len(word1)
  181. n = len(word2)
  182. if m * n == 0:
  183. return m + n
  184. dp = [[0] * (n + 1) for _ in range(m + 1)]
  185. for i in range(m + 1):
  186. dp[i][0] = i
  187. for j in range(n + 1):
  188. dp[0][j] = j
  189. for i in range(1, m + 1):
  190. for j in range(1, n + 1):
  191. a = dp[i - 1][j] + 1
  192. b = dp[i][j - 1] + 1
  193. c = dp[i - 1][j - 1]
  194. if word1[i - 1] != word2[j - 1]:
  195. c += 1
  196. dp[i][j] = min(a, b, c)
  197. return dp[m][n]
  198. def compute_edit_distance(prediction, label):
  199. prediction = prediction.strip().split(" ")
  200. label = label.strip().split(" ")
  201. distance = cal_distance(prediction, label)
  202. return distance