table_metric.py 8.6 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249
  1. # Copyright 2020 IBM
  2. # Author: peter.zhong@au1.ibm.com
  3. #
  4. # This is free software; you can redistribute it and/or modify
  5. # it under the terms of the Apache 2.0 License.
  6. #
  7. # This software is distributed in the hope that it will be useful,
  8. # but WITHOUT ANY WARRANTY; without even the implied warranty of
  9. # MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
  10. # Apache 2.0 License for more details.
  11. from rapidfuzz.distance import Levenshtein
  12. from apted import APTED, Config
  13. from apted.helpers import Tree
  14. from collections import deque
  15. from .parallel import parallel_process
  16. from tqdm import tqdm
  17. from paddle.utils import try_import
  18. class TableTree(Tree):
  19. def __init__(self, tag, colspan=None, rowspan=None, content=None, *children):
  20. self.tag = tag
  21. self.colspan = colspan
  22. self.rowspan = rowspan
  23. self.content = content
  24. self.children = list(children)
  25. def bracket(self):
  26. """Show tree using brackets notation"""
  27. if self.tag == "td":
  28. result = '"tag": %s, "colspan": %d, "rowspan": %d, "text": %s' % (
  29. self.tag,
  30. self.colspan,
  31. self.rowspan,
  32. self.content,
  33. )
  34. else:
  35. result = '"tag": %s' % self.tag
  36. for child in self.children:
  37. result += child.bracket()
  38. return "{{{}}}".format(result)
  39. class CustomConfig(Config):
  40. def rename(self, node1, node2):
  41. """Compares attributes of trees"""
  42. # print(node1.tag)
  43. if (
  44. (node1.tag != node2.tag)
  45. or (node1.colspan != node2.colspan)
  46. or (node1.rowspan != node2.rowspan)
  47. ):
  48. return 1.0
  49. if node1.tag == "td":
  50. if node1.content or node2.content:
  51. # print(node1.content, )
  52. return Levenshtein.normalized_distance(node1.content, node2.content)
  53. return 0.0
  54. class CustomConfig_del_short(Config):
  55. def rename(self, node1, node2):
  56. """Compares attributes of trees"""
  57. if (
  58. (node1.tag != node2.tag)
  59. or (node1.colspan != node2.colspan)
  60. or (node1.rowspan != node2.rowspan)
  61. ):
  62. return 1.0
  63. if node1.tag == "td":
  64. if node1.content or node2.content:
  65. # print('before')
  66. # print(node1.content, node2.content)
  67. # print('after')
  68. node1_content = node1.content
  69. node2_content = node2.content
  70. if len(node1_content) < 3:
  71. node1_content = ["####"]
  72. if len(node2_content) < 3:
  73. node2_content = ["####"]
  74. return Levenshtein.normalized_distance(node1_content, node2_content)
  75. return 0.0
  76. class CustomConfig_del_block(Config):
  77. def rename(self, node1, node2):
  78. """Compares attributes of trees"""
  79. if (
  80. (node1.tag != node2.tag)
  81. or (node1.colspan != node2.colspan)
  82. or (node1.rowspan != node2.rowspan)
  83. ):
  84. return 1.0
  85. if node1.tag == "td":
  86. if node1.content or node2.content:
  87. node1_content = node1.content
  88. node2_content = node2.content
  89. while " " in node1_content:
  90. print(node1_content.index(" "))
  91. node1_content.pop(node1_content.index(" "))
  92. while " " in node2_content:
  93. print(node2_content.index(" "))
  94. node2_content.pop(node2_content.index(" "))
  95. return Levenshtein.normalized_distance(node1_content, node2_content)
  96. return 0.0
  97. class TEDS(object):
  98. """Tree Edit Distance basead Similarity"""
  99. def __init__(self, structure_only=False, n_jobs=1, ignore_nodes=None):
  100. assert isinstance(n_jobs, int) and (
  101. n_jobs >= 1
  102. ), "n_jobs must be an integer greater than 1"
  103. self.structure_only = structure_only
  104. self.n_jobs = n_jobs
  105. self.ignore_nodes = ignore_nodes
  106. self.__tokens__ = []
  107. def tokenize(self, node):
  108. """Tokenizes table cells"""
  109. self.__tokens__.append("<%s>" % node.tag)
  110. if node.text is not None:
  111. self.__tokens__ += list(node.text)
  112. for n in node.getchildren():
  113. self.tokenize(n)
  114. if node.tag != "unk":
  115. self.__tokens__.append("</%s>" % node.tag)
  116. if node.tag != "td" and node.tail is not None:
  117. self.__tokens__ += list(node.tail)
  118. def load_html_tree(self, node, parent=None):
  119. """Converts HTML tree to the format required by apted"""
  120. global __tokens__
  121. if node.tag == "td":
  122. if self.structure_only:
  123. cell = []
  124. else:
  125. self.__tokens__ = []
  126. self.tokenize(node)
  127. cell = self.__tokens__[1:-1].copy()
  128. new_node = TableTree(
  129. node.tag,
  130. int(node.attrib.get("colspan", "1")),
  131. int(node.attrib.get("rowspan", "1")),
  132. cell,
  133. *deque(),
  134. )
  135. else:
  136. new_node = TableTree(node.tag, None, None, None, *deque())
  137. if parent is not None:
  138. parent.children.append(new_node)
  139. if node.tag != "td":
  140. for n in node.getchildren():
  141. self.load_html_tree(n, new_node)
  142. if parent is None:
  143. return new_node
  144. def evaluate(self, pred, true):
  145. """Computes TEDS score between the prediction and the ground truth of a
  146. given sample
  147. """
  148. try_import("lxml")
  149. from lxml import etree, html
  150. if (not pred) or (not true):
  151. return 0.0
  152. parser = html.HTMLParser(remove_comments=True, encoding="utf-8")
  153. pred = html.fromstring(pred, parser=parser)
  154. true = html.fromstring(true, parser=parser)
  155. if pred.xpath("body/table") and true.xpath("body/table"):
  156. pred = pred.xpath("body/table")[0]
  157. true = true.xpath("body/table")[0]
  158. if self.ignore_nodes:
  159. etree.strip_tags(pred, *self.ignore_nodes)
  160. etree.strip_tags(true, *self.ignore_nodes)
  161. n_nodes_pred = len(pred.xpath(".//*"))
  162. n_nodes_true = len(true.xpath(".//*"))
  163. n_nodes = max(n_nodes_pred, n_nodes_true)
  164. tree_pred = self.load_html_tree(pred)
  165. tree_true = self.load_html_tree(true)
  166. distance = APTED(
  167. tree_pred, tree_true, CustomConfig()
  168. ).compute_edit_distance()
  169. return 1.0 - (float(distance) / n_nodes)
  170. else:
  171. return 0.0
  172. def batch_evaluate(self, pred_json, true_json):
  173. """Computes TEDS score between the prediction and the ground truth of
  174. a batch of samples
  175. @params pred_json: {'FILENAME': 'HTML CODE', ...}
  176. @params true_json: {'FILENAME': {'html': 'HTML CODE'}, ...}
  177. @output: {'FILENAME': 'TEDS SCORE', ...}
  178. """
  179. samples = true_json.keys()
  180. if self.n_jobs == 1:
  181. scores = [
  182. self.evaluate(pred_json.get(filename, ""), true_json[filename]["html"])
  183. for filename in tqdm(samples)
  184. ]
  185. else:
  186. inputs = [
  187. {
  188. "pred": pred_json.get(filename, ""),
  189. "true": true_json[filename]["html"],
  190. }
  191. for filename in samples
  192. ]
  193. scores = parallel_process(
  194. inputs, self.evaluate, use_kwargs=True, n_jobs=self.n_jobs, front_num=1
  195. )
  196. scores = dict(zip(samples, scores))
  197. return scores
  198. def batch_evaluate_html(self, pred_htmls, true_htmls):
  199. """Computes TEDS score between the prediction and the ground truth of
  200. a batch of samples
  201. """
  202. if self.n_jobs == 1:
  203. scores = [
  204. self.evaluate(pred_html, true_html)
  205. for (pred_html, true_html) in zip(pred_htmls, true_htmls)
  206. ]
  207. else:
  208. inputs = [
  209. {"pred": pred_html, "true": true_html}
  210. for (pred_html, true_html) in zip(pred_htmls, true_htmls)
  211. ]
  212. scores = parallel_process(
  213. inputs, self.evaluate, use_kwargs=True, n_jobs=self.n_jobs, front_num=1
  214. )
  215. return scores
  216. if __name__ == "__main__":
  217. import json
  218. import pprint
  219. with open("sample_pred.json") as fp:
  220. pred_json = json.load(fp)
  221. with open("sample_gt.json") as fp:
  222. true_json = json.load(fp)
  223. teds = TEDS(n_jobs=4)
  224. scores = teds.batch_evaluate(pred_json, true_json)
  225. pp = pprint.PrettyPrinter()
  226. pp.pprint(scores)