trans_funsd_label.py 6.2 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166
  1. # copyright (c) 2022 PaddlePaddle Authors. All Rights Reserve.
  2. #
  3. # Licensed under the Apache License, Version 2.0 (the "License");
  4. # you may not use this file except in compliance with the License.
  5. # You may obtain a copy of the License at
  6. #
  7. # http://www.apache.org/licenses/LICENSE-2.0
  8. #
  9. # Unless required by applicable law or agreed to in writing, software
  10. # distributed under the License is distributed on an "AS IS" BASIS,
  11. # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  12. # See the License for the specific language governing permissions and
  13. # limitations under the License.
  14. import json
  15. import os
  16. import sys
  17. import cv2
  18. import numpy as np
  19. from copy import deepcopy
  20. def trans_poly_to_bbox(poly):
  21. x1 = np.min([p[0] for p in poly])
  22. x2 = np.max([p[0] for p in poly])
  23. y1 = np.min([p[1] for p in poly])
  24. y2 = np.max([p[1] for p in poly])
  25. return [x1, y1, x2, y2]
  26. def get_outer_poly(bbox_list):
  27. x1 = min([bbox[0] for bbox in bbox_list])
  28. y1 = min([bbox[1] for bbox in bbox_list])
  29. x2 = max([bbox[2] for bbox in bbox_list])
  30. y2 = max([bbox[3] for bbox in bbox_list])
  31. return [[x1, y1], [x2, y1], [x2, y2], [x1, y2]]
  32. def load_funsd_label(image_dir, anno_dir):
  33. imgs = os.listdir(image_dir)
  34. annos = os.listdir(anno_dir)
  35. imgs = [img.replace(".png", "") for img in imgs]
  36. annos = [anno.replace(".json", "") for anno in annos]
  37. fn_info_map = dict()
  38. for anno_fn in annos:
  39. res = []
  40. with open(os.path.join(anno_dir, anno_fn + ".json"), "r") as fin:
  41. infos = json.load(fin)
  42. infos = infos["form"]
  43. old_id2new_id_map = dict()
  44. global_new_id = 0
  45. for info in infos:
  46. if info["text"] is None:
  47. continue
  48. words = info["words"]
  49. if len(words) <= 0:
  50. continue
  51. word_idx = 1
  52. curr_bboxes = [words[0]["box"]]
  53. curr_texts = [words[0]["text"]]
  54. while word_idx < len(words):
  55. # switch to a new link
  56. if words[word_idx]["box"][0] + 10 <= words[word_idx - 1]["box"][2]:
  57. if len("".join(curr_texts[0])) > 0:
  58. res.append(
  59. {
  60. "transcription": " ".join(curr_texts),
  61. "label": info["label"],
  62. "points": get_outer_poly(curr_bboxes),
  63. "linking": info["linking"],
  64. "id": global_new_id,
  65. }
  66. )
  67. if info["id"] not in old_id2new_id_map:
  68. old_id2new_id_map[info["id"]] = []
  69. old_id2new_id_map[info["id"]].append(global_new_id)
  70. global_new_id += 1
  71. curr_bboxes = [words[word_idx]["box"]]
  72. curr_texts = [words[word_idx]["text"]]
  73. else:
  74. curr_bboxes.append(words[word_idx]["box"])
  75. curr_texts.append(words[word_idx]["text"])
  76. word_idx += 1
  77. if len("".join(curr_texts[0])) > 0:
  78. res.append(
  79. {
  80. "transcription": " ".join(curr_texts),
  81. "label": info["label"],
  82. "points": get_outer_poly(curr_bboxes),
  83. "linking": info["linking"],
  84. "id": global_new_id,
  85. }
  86. )
  87. if info["id"] not in old_id2new_id_map:
  88. old_id2new_id_map[info["id"]] = []
  89. old_id2new_id_map[info["id"]].append(global_new_id)
  90. global_new_id += 1
  91. res = sorted(res, key=lambda r: (r["points"][0][1], r["points"][0][0]))
  92. for i in range(len(res) - 1):
  93. for j in range(i, 0, -1):
  94. if abs(
  95. res[j + 1]["points"][0][1] - res[j]["points"][0][1]
  96. ) < 20 and (res[j + 1]["points"][0][0] < res[j]["points"][0][0]):
  97. tmp = deepcopy(res[j])
  98. res[j] = deepcopy(res[j + 1])
  99. res[j + 1] = deepcopy(tmp)
  100. else:
  101. break
  102. # re-generate unique ids
  103. for idx, r in enumerate(res):
  104. new_links = []
  105. for link in r["linking"]:
  106. # illegal links will be removed
  107. if (
  108. link[0] not in old_id2new_id_map
  109. or link[1] not in old_id2new_id_map
  110. ):
  111. continue
  112. for src in old_id2new_id_map[link[0]]:
  113. for dst in old_id2new_id_map[link[1]]:
  114. new_links.append([src, dst])
  115. res[idx]["linking"] = deepcopy(new_links)
  116. fn_info_map[anno_fn] = res
  117. return fn_info_map
  118. def main():
  119. test_image_dir = "train_data/FUNSD/testing_data/images/"
  120. test_anno_dir = "train_data/FUNSD/testing_data/annotations/"
  121. test_output_dir = "train_data/FUNSD/test.json"
  122. fn_info_map = load_funsd_label(test_image_dir, test_anno_dir)
  123. with open(test_output_dir, "w") as fout:
  124. for fn in fn_info_map:
  125. fout.write(
  126. fn
  127. + ".png"
  128. + "\t"
  129. + json.dumps(fn_info_map[fn], ensure_ascii=False)
  130. + "\n"
  131. )
  132. train_image_dir = "train_data/FUNSD/training_data/images/"
  133. train_anno_dir = "train_data/FUNSD/training_data/annotations/"
  134. train_output_dir = "train_data/FUNSD/train.json"
  135. fn_info_map = load_funsd_label(train_image_dir, train_anno_dir)
  136. with open(train_output_dir, "w") as fout:
  137. for fn in fn_info_map:
  138. fout.write(
  139. fn
  140. + ".png"
  141. + "\t"
  142. + json.dumps(fn_info_map[fn], ensure_ascii=False)
  143. + "\n"
  144. )
  145. print("====ok====")
  146. return
  147. if __name__ == "__main__":
  148. main()