pubtab_dataset.py 5.0 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138
  1. # copyright (c) 2021 PaddlePaddle Authors. All Rights Reserve.
  2. #
  3. # Licensed under the Apache License, Version 2.0 (the "License");
  4. # you may not use this file except in compliance with the License.
  5. # You may obtain a copy of the License at
  6. #
  7. # http://www.apache.org/licenses/LICENSE-2.0
  8. #
  9. # Unless required by applicable law or agreed to in writing, software
  10. # distributed under the License is distributed on an "AS IS" BASIS,
  11. # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  12. # See the License for the specific language governing permissions and
  13. # limitations under the License.
  14. import numpy as np
  15. import os
  16. import random
  17. from paddle.io import Dataset
  18. import json
  19. from copy import deepcopy
  20. from .imaug import transform, create_operators
  21. class PubTabDataSet(Dataset):
  22. def __init__(self, config, mode, logger, seed=None):
  23. super(PubTabDataSet, self).__init__()
  24. self.logger = logger
  25. global_config = config["Global"]
  26. dataset_config = config[mode]["dataset"]
  27. loader_config = config[mode]["loader"]
  28. label_file_list = dataset_config.pop("label_file_list")
  29. data_source_num = len(label_file_list)
  30. ratio_list = dataset_config.get("ratio_list", [1.0])
  31. if isinstance(ratio_list, (float, int)):
  32. ratio_list = [float(ratio_list)] * int(data_source_num)
  33. assert (
  34. len(ratio_list) == data_source_num
  35. ), "The length of ratio_list should be the same as the file_list."
  36. self.data_dir = dataset_config["data_dir"]
  37. self.do_shuffle = loader_config["shuffle"]
  38. self.seed = seed
  39. self.mode = mode.lower()
  40. logger.info("Initialize indexes of datasets:%s" % label_file_list)
  41. self.data_lines = self.get_image_info_list(label_file_list, ratio_list)
  42. # self.check(config['Global']['max_text_length'])
  43. if mode.lower() == "train" and self.do_shuffle:
  44. self.shuffle_data_random()
  45. self.ops = create_operators(dataset_config["transforms"], global_config)
  46. self.need_reset = True in [x < 1 for x in ratio_list]
  47. def get_image_info_list(self, file_list, ratio_list):
  48. if isinstance(file_list, str):
  49. file_list = [file_list]
  50. data_lines = []
  51. for idx, file in enumerate(file_list):
  52. with open(file, "rb") as f:
  53. lines = f.readlines()
  54. if self.mode == "train" or ratio_list[idx] < 1.0:
  55. random.seed(self.seed)
  56. lines = random.sample(lines, round(len(lines) * ratio_list[idx]))
  57. data_lines.extend(lines)
  58. return data_lines
  59. def check(self, max_text_length):
  60. data_lines = []
  61. for line in self.data_lines:
  62. data_line = line.decode("utf-8").strip("\n")
  63. info = json.loads(data_line)
  64. file_name = info["filename"]
  65. cells = info["html"]["cells"].copy()
  66. structure = info["html"]["structure"]["tokens"].copy()
  67. img_path = os.path.join(self.data_dir, file_name)
  68. if not os.path.exists(img_path):
  69. self.logger.warning("{} does not exist!".format(img_path))
  70. continue
  71. if len(structure) == 0 or len(structure) > max_text_length:
  72. continue
  73. # data = {'img_path': img_path, 'cells': cells, 'structure':structure,'file_name':file_name}
  74. data_lines.append(line)
  75. self.data_lines = data_lines
  76. def shuffle_data_random(self):
  77. if self.do_shuffle:
  78. random.seed(self.seed)
  79. random.shuffle(self.data_lines)
  80. return
  81. def __getitem__(self, idx):
  82. try:
  83. data_line = self.data_lines[idx]
  84. data_line = data_line.decode("utf-8").strip("\n")
  85. info = json.loads(data_line)
  86. file_name = info["filename"]
  87. cells = info["html"]["cells"].copy()
  88. structure = info["html"]["structure"]["tokens"].copy()
  89. img_path = os.path.join(self.data_dir, file_name)
  90. if not os.path.exists(img_path):
  91. raise Exception("{} does not exist!".format(img_path))
  92. data = {
  93. "img_path": img_path,
  94. "cells": cells,
  95. "structure": structure,
  96. "file_name": file_name,
  97. }
  98. with open(data["img_path"], "rb") as f:
  99. img = f.read()
  100. data["image"] = img
  101. outs = transform(data, self.ops)
  102. except:
  103. import traceback
  104. err = traceback.format_exc()
  105. self.logger.error(
  106. "When parsing line {}, error happened with msg: {}".format(
  107. data_line, err
  108. )
  109. )
  110. outs = None
  111. if outs is None:
  112. rnd_idx = (
  113. np.random.randint(self.__len__())
  114. if self.mode == "train"
  115. else (idx + 1) % self.__len__()
  116. )
  117. return self.__getitem__(rnd_idx)
  118. return outs
  119. def __len__(self):
  120. return len(self.data_lines)