text_ranking_metric.py 3.1 KB

12345678910111213141516171819202122232425262728293031323334353637383940414243444546474849505152535455565758596061626364656667686970717273747576777879808182838485868788899091
  1. # Copyright (c) Alibaba, Inc. and its affiliates.
  2. from typing import Dict, List
  3. import numpy as np
  4. from modelscope.metainfo import Metrics
  5. from modelscope.metrics.base import Metric
  6. from modelscope.metrics.builder import METRICS, MetricKeys
  7. from modelscope.utils.registry import default_group
  8. @METRICS.register_module(
  9. group_key=default_group, module_name=Metrics.text_ranking_metric)
  10. class TextRankingMetric(Metric):
  11. """The metric computation class for text ranking classes.
  12. This metric class calculates mrr and ndcg metric for the whole evaluation dataset.
  13. Args:
  14. target_text: The key of the target text column in the `inputs` arg.
  15. pred_text: The key of the predicted text column in the `outputs` arg.
  16. """
  17. def __init__(self, mrr_k: int = 1, ndcg_k: int = 1):
  18. self.labels: List = []
  19. self.qids: List = []
  20. self.logits: List = []
  21. self.mrr_k: int = mrr_k
  22. self.ndcg_k: int = ndcg_k
  23. def add(self, outputs: Dict[str, List], inputs: Dict[str, List]):
  24. self.labels.extend(inputs.pop('labels').detach().cpu().numpy())
  25. self.qids.extend(inputs.pop('qid').detach().cpu().numpy())
  26. logits = outputs['logits'].squeeze(-1).detach().cpu().numpy()
  27. logits = self._sigmoid(logits).tolist()
  28. self.logits.extend(logits)
  29. def evaluate(self):
  30. rank_result = {}
  31. for qid, score, label in zip(self.qids, self.logits, self.labels):
  32. if qid not in rank_result:
  33. rank_result[qid] = []
  34. rank_result[qid].append((score, label))
  35. for qid in rank_result:
  36. rank_result[qid] = sorted(rank_result[qid], key=lambda x: x[0])
  37. return {
  38. MetricKeys.MRR: self._compute_mrr(rank_result),
  39. MetricKeys.NDCG: self._compute_ndcg(rank_result)
  40. }
  41. @staticmethod
  42. def _sigmoid(logits):
  43. return np.exp(logits) / (1 + np.exp(logits))
  44. def _compute_mrr(self, result):
  45. mrr = 0
  46. for res in result.values():
  47. sorted_res = sorted(res, key=lambda x: x[0], reverse=True)
  48. ar = 0
  49. for index, ele in enumerate(sorted_res[:self.mrr_k]):
  50. if str(ele[1]) == '1':
  51. ar = 1.0 / (index + 1)
  52. break
  53. mrr += ar
  54. return mrr / len(result)
  55. def _compute_ndcg(self, result):
  56. ndcg = 0
  57. from sklearn.metrics import ndcg_score
  58. for res in result.values():
  59. sorted_res = sorted(res, key=lambda x: [0], reverse=True)
  60. labels = np.array([[ele[1] for ele in sorted_res]])
  61. scores = np.array([[ele[0] for ele in sorted_res]])
  62. ndcg += float(ndcg_score(labels, scores, k=self.ndcg_k))
  63. return ndcg / len(result)
  64. def merge(self, other: 'TextRankingMetric'):
  65. self.labels.extend(other.labels)
  66. self.qids.extend(other.qids)
  67. self.logits.extend(other.logits)
  68. def __getstate__(self):
  69. return self.labels, self.qids, self.logits, self.mrr_k, self.ndcg_k
  70. def __setstate__(self, state):
  71. self.__init__()
  72. self.labels, self.qids, self.logits, self.mrr_k, self.ndcg_k = state