ned_metric.py 3.1 KB

1234567891011121314151617181920212223242526272829303132333435363738394041424344454647484950515253545556575859606162636465666768697071727374757677787980818283848586878889909192939495969798
  1. # Copyright (c) Alibaba, Inc. and its affiliates.
  2. from typing import Dict
  3. import numpy as np
  4. from modelscope.metainfo import Metrics
  5. from modelscope.outputs import OutputKeys
  6. from modelscope.utils.registry import default_group
  7. from .base import Metric
  8. from .builder import METRICS, MetricKeys
  9. @METRICS.register_module(group_key=default_group, module_name=Metrics.NED)
  10. class NedMetric(Metric):
  11. """The ned metric computation class for classification classes.
  12. This metric class calculates the levenshtein distance between sentences for the whole input batches.
  13. """
  14. def __init__(self, *args, **kwargs):
  15. super().__init__(*args, **kwargs)
  16. self.preds = []
  17. self.labels = []
  18. def add(self, outputs: Dict, inputs: Dict):
  19. label_name = OutputKeys.LABEL if OutputKeys.LABEL in inputs else OutputKeys.LABELS
  20. ground_truths = inputs[label_name]
  21. eval_results = outputs[label_name]
  22. for key in [
  23. OutputKeys.CAPTION, OutputKeys.TEXT, OutputKeys.BOXES,
  24. OutputKeys.LABELS, OutputKeys.SCORES
  25. ]:
  26. if key in outputs and outputs[key] is not None:
  27. eval_results = outputs[key]
  28. break
  29. assert type(ground_truths) == type(eval_results)
  30. if isinstance(ground_truths, list):
  31. self.preds.extend(eval_results)
  32. self.labels.extend(ground_truths)
  33. elif isinstance(ground_truths, np.ndarray):
  34. self.preds.extend(eval_results.tolist())
  35. self.labels.extend(ground_truths.tolist())
  36. else:
  37. raise Exception('only support list or np.ndarray')
  38. def evaluate(self):
  39. assert len(self.preds) == len(self.labels)
  40. return {
  41. MetricKeys.NED: (np.asarray([
  42. 1.0 - NedMetric._distance(pred, ref)
  43. for pred, ref in zip(self.preds, self.labels)
  44. ])).mean().item()
  45. }
  46. def merge(self, other: 'NedMetric'):
  47. self.preds.extend(other.preds)
  48. self.labels.extend(other.labels)
  49. def __getstate__(self):
  50. return self.preds, self.labels
  51. def __setstate__(self, state):
  52. self.__init__()
  53. self.preds, self.labels = state
  54. @staticmethod
  55. def _distance(pred, ref):
  56. if pred is None or ref is None:
  57. raise TypeError('Argument (pred or ref) is NoneType.')
  58. if pred == ref:
  59. return 0.0
  60. if len(pred) == 0:
  61. return len(ref)
  62. if len(ref) == 0:
  63. return len(pred)
  64. m_len = max(len(pred), len(ref))
  65. if m_len == 0:
  66. return 0.0
  67. def levenshtein(s0, s1):
  68. v0 = [0] * (len(s1) + 1)
  69. v1 = [0] * (len(s1) + 1)
  70. for i in range(len(v0)):
  71. v0[i] = i
  72. for i in range(len(s0)):
  73. v1[0] = i + 1
  74. for j in range(len(s1)):
  75. cost = 1
  76. if s0[i] == s1[j]:
  77. cost = 0
  78. v1[j + 1] = min(v1[j] + 1, v0[j + 1] + 1, v0[j] + cost)
  79. v0, v1 = v1, v0
  80. return v0[len(s1)]
  81. return levenshtein(pred, ref) / m_len