| 123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249 |
- # Copyright 2020 IBM
- # Author: peter.zhong@au1.ibm.com
- #
- # This is free software; you can redistribute it and/or modify
- # it under the terms of the Apache 2.0 License.
- #
- # This software is distributed in the hope that it will be useful,
- # but WITHOUT ANY WARRANTY; without even the implied warranty of
- # MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
- # Apache 2.0 License for more details.
- from rapidfuzz.distance import Levenshtein
- from apted import APTED, Config
- from apted.helpers import Tree
- from collections import deque
- from .parallel import parallel_process
- from tqdm import tqdm
- from paddle.utils import try_import
- class TableTree(Tree):
- def __init__(self, tag, colspan=None, rowspan=None, content=None, *children):
- self.tag = tag
- self.colspan = colspan
- self.rowspan = rowspan
- self.content = content
- self.children = list(children)
- def bracket(self):
- """Show tree using brackets notation"""
- if self.tag == "td":
- result = '"tag": %s, "colspan": %d, "rowspan": %d, "text": %s' % (
- self.tag,
- self.colspan,
- self.rowspan,
- self.content,
- )
- else:
- result = '"tag": %s' % self.tag
- for child in self.children:
- result += child.bracket()
- return "{{{}}}".format(result)
- class CustomConfig(Config):
- def rename(self, node1, node2):
- """Compares attributes of trees"""
- # print(node1.tag)
- if (
- (node1.tag != node2.tag)
- or (node1.colspan != node2.colspan)
- or (node1.rowspan != node2.rowspan)
- ):
- return 1.0
- if node1.tag == "td":
- if node1.content or node2.content:
- # print(node1.content, )
- return Levenshtein.normalized_distance(node1.content, node2.content)
- return 0.0
- class CustomConfig_del_short(Config):
- def rename(self, node1, node2):
- """Compares attributes of trees"""
- if (
- (node1.tag != node2.tag)
- or (node1.colspan != node2.colspan)
- or (node1.rowspan != node2.rowspan)
- ):
- return 1.0
- if node1.tag == "td":
- if node1.content or node2.content:
- # print('before')
- # print(node1.content, node2.content)
- # print('after')
- node1_content = node1.content
- node2_content = node2.content
- if len(node1_content) < 3:
- node1_content = ["####"]
- if len(node2_content) < 3:
- node2_content = ["####"]
- return Levenshtein.normalized_distance(node1_content, node2_content)
- return 0.0
- class CustomConfig_del_block(Config):
- def rename(self, node1, node2):
- """Compares attributes of trees"""
- if (
- (node1.tag != node2.tag)
- or (node1.colspan != node2.colspan)
- or (node1.rowspan != node2.rowspan)
- ):
- return 1.0
- if node1.tag == "td":
- if node1.content or node2.content:
- node1_content = node1.content
- node2_content = node2.content
- while " " in node1_content:
- print(node1_content.index(" "))
- node1_content.pop(node1_content.index(" "))
- while " " in node2_content:
- print(node2_content.index(" "))
- node2_content.pop(node2_content.index(" "))
- return Levenshtein.normalized_distance(node1_content, node2_content)
- return 0.0
- class TEDS(object):
- """Tree Edit Distance basead Similarity"""
- def __init__(self, structure_only=False, n_jobs=1, ignore_nodes=None):
- assert isinstance(n_jobs, int) and (
- n_jobs >= 1
- ), "n_jobs must be an integer greater than 1"
- self.structure_only = structure_only
- self.n_jobs = n_jobs
- self.ignore_nodes = ignore_nodes
- self.__tokens__ = []
- def tokenize(self, node):
- """Tokenizes table cells"""
- self.__tokens__.append("<%s>" % node.tag)
- if node.text is not None:
- self.__tokens__ += list(node.text)
- for n in node.getchildren():
- self.tokenize(n)
- if node.tag != "unk":
- self.__tokens__.append("</%s>" % node.tag)
- if node.tag != "td" and node.tail is not None:
- self.__tokens__ += list(node.tail)
- def load_html_tree(self, node, parent=None):
- """Converts HTML tree to the format required by apted"""
- global __tokens__
- if node.tag == "td":
- if self.structure_only:
- cell = []
- else:
- self.__tokens__ = []
- self.tokenize(node)
- cell = self.__tokens__[1:-1].copy()
- new_node = TableTree(
- node.tag,
- int(node.attrib.get("colspan", "1")),
- int(node.attrib.get("rowspan", "1")),
- cell,
- *deque(),
- )
- else:
- new_node = TableTree(node.tag, None, None, None, *deque())
- if parent is not None:
- parent.children.append(new_node)
- if node.tag != "td":
- for n in node.getchildren():
- self.load_html_tree(n, new_node)
- if parent is None:
- return new_node
- def evaluate(self, pred, true):
- """Computes TEDS score between the prediction and the ground truth of a
- given sample
- """
- try_import("lxml")
- from lxml import etree, html
- if (not pred) or (not true):
- return 0.0
- parser = html.HTMLParser(remove_comments=True, encoding="utf-8")
- pred = html.fromstring(pred, parser=parser)
- true = html.fromstring(true, parser=parser)
- if pred.xpath("body/table") and true.xpath("body/table"):
- pred = pred.xpath("body/table")[0]
- true = true.xpath("body/table")[0]
- if self.ignore_nodes:
- etree.strip_tags(pred, *self.ignore_nodes)
- etree.strip_tags(true, *self.ignore_nodes)
- n_nodes_pred = len(pred.xpath(".//*"))
- n_nodes_true = len(true.xpath(".//*"))
- n_nodes = max(n_nodes_pred, n_nodes_true)
- tree_pred = self.load_html_tree(pred)
- tree_true = self.load_html_tree(true)
- distance = APTED(
- tree_pred, tree_true, CustomConfig()
- ).compute_edit_distance()
- return 1.0 - (float(distance) / n_nodes)
- else:
- return 0.0
- def batch_evaluate(self, pred_json, true_json):
- """Computes TEDS score between the prediction and the ground truth of
- a batch of samples
- @params pred_json: {'FILENAME': 'HTML CODE', ...}
- @params true_json: {'FILENAME': {'html': 'HTML CODE'}, ...}
- @output: {'FILENAME': 'TEDS SCORE', ...}
- """
- samples = true_json.keys()
- if self.n_jobs == 1:
- scores = [
- self.evaluate(pred_json.get(filename, ""), true_json[filename]["html"])
- for filename in tqdm(samples)
- ]
- else:
- inputs = [
- {
- "pred": pred_json.get(filename, ""),
- "true": true_json[filename]["html"],
- }
- for filename in samples
- ]
- scores = parallel_process(
- inputs, self.evaluate, use_kwargs=True, n_jobs=self.n_jobs, front_num=1
- )
- scores = dict(zip(samples, scores))
- return scores
- def batch_evaluate_html(self, pred_htmls, true_htmls):
- """Computes TEDS score between the prediction and the ground truth of
- a batch of samples
- """
- if self.n_jobs == 1:
- scores = [
- self.evaluate(pred_html, true_html)
- for (pred_html, true_html) in zip(pred_htmls, true_htmls)
- ]
- else:
- inputs = [
- {"pred": pred_html, "true": true_html}
- for (pred_html, true_html) in zip(pred_htmls, true_htmls)
- ]
- scores = parallel_process(
- inputs, self.evaluate, use_kwargs=True, n_jobs=self.n_jobs, front_num=1
- )
- return scores
- if __name__ == "__main__":
- import json
- import pprint
- with open("sample_pred.json") as fp:
- pred_json = json.load(fp)
- with open("sample_gt.json") as fp:
- true_json = json.load(fp)
- teds = TEDS(n_jobs=4)
- scores = teds.batch_evaluate(pred_json, true_json)
- pp = pprint.PrettyPrinter()
- pp.pprint(scores)
|