data.py 3.8 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133
  1. import numpy as np
  2. import paddle
  3. import os
  4. import cv2
  5. import glob
  6. def transform(data, ops=None):
  7. """transform"""
  8. if ops is None:
  9. ops = []
  10. for op in ops:
  11. data = op(data)
  12. if data is None:
  13. return None
  14. return data
  15. def create_operators(op_param_list, global_config=None):
  16. """
  17. create operators based on the config
  18. Args:
  19. params(list): a dict list, used to create some operators
  20. """
  21. assert isinstance(op_param_list, list), "operator config should be a list"
  22. ops = []
  23. for operator in op_param_list:
  24. assert isinstance(operator, dict) and len(operator) == 1, "yaml format error"
  25. op_name = list(operator)[0]
  26. param = {} if operator[op_name] is None else operator[op_name]
  27. if global_config is not None:
  28. param.update(global_config)
  29. op = eval(op_name)(**param)
  30. ops.append(op)
  31. return ops
  32. class DecodeImage(object):
  33. """decode image"""
  34. def __init__(self, img_mode="RGB", channel_first=False, **kwargs):
  35. self.img_mode = img_mode
  36. self.channel_first = channel_first
  37. def __call__(self, data):
  38. img = data["image"]
  39. assert type(img) is bytes and len(img) > 0, "invalid input 'img' in DecodeImage"
  40. img = np.frombuffer(img, dtype="uint8")
  41. img = cv2.imdecode(img, 1)
  42. if img is None:
  43. return None
  44. if self.img_mode == "GRAY":
  45. img = cv2.cvtColor(img, cv2.COLOR_GRAY2BGR)
  46. elif self.img_mode == "RGB":
  47. assert img.shape[2] == 3, "invalid shape of image[%s]" % (img.shape)
  48. img = img[:, :, ::-1]
  49. if self.channel_first:
  50. img = img.transpose((2, 0, 1))
  51. data["image"] = img
  52. data["src_image"] = img
  53. return data
  54. class NormalizeImage(object):
  55. """normalize image such as subtract mean, divide std"""
  56. def __init__(self, scale=None, mean=None, std=None, order="chw", **kwargs):
  57. if isinstance(scale, str):
  58. scale = eval(scale)
  59. self.scale = np.float32(scale if scale is not None else 1.0 / 255.0)
  60. mean = mean if mean is not None else [0.485, 0.456, 0.406]
  61. std = std if std is not None else [0.229, 0.224, 0.225]
  62. shape = (3, 1, 1) if order == "chw" else (1, 1, 3)
  63. self.mean = np.array(mean).reshape(shape).astype("float32")
  64. self.std = np.array(std).reshape(shape).astype("float32")
  65. def __call__(self, data):
  66. img = data["image"]
  67. from PIL import Image
  68. if isinstance(img, Image.Image):
  69. img = np.array(img)
  70. assert isinstance(img, np.ndarray), "invalid input 'img' in NormalizeImage"
  71. data["image"] = (img.astype("float32") * self.scale - self.mean) / self.std
  72. return data
  73. class ToCHWImage(object):
  74. """convert hwc image to chw image"""
  75. def __init__(self, **kwargs):
  76. pass
  77. def __call__(self, data):
  78. img = data["image"]
  79. from PIL import Image
  80. if isinstance(img, Image.Image):
  81. img = np.array(img)
  82. data["image"] = img.transpose((2, 0, 1))
  83. src_img = data["src_image"]
  84. from PIL import Image
  85. if isinstance(img, Image.Image):
  86. src_img = np.array(src_img)
  87. data["src_image"] = img.transpose((2, 0, 1))
  88. return data
  89. class SimpleDataset(nn.Dataset):
  90. def __init__(self, config, mode, logger, seed=None):
  91. self.logger = logger
  92. self.mode = mode.lower()
  93. data_dir = config["Train"]["data_dir"]
  94. imgs_list = self.get_image_list(data_dir)
  95. self.ops = create_operators(cfg["transforms"], None)
  96. def get_image_list(self, img_dir):
  97. imgs = glob.glob(os.path.join(img_dir, "*.png"))
  98. if len(imgs) == 0:
  99. raise ValueError(f"not any images founded in {img_dir}")
  100. return imgs
  101. def __getitem__(self, idx):
  102. return None