simple_dataset.py 9.9 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256
  1. # copyright (c) 2020 PaddlePaddle Authors. All Rights Reserve.
  2. #
  3. # Licensed under the Apache License, Version 2.0 (the "License");
  4. # you may not use this file except in compliance with the License.
  5. # You may obtain a copy of the License at
  6. #
  7. # http://www.apache.org/licenses/LICENSE-2.0
  8. #
  9. # Unless required by applicable law or agreed to in writing, software
  10. # distributed under the License is distributed on an "AS IS" BASIS,
  11. # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  12. # See the License for the specific language governing permissions and
  13. # limitations under the License.
  14. import numpy as np
  15. import cv2
  16. import math
  17. import os
  18. import json
  19. import random
  20. import traceback
  21. from paddle.io import Dataset
  22. from .imaug import transform, create_operators
  23. from paddle import get_device
  24. class SimpleDataSet(Dataset):
  25. def __init__(self, config, mode, logger, seed=None):
  26. super(SimpleDataSet, self).__init__()
  27. self.logger = logger
  28. self.mode = mode.lower()
  29. global_config = config["Global"]
  30. dataset_config = config[mode]["dataset"]
  31. loader_config = config[mode]["loader"]
  32. self.delimiter = dataset_config.get("delimiter", "\t")
  33. label_file_list = dataset_config.pop("label_file_list")
  34. data_source_num = len(label_file_list)
  35. ratio_list = dataset_config.get("ratio_list", 1.0)
  36. if isinstance(ratio_list, (float, int)):
  37. ratio_list = [float(ratio_list)] * int(data_source_num)
  38. assert (
  39. len(ratio_list) == data_source_num
  40. ), "The length of ratio_list should be the same as the file_list."
  41. self.data_dir = dataset_config["data_dir"]
  42. self.do_shuffle = loader_config["shuffle"]
  43. self.seed = seed
  44. logger.info("Initialize indexes of datasets:%s" % label_file_list)
  45. self.data_lines = self.get_image_info_list(label_file_list, ratio_list)
  46. self.data_idx_order_list = list(range(len(self.data_lines)))
  47. if self.mode == "train" and self.do_shuffle:
  48. self.shuffle_data_random()
  49. self.ops = create_operators(dataset_config["transforms"], global_config)
  50. self.ext_op_transform_idx = dataset_config.get("ext_op_transform_idx", 2)
  51. self.need_reset = True in [x < 1 for x in ratio_list]
  52. def get_image_info_list(self, file_list, ratio_list):
  53. if isinstance(file_list, str):
  54. file_list = [file_list]
  55. data_lines = []
  56. for idx, file in enumerate(file_list):
  57. with open(file, "rb") as f:
  58. lines = f.readlines()
  59. if self.mode == "train" or ratio_list[idx] < 1.0:
  60. random.seed(self.seed)
  61. lines = random.sample(lines, round(len(lines) * ratio_list[idx]))
  62. data_lines.extend(lines)
  63. return data_lines
  64. def shuffle_data_random(self):
  65. random.seed(self.seed)
  66. random.shuffle(self.data_lines)
  67. return
  68. def _try_parse_filename_list(self, file_name):
  69. # multiple images -> one gt label
  70. if len(file_name) > 0 and file_name[0] == "[":
  71. try:
  72. info = json.loads(file_name)
  73. file_name = random.choice(info)
  74. except:
  75. pass
  76. return file_name
  77. def get_ext_data(self):
  78. ext_data_num = 0
  79. for op in self.ops:
  80. if hasattr(op, "ext_data_num"):
  81. ext_data_num = getattr(op, "ext_data_num")
  82. break
  83. load_data_ops = self.ops[: self.ext_op_transform_idx]
  84. ext_data = []
  85. while len(ext_data) < ext_data_num:
  86. file_idx = self.data_idx_order_list[np.random.randint(self.__len__())]
  87. data_line = self.data_lines[file_idx]
  88. data_line = data_line.decode("utf-8")
  89. substr = data_line.strip("\n").split(self.delimiter)
  90. file_name = substr[0]
  91. file_name = self._try_parse_filename_list(file_name)
  92. label = substr[1]
  93. img_path = os.path.join(self.data_dir, file_name)
  94. data = {"img_path": img_path, "label": label}
  95. if not os.path.exists(img_path):
  96. continue
  97. with open(data["img_path"], "rb") as f:
  98. img = f.read()
  99. data["image"] = img
  100. data = transform(data, load_data_ops)
  101. if data is None:
  102. continue
  103. if "polys" in data.keys():
  104. if data["polys"].shape[1] != 4:
  105. continue
  106. ext_data.append(data)
  107. return ext_data
  108. def __getitem__(self, idx):
  109. file_idx = self.data_idx_order_list[idx]
  110. data_line = self.data_lines[file_idx]
  111. try:
  112. data_line = data_line.decode("utf-8")
  113. substr = data_line.strip("\n").split(self.delimiter)
  114. file_name = substr[0]
  115. file_name = self._try_parse_filename_list(file_name)
  116. label = substr[1]
  117. img_path = os.path.join(self.data_dir, file_name)
  118. data = {"img_path": img_path, "label": label}
  119. if not os.path.exists(img_path):
  120. raise Exception("{} does not exist!".format(img_path))
  121. with open(data["img_path"], "rb") as f:
  122. img = f.read()
  123. data["image"] = img
  124. data["ext_data"] = self.get_ext_data()
  125. data["filename"] = data["img_path"]
  126. outs = transform(data, self.ops)
  127. except:
  128. self.logger.error(
  129. "When parsing line {}, error happened with msg: {}".format(
  130. data_line, traceback.format_exc()
  131. )
  132. )
  133. outs = None
  134. if outs is None:
  135. # during evaluation, we should fix the idx to get same results for many times of evaluation.
  136. rnd_idx = (
  137. np.random.randint(self.__len__())
  138. if self.mode == "train"
  139. else (idx + 1) % self.__len__()
  140. )
  141. return self.__getitem__(rnd_idx)
  142. return outs
  143. def __len__(self):
  144. return len(self.data_idx_order_list)
  145. class MultiScaleDataSet(SimpleDataSet):
  146. def __init__(self, config, mode, logger, seed=None):
  147. super(MultiScaleDataSet, self).__init__(config, mode, logger, seed)
  148. self.ds_width = config[mode]["dataset"].get("ds_width", False)
  149. if self.ds_width:
  150. self.wh_aware()
  151. def wh_aware(self):
  152. data_line_new = []
  153. wh_ratio = []
  154. for line in self.data_lines:
  155. data_line_new.append(line)
  156. line = line.decode("utf-8")
  157. name, label, w, h = line.strip("\n").split(self.delimiter)
  158. wh_ratio.append(float(w) / float(h))
  159. self.data_lines = data_line_new
  160. self.wh_ratio = np.array(wh_ratio)
  161. self.wh_ratio_sort = np.argsort(self.wh_ratio)
  162. self.data_idx_order_list = list(range(len(self.data_lines)))
  163. def resize_norm_img(self, data, imgW, imgH, padding=True):
  164. img = data["image"]
  165. h = img.shape[0]
  166. w = img.shape[1]
  167. if not padding:
  168. resized_image = cv2.resize(
  169. img, (imgW, imgH), interpolation=cv2.INTER_LINEAR
  170. )
  171. resized_w = imgW
  172. else:
  173. ratio = w / float(h)
  174. if math.ceil(imgH * ratio) > imgW:
  175. resized_w = imgW
  176. else:
  177. resized_w = int(math.ceil(imgH * ratio))
  178. resized_image = cv2.resize(img, (resized_w, imgH))
  179. resized_image = resized_image.astype("float32")
  180. resized_image = resized_image.transpose((2, 0, 1)) / 255
  181. resized_image -= 0.5
  182. resized_image /= 0.5
  183. padding_im = np.zeros((3, imgH, imgW), dtype=np.float32)
  184. padding_im[:, :, :resized_w] = resized_image
  185. valid_ratio = min(1.0, float(resized_w / imgW))
  186. data["image"] = padding_im
  187. data["valid_ratio"] = valid_ratio
  188. if "iluvatar_gpu" in get_device():
  189. data["valid_ratio"] = np.float32(valid_ratio)
  190. return data
  191. def __getitem__(self, properties):
  192. # properties is a tuple, contains (width, height, index)
  193. img_height = properties[1]
  194. idx = properties[2]
  195. if self.ds_width and properties[3] is not None:
  196. wh_ratio = properties[3]
  197. img_width = img_height * (
  198. 1 if int(round(wh_ratio)) == 0 else int(round(wh_ratio))
  199. )
  200. file_idx = self.wh_ratio_sort[idx]
  201. else:
  202. file_idx = self.data_idx_order_list[idx]
  203. img_width = properties[0]
  204. wh_ratio = None
  205. data_line = self.data_lines[file_idx]
  206. try:
  207. data_line = data_line.decode("utf-8")
  208. substr = data_line.strip("\n").split(self.delimiter)
  209. file_name = substr[0]
  210. file_name = self._try_parse_filename_list(file_name)
  211. label = substr[1]
  212. img_path = os.path.join(self.data_dir, file_name)
  213. data = {"img_path": img_path, "label": label}
  214. if not os.path.exists(img_path):
  215. raise Exception("{} does not exist!".format(img_path))
  216. with open(data["img_path"], "rb") as f:
  217. img = f.read()
  218. data["image"] = img
  219. data["ext_data"] = self.get_ext_data()
  220. outs = transform(data, self.ops[:-1])
  221. if outs is not None:
  222. outs = self.resize_norm_img(outs, img_width, img_height)
  223. outs = transform(outs, self.ops[-1:])
  224. except:
  225. self.logger.error(
  226. "When parsing line {}, error happened with msg: {}".format(
  227. data_line, traceback.format_exc()
  228. )
  229. )
  230. outs = None
  231. if outs is None:
  232. # during evaluation, we should fix the idx to get same results for many times of evaluation.
  233. rnd_idx = (idx + 1) % self.__len__()
  234. return self.__getitem__([img_width, img_height, rnd_idx, wh_ratio])
  235. return outs