vqa_token_re_metric.py 7.3 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191
  1. # Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved.
  2. #
  3. # Licensed under the Apache License, Version 2.0 (the "License");
  4. # you may not use this file except in compliance with the License.
  5. # You may obtain a copy of the License at
  6. #
  7. # http://www.apache.org/licenses/LICENSE-2.0
  8. #
  9. # Unless required by applicable law or agreed to in writing, software
  10. # distributed under the License is distributed on an "AS IS" BASIS,
  11. # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  12. # See the License for the specific language governing permissions and
  13. # limitations under the License.
  14. from __future__ import absolute_import
  15. from __future__ import division
  16. from __future__ import print_function
  17. import numpy as np
  18. import paddle
  19. __all__ = ["VQAReTokenMetric"]
  20. class VQAReTokenMetric(object):
  21. def __init__(self, main_indicator="hmean", **kwargs):
  22. self.main_indicator = main_indicator
  23. self.reset()
  24. def __call__(self, preds, batch, **kwargs):
  25. pred_relations, relations, entities = preds
  26. self.pred_relations_list.extend(pred_relations)
  27. self.relations_list.extend(relations)
  28. self.entities_list.extend(entities)
  29. def get_metric(self):
  30. gt_relations = []
  31. for b in range(len(self.relations_list)):
  32. rel_sent = []
  33. relation_list = self.relations_list[b]
  34. entitie_list = self.entities_list[b]
  35. head_len = relation_list[0, 0]
  36. if head_len > 0:
  37. entitie_start_list = entitie_list[1 : entitie_list[0, 0] + 1, 0]
  38. entitie_end_list = entitie_list[1 : entitie_list[0, 1] + 1, 1]
  39. entitie_label_list = entitie_list[1 : entitie_list[0, 2] + 1, 2]
  40. for head, tail in zip(
  41. relation_list[1 : head_len + 1, 0],
  42. relation_list[1 : head_len + 1, 1],
  43. ):
  44. rel = {}
  45. rel["head_id"] = head
  46. rel["head"] = (entitie_start_list[head], entitie_end_list[head])
  47. rel["head_type"] = entitie_label_list[head]
  48. rel["tail_id"] = tail
  49. rel["tail"] = (entitie_start_list[tail], entitie_end_list[tail])
  50. rel["tail_type"] = entitie_label_list[tail]
  51. rel["type"] = 1
  52. rel_sent.append(rel)
  53. gt_relations.append(rel_sent)
  54. re_metrics = self.re_score(
  55. self.pred_relations_list, gt_relations, mode="boundaries"
  56. )
  57. metrics = {
  58. "precision": re_metrics["ALL"]["p"],
  59. "recall": re_metrics["ALL"]["r"],
  60. "hmean": re_metrics["ALL"]["f1"],
  61. }
  62. self.reset()
  63. return metrics
  64. def reset(self):
  65. self.pred_relations_list = []
  66. self.relations_list = []
  67. self.entities_list = []
  68. def re_score(self, pred_relations, gt_relations, mode="strict"):
  69. """Evaluate RE predictions
  70. Args:
  71. pred_relations (list) : list of list of predicted relations (several relations in each sentence)
  72. gt_relations (list) : list of list of ground truth relations
  73. rel = { "head": (start_idx (inclusive), end_idx (exclusive)),
  74. "tail": (start_idx (inclusive), end_idx (exclusive)),
  75. "head_type": ent_type,
  76. "tail_type": ent_type,
  77. "type": rel_type}
  78. vocab (Vocab) : dataset vocabulary
  79. mode (str) : in 'strict' or 'boundaries'"""
  80. assert mode in ["strict", "boundaries"]
  81. relation_types = [v for v in [0, 1] if not v == 0]
  82. scores = {rel: {"tp": 0, "fp": 0, "fn": 0} for rel in relation_types + ["ALL"]}
  83. # Count GT relations and Predicted relations
  84. n_sents = len(gt_relations)
  85. n_rels = sum([len([rel for rel in sent]) for sent in gt_relations])
  86. n_found = sum([len([rel for rel in sent]) for sent in pred_relations])
  87. # Count TP, FP and FN per type
  88. for pred_sent, gt_sent in zip(pred_relations, gt_relations):
  89. for rel_type in relation_types:
  90. # strict mode takes argument types into account
  91. if mode == "strict":
  92. pred_rels = {
  93. (rel["head"], rel["head_type"], rel["tail"], rel["tail_type"])
  94. for rel in pred_sent
  95. if rel["type"] == rel_type
  96. }
  97. gt_rels = {
  98. (rel["head"], rel["head_type"], rel["tail"], rel["tail_type"])
  99. for rel in gt_sent
  100. if rel["type"] == rel_type
  101. }
  102. # boundaries mode only takes argument spans into account
  103. elif mode == "boundaries":
  104. pred_rels = {
  105. (rel["head"], rel["tail"])
  106. for rel in pred_sent
  107. if rel["type"] == rel_type
  108. }
  109. gt_rels = {
  110. (rel["head"], rel["tail"])
  111. for rel in gt_sent
  112. if rel["type"] == rel_type
  113. }
  114. scores[rel_type]["tp"] += len(pred_rels & gt_rels)
  115. scores[rel_type]["fp"] += len(pred_rels - gt_rels)
  116. scores[rel_type]["fn"] += len(gt_rels - pred_rels)
  117. # Compute per entity Precision / Recall / F1
  118. for rel_type in scores.keys():
  119. if scores[rel_type]["tp"]:
  120. scores[rel_type]["p"] = scores[rel_type]["tp"] / (
  121. scores[rel_type]["fp"] + scores[rel_type]["tp"]
  122. )
  123. scores[rel_type]["r"] = scores[rel_type]["tp"] / (
  124. scores[rel_type]["fn"] + scores[rel_type]["tp"]
  125. )
  126. else:
  127. scores[rel_type]["p"], scores[rel_type]["r"] = 0, 0
  128. if not scores[rel_type]["p"] + scores[rel_type]["r"] == 0:
  129. scores[rel_type]["f1"] = (
  130. 2
  131. * scores[rel_type]["p"]
  132. * scores[rel_type]["r"]
  133. / (scores[rel_type]["p"] + scores[rel_type]["r"])
  134. )
  135. else:
  136. scores[rel_type]["f1"] = 0
  137. # Compute micro F1 Scores
  138. tp = sum([scores[rel_type]["tp"] for rel_type in relation_types])
  139. fp = sum([scores[rel_type]["fp"] for rel_type in relation_types])
  140. fn = sum([scores[rel_type]["fn"] for rel_type in relation_types])
  141. if tp:
  142. precision = tp / (tp + fp)
  143. recall = tp / (tp + fn)
  144. f1 = 2 * precision * recall / (precision + recall)
  145. else:
  146. precision, recall, f1 = 0, 0, 0
  147. scores["ALL"]["p"] = precision
  148. scores["ALL"]["r"] = recall
  149. scores["ALL"]["f1"] = f1
  150. scores["ALL"]["tp"] = tp
  151. scores["ALL"]["fp"] = fp
  152. scores["ALL"]["fn"] = fn
  153. # Compute Macro F1 Scores
  154. scores["ALL"]["Macro_f1"] = np.mean(
  155. [scores[ent_type]["f1"] for ent_type in relation_types]
  156. )
  157. scores["ALL"]["Macro_p"] = np.mean(
  158. [scores[ent_type]["p"] for ent_type in relation_types]
  159. )
  160. scores["ALL"]["Macro_r"] = np.mean(
  161. [scores[ent_type]["r"] for ent_type in relation_types]
  162. )
  163. return scores