ocr_recognition_metric.py 2.3 KB

12345678910111213141516171819202122232425262728293031323334353637383940414243444546474849505152535455565758596061626364656667686970717273747576777879
  1. from typing import Dict
  2. import edit_distance as ed
  3. import numpy as np
  4. import torch
  5. import torch.nn.functional as F
  6. from modelscope.metainfo import Metrics
  7. from modelscope.utils.registry import default_group
  8. from .base import Metric
  9. from .builder import METRICS, MetricKeys
  10. def cal_distance(label_list, pre_list):
  11. y = ed.SequenceMatcher(a=label_list, b=pre_list)
  12. yy = y.get_opcodes()
  13. insert = 0
  14. delete = 0
  15. replace = 0
  16. for item in yy:
  17. if item[0] == 'insert':
  18. insert += item[-1] - item[-2]
  19. if item[0] == 'delete':
  20. delete += item[2] - item[1]
  21. if item[0] == 'replace':
  22. replace += item[-1] - item[-2]
  23. distance = insert + delete + replace
  24. return distance, (delete, replace, insert)
  25. @METRICS.register_module(
  26. group_key=default_group, module_name=Metrics.ocr_recognition_metric)
  27. class OCRRecognitionMetric(Metric):
  28. """The metric computation class for ocr recognition.
  29. """
  30. def __init__(self, *args, **kwargs):
  31. self.preds = []
  32. self.targets = []
  33. self.loss_sum = 0.
  34. self.nsample = 0
  35. self.iter_sum = 0
  36. def add(self, outputs: Dict, inputs: Dict):
  37. pred = outputs['preds']
  38. loss = outputs['loss']
  39. target = inputs['labels']
  40. self.preds.extend(pred)
  41. self.targets.extend(target)
  42. self.loss_sum += loss.data.cpu().numpy()
  43. self.nsample += len(pred)
  44. self.iter_sum += 1
  45. def evaluate(self):
  46. total_chars = 0
  47. total_distance = 0
  48. total_fullmatch = 0
  49. for (pred, target) in zip(self.preds, self.targets):
  50. distance, _ = cal_distance(target, pred)
  51. total_chars += len(target)
  52. total_distance += distance
  53. total_fullmatch += (target == pred)
  54. accuracy = float(total_fullmatch) / self.nsample
  55. AR = 1 - float(total_distance) / total_chars
  56. average_loss = self.loss_sum / self.iter_sum if self.iter_sum > 0 else 0
  57. return {
  58. MetricKeys.ACCURACY: accuracy,
  59. MetricKeys.AR: AR,
  60. MetricKeys.AVERAGE_LOSS: average_loss
  61. }
  62. def merge(self, other: 'OCRRecognitionMetric'):
  63. pass
  64. def __getstate__(self):
  65. pass
  66. def __setstate__(self, state):
  67. pass