rec_metric.py 9.4 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297
  1. # copyright (c) 2020 PaddlePaddle Authors. All Rights Reserve.
  2. #
  3. # Licensed under the Apache License, Version 2.0 (the "License");
  4. # you may not use this file except in compliance with the License.
  5. # You may obtain a copy of the License at
  6. #
  7. # http://www.apache.org/licenses/LICENSE-2.0
  8. #
  9. # Unless required by applicable law or agreed to in writing, software
  10. # distributed under the License is distributed on an "AS IS" BASIS,
  11. # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  12. # See the License for the specific language governing permissions and
  13. # limitations under the License.
  14. from rapidfuzz.distance import Levenshtein
  15. from difflib import SequenceMatcher
  16. import numpy as np
  17. import string
  18. from .bleu import compute_bleu_score, compute_edit_distance
  19. class RecMetric(object):
  20. def __init__(
  21. self, main_indicator="acc", is_filter=False, ignore_space=True, **kwargs
  22. ):
  23. self.main_indicator = main_indicator
  24. self.is_filter = is_filter
  25. self.ignore_space = ignore_space
  26. self.eps = 1e-5
  27. self.reset()
  28. def _normalize_text(self, text):
  29. text = "".join(
  30. filter(lambda x: x in (string.digits + string.ascii_letters), text)
  31. )
  32. return text.lower()
  33. def __call__(self, pred_label, *args, **kwargs):
  34. preds, labels = pred_label
  35. correct_num = 0
  36. all_num = 0
  37. norm_edit_dis = 0.0
  38. for (pred, pred_conf), (target, _) in zip(preds, labels):
  39. if self.ignore_space:
  40. pred = pred.replace(" ", "")
  41. target = target.replace(" ", "")
  42. if self.is_filter:
  43. pred = self._normalize_text(pred)
  44. target = self._normalize_text(target)
  45. norm_edit_dis += Levenshtein.normalized_distance(pred, target)
  46. if pred == target:
  47. correct_num += 1
  48. all_num += 1
  49. self.correct_num += correct_num
  50. self.all_num += all_num
  51. self.norm_edit_dis += norm_edit_dis
  52. return {
  53. "acc": correct_num / (all_num + self.eps),
  54. "norm_edit_dis": 1 - norm_edit_dis / (all_num + self.eps),
  55. }
  56. def get_metric(self):
  57. """
  58. return metrics {
  59. 'acc': 0,
  60. 'norm_edit_dis': 0,
  61. }
  62. """
  63. acc = 1.0 * self.correct_num / (self.all_num + self.eps)
  64. norm_edit_dis = 1 - self.norm_edit_dis / (self.all_num + self.eps)
  65. self.reset()
  66. return {"acc": acc, "norm_edit_dis": norm_edit_dis}
  67. def reset(self):
  68. self.correct_num = 0
  69. self.all_num = 0
  70. self.norm_edit_dis = 0
  71. class CNTMetric(object):
  72. def __init__(self, main_indicator="acc", **kwargs):
  73. self.main_indicator = main_indicator
  74. self.eps = 1e-5
  75. self.reset()
  76. def __call__(self, pred_label, *args, **kwargs):
  77. preds, labels = pred_label
  78. correct_num = 0
  79. all_num = 0
  80. for pred, target in zip(preds, labels):
  81. if pred == target:
  82. correct_num += 1
  83. all_num += 1
  84. self.correct_num += correct_num
  85. self.all_num += all_num
  86. return {
  87. "acc": correct_num / (all_num + self.eps),
  88. }
  89. def get_metric(self):
  90. """
  91. return metrics {
  92. 'acc': 0,
  93. }
  94. """
  95. acc = 1.0 * self.correct_num / (self.all_num + self.eps)
  96. self.reset()
  97. return {"acc": acc}
  98. def reset(self):
  99. self.correct_num = 0
  100. self.all_num = 0
  101. class CANMetric(object):
  102. def __init__(self, main_indicator="exp_rate", **kwargs):
  103. self.main_indicator = main_indicator
  104. self.word_right = []
  105. self.exp_right = []
  106. self.word_total_length = 0
  107. self.exp_total_num = 0
  108. self.word_rate = 0
  109. self.exp_rate = 0
  110. self.reset()
  111. self.epoch_reset()
  112. def __call__(self, preds, batch, **kwargs):
  113. for k, v in kwargs.items():
  114. epoch_reset = v
  115. if epoch_reset:
  116. self.epoch_reset()
  117. word_probs = preds
  118. word_label, word_label_mask = batch
  119. line_right = 0
  120. if word_probs is not None:
  121. word_pred = word_probs.argmax(2)
  122. word_pred = word_pred.cpu().detach().numpy()
  123. word_scores = [
  124. SequenceMatcher(
  125. None, s1[: int(np.sum(s3))], s2[: int(np.sum(s3))], autojunk=False
  126. ).ratio()
  127. * (len(s1[: int(np.sum(s3))]) + len(s2[: int(np.sum(s3))]))
  128. / len(s1[: int(np.sum(s3))])
  129. / 2
  130. for s1, s2, s3 in zip(word_label, word_pred, word_label_mask)
  131. ]
  132. batch_size = len(word_scores)
  133. for i in range(batch_size):
  134. if word_scores[i] == 1:
  135. line_right += 1
  136. self.word_rate = np.mean(word_scores) # float
  137. self.exp_rate = line_right / batch_size # float
  138. exp_length, word_length = word_label.shape[:2]
  139. self.word_right.append(self.word_rate * word_length)
  140. self.exp_right.append(self.exp_rate * exp_length)
  141. self.word_total_length = self.word_total_length + word_length
  142. self.exp_total_num = self.exp_total_num + exp_length
  143. def get_metric(self):
  144. """
  145. return {
  146. 'word_rate': 0,
  147. "exp_rate": 0,
  148. }
  149. """
  150. cur_word_rate = sum(self.word_right) / self.word_total_length
  151. cur_exp_rate = sum(self.exp_right) / self.exp_total_num
  152. self.reset()
  153. return {"word_rate": cur_word_rate, "exp_rate": cur_exp_rate}
  154. def reset(self):
  155. self.word_rate = 0
  156. self.exp_rate = 0
  157. def epoch_reset(self):
  158. self.word_right = []
  159. self.exp_right = []
  160. self.word_total_length = 0
  161. self.exp_total_num = 0
  162. class LaTeXOCRMetric(object):
  163. def __init__(self, main_indicator="exp_rate", cal_bleu_score=False, **kwargs):
  164. self.main_indicator = main_indicator
  165. self.cal_bleu_score = cal_bleu_score
  166. self.edit_right = []
  167. self.exp_right = []
  168. self.bleu_right = []
  169. self.e1_right = []
  170. self.e2_right = []
  171. self.e3_right = []
  172. self.exp_total_num = 0
  173. self.edit_dist = 0
  174. self.exp_rate = 0
  175. if self.cal_bleu_score:
  176. self.bleu_score = 0
  177. self.e1 = 0
  178. self.e2 = 0
  179. self.e3 = 0
  180. self.reset()
  181. self.epoch_reset()
  182. def __call__(self, preds, batch, **kwargs):
  183. for k, v in kwargs.items():
  184. epoch_reset = v
  185. if epoch_reset:
  186. self.epoch_reset()
  187. word_pred = preds
  188. word_label = batch
  189. line_right, e1, e2, e3 = 0, 0, 0, 0
  190. bleu_list, lev_dist = [], []
  191. for labels, prediction in zip(word_label, word_pred):
  192. if prediction == labels:
  193. line_right += 1
  194. distance = compute_edit_distance(prediction, labels)
  195. bleu_list.append(compute_bleu_score([prediction], [labels]))
  196. lev_dist.append(Levenshtein.normalized_distance(prediction, labels))
  197. if distance <= 1:
  198. e1 += 1
  199. if distance <= 2:
  200. e2 += 1
  201. if distance <= 3:
  202. e3 += 1
  203. batch_size = len(lev_dist)
  204. self.edit_dist = sum(lev_dist) # float
  205. self.exp_rate = line_right # float
  206. if self.cal_bleu_score:
  207. self.bleu_score = sum(bleu_list)
  208. self.bleu_right.append(self.bleu_score)
  209. self.e1 = e1
  210. self.e2 = e2
  211. self.e3 = e3
  212. exp_length = len(word_label)
  213. self.edit_right.append(self.edit_dist)
  214. self.exp_right.append(self.exp_rate)
  215. self.e1_right.append(self.e1)
  216. self.e2_right.append(self.e2)
  217. self.e3_right.append(self.e3)
  218. self.exp_total_num = self.exp_total_num + exp_length
  219. def get_metric(self):
  220. """
  221. return {
  222. 'edit distance': 0,
  223. "bleu_score": 0,
  224. "exp_rate": 0,
  225. }
  226. """
  227. cur_edit_distance = sum(self.edit_right) / self.exp_total_num
  228. cur_exp_rate = sum(self.exp_right) / self.exp_total_num
  229. if self.cal_bleu_score:
  230. cur_bleu_score = sum(self.bleu_right) / self.exp_total_num
  231. cur_exp_1 = sum(self.e1_right) / self.exp_total_num
  232. cur_exp_2 = sum(self.e2_right) / self.exp_total_num
  233. cur_exp_3 = sum(self.e3_right) / self.exp_total_num
  234. self.reset()
  235. if self.cal_bleu_score:
  236. return {
  237. "bleu_score": cur_bleu_score,
  238. "edit distance": cur_edit_distance,
  239. "exp_rate": cur_exp_rate,
  240. "exp_rate<=1 ": cur_exp_1,
  241. "exp_rate<=2 ": cur_exp_2,
  242. "exp_rate<=3 ": cur_exp_3,
  243. }
  244. else:
  245. return {
  246. "edit distance": cur_edit_distance,
  247. "exp_rate": cur_exp_rate,
  248. "exp_rate<=1 ": cur_exp_1,
  249. "exp_rate<=2 ": cur_exp_2,
  250. "exp_rate<=3 ": cur_exp_3,
  251. }
  252. def reset(self):
  253. self.edit_dist = 0
  254. self.exp_rate = 0
  255. if self.cal_bleu_score:
  256. self.bleu_score = 0
  257. self.e1 = 0
  258. self.e2 = 0
  259. self.e3 = 0
  260. def epoch_reset(self):
  261. self.edit_right = []
  262. self.exp_right = []
  263. if self.cal_bleu_score:
  264. self.bleu_right = []
  265. self.e1_right = []
  266. self.e2_right = []
  267. self.e3_right = []
  268. self.editdistance_total_length = 0
  269. self.exp_total_num = 0