base_dataset.py 2.8 KB

1234567891011121314151617181920212223242526272829303132333435363738394041424344454647484950515253545556575859606162636465666768697071727374757677787980818283848586
  1. # -*- coding: utf-8 -*-
  2. # @Time : 2019/12/4 13:12
  3. # @Author : zhoujun
  4. import copy
  5. from paddle.io import Dataset
  6. from data_loader.modules import *
  7. class BaseDataSet(Dataset):
  8. def __init__(
  9. self,
  10. data_path: str,
  11. img_mode,
  12. pre_processes,
  13. filter_keys,
  14. ignore_tags,
  15. transform=None,
  16. target_transform=None,
  17. ):
  18. assert img_mode in ["RGB", "BRG", "GRAY"]
  19. self.ignore_tags = ignore_tags
  20. self.data_list = self.load_data(data_path)
  21. item_keys = ["img_path", "img_name", "text_polys", "texts", "ignore_tags"]
  22. for item in item_keys:
  23. assert (
  24. item in self.data_list[0]
  25. ), "data_list from load_data must contains {}".format(item_keys)
  26. self.img_mode = img_mode
  27. self.filter_keys = filter_keys
  28. self.transform = transform
  29. self.target_transform = target_transform
  30. self._init_pre_processes(pre_processes)
  31. def _init_pre_processes(self, pre_processes):
  32. self.aug = []
  33. if pre_processes is not None:
  34. for aug in pre_processes:
  35. if "args" not in aug:
  36. args = {}
  37. else:
  38. args = aug["args"]
  39. if isinstance(args, dict):
  40. cls = eval(aug["type"])(**args)
  41. else:
  42. cls = eval(aug["type"])(args)
  43. self.aug.append(cls)
  44. def load_data(self, data_path: str) -> list:
  45. """
  46. 把数据加载为一个list:
  47. :params data_path: 存储数据的文件夹或者文件
  48. return a dict ,包含了,'img_path','img_name','text_polys','texts','ignore_tags'
  49. """
  50. raise NotImplementedError
  51. def apply_pre_processes(self, data):
  52. for aug in self.aug:
  53. data = aug(data)
  54. return data
  55. def __getitem__(self, index):
  56. try:
  57. data = copy.deepcopy(self.data_list[index])
  58. im = cv2.imread(data["img_path"], 1 if self.img_mode != "GRAY" else 0)
  59. if self.img_mode == "RGB":
  60. im = cv2.cvtColor(im, cv2.COLOR_BGR2RGB)
  61. data["img"] = im
  62. data["shape"] = [im.shape[0], im.shape[1]]
  63. data = self.apply_pre_processes(data)
  64. if self.transform:
  65. data["img"] = self.transform(data["img"])
  66. data["text_polys"] = data["text_polys"].tolist()
  67. if len(self.filter_keys):
  68. data_dict = {}
  69. for k, v in data.items():
  70. if k not in self.filter_keys:
  71. data_dict[k] = v
  72. return data_dict
  73. else:
  74. return data
  75. except:
  76. return self.__getitem__(np.random.randint(self.__len__()))
  77. def __len__(self):
  78. return len(self.data_list)