lmdb_dataset.py 10 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301
  1. # copyright (c) 2020 PaddlePaddle Authors. All Rights Reserve.
  2. #
  3. # Licensed under the Apache License, Version 2.0 (the "License");
  4. # you may not use this file except in compliance with the License.
  5. # You may obtain a copy of the License at
  6. #
  7. # http://www.apache.org/licenses/LICENSE-2.0
  8. #
  9. # Unless required by applicable law or agreed to in writing, software
  10. # distributed under the License is distributed on an "AS IS" BASIS,
  11. # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  12. # See the License for the specific language governing permissions and
  13. # limitations under the License.
  14. import numpy as np
  15. import io
  16. import os
  17. from paddle.io import Dataset
  18. import lmdb
  19. import cv2
  20. import string
  21. import pickle
  22. from PIL import Image
  23. from .imaug import transform, create_operators
  24. class LMDBDataSet(Dataset):
  25. def __init__(self, config, mode, logger, seed=None):
  26. super(LMDBDataSet, self).__init__()
  27. global_config = config["Global"]
  28. dataset_config = config[mode]["dataset"]
  29. loader_config = config[mode]["loader"]
  30. batch_size = loader_config["batch_size_per_card"]
  31. data_dir = dataset_config["data_dir"]
  32. self.do_shuffle = loader_config["shuffle"]
  33. self.lmdb_sets = self.load_hierarchical_lmdb_dataset(data_dir)
  34. logger.info("Initialize indexes of datasets:%s" % data_dir)
  35. self.data_idx_order_list = self.dataset_traversal()
  36. if self.do_shuffle:
  37. np.random.shuffle(self.data_idx_order_list)
  38. self.ops = create_operators(dataset_config["transforms"], global_config)
  39. self.ext_op_transform_idx = dataset_config.get("ext_op_transform_idx", 1)
  40. ratio_list = dataset_config.get("ratio_list", [1.0])
  41. self.need_reset = True in [x < 1 for x in ratio_list]
  42. def load_hierarchical_lmdb_dataset(self, data_dir):
  43. lmdb_sets = {}
  44. dataset_idx = 0
  45. for dirpath, dirnames, filenames in os.walk(data_dir + "/"):
  46. if not dirnames:
  47. env = lmdb.open(
  48. dirpath,
  49. max_readers=32,
  50. readonly=True,
  51. lock=False,
  52. readahead=False,
  53. meminit=False,
  54. )
  55. txn = env.begin(write=False)
  56. num_samples = int(txn.get("num-samples".encode()))
  57. lmdb_sets[dataset_idx] = {
  58. "dirpath": dirpath,
  59. "env": env,
  60. "txn": txn,
  61. "num_samples": num_samples,
  62. }
  63. dataset_idx += 1
  64. return lmdb_sets
  65. def dataset_traversal(self):
  66. lmdb_num = len(self.lmdb_sets)
  67. total_sample_num = 0
  68. for lno in range(lmdb_num):
  69. total_sample_num += self.lmdb_sets[lno]["num_samples"]
  70. data_idx_order_list = np.zeros((total_sample_num, 2))
  71. beg_idx = 0
  72. for lno in range(lmdb_num):
  73. tmp_sample_num = self.lmdb_sets[lno]["num_samples"]
  74. end_idx = beg_idx + tmp_sample_num
  75. data_idx_order_list[beg_idx:end_idx, 0] = lno
  76. data_idx_order_list[beg_idx:end_idx, 1] = list(range(tmp_sample_num))
  77. data_idx_order_list[beg_idx:end_idx, 1] += 1
  78. beg_idx = beg_idx + tmp_sample_num
  79. return data_idx_order_list
  80. def get_img_data(self, value):
  81. """get_img_data"""
  82. if not value:
  83. return None
  84. imgdata = np.frombuffer(value, dtype="uint8")
  85. if imgdata is None:
  86. return None
  87. imgori = cv2.imdecode(imgdata, 1)
  88. if imgori is None:
  89. return None
  90. return imgori
  91. def get_ext_data(self):
  92. ext_data_num = 0
  93. for op in self.ops:
  94. if hasattr(op, "ext_data_num"):
  95. ext_data_num = getattr(op, "ext_data_num")
  96. break
  97. load_data_ops = self.ops[: self.ext_op_transform_idx]
  98. ext_data = []
  99. while len(ext_data) < ext_data_num:
  100. lmdb_idx, file_idx = self.data_idx_order_list[np.random.randint(len(self))]
  101. lmdb_idx = int(lmdb_idx)
  102. file_idx = int(file_idx)
  103. sample_info = self.get_lmdb_sample_info(
  104. self.lmdb_sets[lmdb_idx]["txn"], file_idx
  105. )
  106. if sample_info is None:
  107. continue
  108. img, label = sample_info
  109. data = {"image": img, "label": label}
  110. data = transform(data, load_data_ops)
  111. if data is None:
  112. continue
  113. ext_data.append(data)
  114. return ext_data
  115. def get_lmdb_sample_info(self, txn, index):
  116. label_key = "label-%09d".encode() % index
  117. label = txn.get(label_key)
  118. if label is None:
  119. return None
  120. label = label.decode("utf-8")
  121. img_key = "image-%09d".encode() % index
  122. imgbuf = txn.get(img_key)
  123. return imgbuf, label
  124. def __getitem__(self, idx):
  125. lmdb_idx, file_idx = self.data_idx_order_list[idx]
  126. lmdb_idx = int(lmdb_idx)
  127. file_idx = int(file_idx)
  128. sample_info = self.get_lmdb_sample_info(
  129. self.lmdb_sets[lmdb_idx]["txn"], file_idx
  130. )
  131. if sample_info is None:
  132. return self.__getitem__(np.random.randint(self.__len__()))
  133. img, label = sample_info
  134. data = {"image": img, "label": label}
  135. data["ext_data"] = self.get_ext_data()
  136. outs = transform(data, self.ops)
  137. if outs is None:
  138. return self.__getitem__(np.random.randint(self.__len__()))
  139. return outs
  140. def __len__(self):
  141. return self.data_idx_order_list.shape[0]
  142. class LMDBDataSetSR(LMDBDataSet):
  143. def buf2PIL(self, txn, key, type="RGB"):
  144. imgbuf = txn.get(key)
  145. buf = io.BytesIO()
  146. buf.write(imgbuf)
  147. buf.seek(0)
  148. im = Image.open(buf).convert(type)
  149. return im
  150. def str_filt(self, str_, voc_type):
  151. alpha_dict = {
  152. "digit": string.digits,
  153. "lower": string.digits + string.ascii_lowercase,
  154. "upper": string.digits + string.ascii_letters,
  155. "all": string.digits + string.ascii_letters + string.punctuation,
  156. }
  157. if voc_type == "lower":
  158. str_ = str_.lower()
  159. for char in str_:
  160. if char not in alpha_dict[voc_type]:
  161. str_ = str_.replace(char, "")
  162. return str_
  163. def get_lmdb_sample_info(self, txn, index):
  164. self.voc_type = "upper"
  165. self.max_len = 100
  166. self.test = False
  167. label_key = b"label-%09d" % index
  168. word = str(txn.get(label_key).decode())
  169. img_HR_key = b"image_hr-%09d" % index # 128*32
  170. img_lr_key = b"image_lr-%09d" % index # 64*16
  171. try:
  172. img_HR = self.buf2PIL(txn, img_HR_key, "RGB")
  173. img_lr = self.buf2PIL(txn, img_lr_key, "RGB")
  174. except IOError or len(word) > self.max_len:
  175. return self[index + 1]
  176. label_str = self.str_filt(word, self.voc_type)
  177. return img_HR, img_lr, label_str
  178. def __getitem__(self, idx):
  179. lmdb_idx, file_idx = self.data_idx_order_list[idx]
  180. lmdb_idx = int(lmdb_idx)
  181. file_idx = int(file_idx)
  182. sample_info = self.get_lmdb_sample_info(
  183. self.lmdb_sets[lmdb_idx]["txn"], file_idx
  184. )
  185. if sample_info is None:
  186. return self.__getitem__(np.random.randint(self.__len__()))
  187. img_HR, img_lr, label_str = sample_info
  188. data = {"image_hr": img_HR, "image_lr": img_lr, "label": label_str}
  189. outs = transform(data, self.ops)
  190. if outs is None:
  191. return self.__getitem__(np.random.randint(self.__len__()))
  192. return outs
  193. class LMDBDataSetTableMaster(LMDBDataSet):
  194. def load_hierarchical_lmdb_dataset(self, data_dir):
  195. lmdb_sets = {}
  196. dataset_idx = 0
  197. env = lmdb.open(
  198. data_dir,
  199. max_readers=32,
  200. readonly=True,
  201. lock=False,
  202. readahead=False,
  203. meminit=False,
  204. )
  205. txn = env.begin(write=False)
  206. num_samples = int(pickle.loads(txn.get(b"__len__")))
  207. lmdb_sets[dataset_idx] = {
  208. "dirpath": data_dir,
  209. "env": env,
  210. "txn": txn,
  211. "num_samples": num_samples,
  212. }
  213. return lmdb_sets
  214. def get_img_data(self, value):
  215. """get_img_data"""
  216. if not value:
  217. return None
  218. imgdata = np.frombuffer(value, dtype="uint8")
  219. if imgdata is None:
  220. return None
  221. imgori = cv2.imdecode(imgdata, 1)
  222. if imgori is None:
  223. return None
  224. return imgori
  225. def get_lmdb_sample_info(self, txn, index):
  226. def convert_bbox(bbox_str_list):
  227. bbox_list = []
  228. for bbox_str in bbox_str_list:
  229. bbox_list.append(int(bbox_str))
  230. return bbox_list
  231. try:
  232. data = pickle.loads(txn.get(str(index).encode("utf8")))
  233. except:
  234. return None
  235. # img_name, img, info_lines
  236. file_name = data[0]
  237. bytes = data[1]
  238. info_lines = data[2] # raw data from TableMASTER annotation file.
  239. # parse info_lines
  240. raw_data = info_lines.strip().split("\n")
  241. raw_name, text = (
  242. raw_data[0],
  243. raw_data[1],
  244. ) # don't filter the samples's length over max_seq_len.
  245. text = text.split(",")
  246. bbox_str_list = raw_data[2:]
  247. bbox_split = ","
  248. bboxes = [
  249. {"bbox": convert_bbox(bsl.strip().split(bbox_split)), "tokens": ["1", "2"]}
  250. for bsl in bbox_str_list
  251. ]
  252. # advance parse bbox
  253. # import pdb;pdb.set_trace()
  254. line_info = {}
  255. line_info["file_name"] = file_name
  256. line_info["structure"] = text
  257. line_info["cells"] = bboxes
  258. line_info["image"] = bytes
  259. return line_info
  260. def __getitem__(self, idx):
  261. lmdb_idx, file_idx = self.data_idx_order_list[idx]
  262. lmdb_idx = int(lmdb_idx)
  263. file_idx = int(file_idx)
  264. data = self.get_lmdb_sample_info(self.lmdb_sets[lmdb_idx]["txn"], file_idx)
  265. if data is None:
  266. return self.__getitem__(np.random.randint(self.__len__()))
  267. outs = transform(data, self.ops)
  268. if outs is None:
  269. return self.__getitem__(np.random.randint(self.__len__()))
  270. return outs
  271. def __len__(self):
  272. return self.data_idx_order_list.shape[0]