dataset.py 6.4 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210
  1. """ Quick n Simple Image Folder, Tarfile based DataSet
  2. Hacked together by / Copyright 2019, Ross Wightman
  3. """
  4. import io
  5. import logging
  6. from typing import Optional
  7. import torch
  8. import torch.utils.data as data
  9. from PIL import Image
  10. from .readers import create_reader
  11. _logger = logging.getLogger(__name__)
  12. _ERROR_RETRY = 50
  13. class ImageDataset(data.Dataset):
  14. def __init__(
  15. self,
  16. root,
  17. reader=None,
  18. split='train',
  19. class_map=None,
  20. load_bytes=False,
  21. input_img_mode='RGB',
  22. transform=None,
  23. target_transform=None,
  24. additional_features=None,
  25. **kwargs,
  26. ):
  27. if reader is None or isinstance(reader, str):
  28. reader = create_reader(
  29. reader or '',
  30. root=root,
  31. split=split,
  32. class_map=class_map,
  33. additional_features=additional_features,
  34. **kwargs,
  35. )
  36. self.reader = reader
  37. self.load_bytes = load_bytes
  38. self.input_img_mode = input_img_mode
  39. self.transform = transform
  40. self.target_transform = target_transform
  41. self.additional_features = additional_features
  42. self._consecutive_errors = 0
  43. def __getitem__(self, index):
  44. img, target, *features = self.reader[index]
  45. try:
  46. img = img.read() if self.load_bytes else Image.open(img)
  47. except Exception as e:
  48. _logger.warning(f'Skipped sample (index {index}, file {self.reader.filename(index)}). {str(e)}')
  49. self._consecutive_errors += 1
  50. if self._consecutive_errors < _ERROR_RETRY:
  51. return self.__getitem__((index + 1) % len(self.reader))
  52. else:
  53. raise e
  54. self._consecutive_errors = 0
  55. if self.input_img_mode and not self.load_bytes:
  56. img = img.convert(self.input_img_mode)
  57. if self.transform is not None:
  58. img = self.transform(img)
  59. if target is None:
  60. target = -1
  61. elif self.target_transform is not None:
  62. target = self.target_transform(target)
  63. if self.additional_features is None:
  64. return img, target
  65. else:
  66. return img, target, *features
  67. def __len__(self):
  68. return len(self.reader)
  69. def filename(self, index, basename=False, absolute=False):
  70. return self.reader.filename(index, basename, absolute)
  71. def filenames(self, basename=False, absolute=False):
  72. return self.reader.filenames(basename, absolute)
  73. class IterableImageDataset(data.IterableDataset):
  74. def __init__(
  75. self,
  76. root,
  77. reader=None,
  78. split='train',
  79. class_map=None,
  80. is_training=False,
  81. batch_size=1,
  82. num_samples=None,
  83. seed=42,
  84. repeats=0,
  85. download=False,
  86. input_img_mode='RGB',
  87. input_key=None,
  88. target_key=None,
  89. transform=None,
  90. target_transform=None,
  91. max_steps=None,
  92. **kwargs,
  93. ):
  94. assert reader is not None
  95. if isinstance(reader, str):
  96. self.reader = create_reader(
  97. reader,
  98. root=root,
  99. split=split,
  100. class_map=class_map,
  101. is_training=is_training,
  102. batch_size=batch_size,
  103. num_samples=num_samples,
  104. seed=seed,
  105. repeats=repeats,
  106. download=download,
  107. input_img_mode=input_img_mode,
  108. input_key=input_key,
  109. target_key=target_key,
  110. max_steps=max_steps,
  111. **kwargs,
  112. )
  113. else:
  114. self.reader = reader
  115. self.transform = transform
  116. self.target_transform = target_transform
  117. self._consecutive_errors = 0
  118. def __iter__(self):
  119. for img, target in self.reader:
  120. if self.transform is not None:
  121. img = self.transform(img)
  122. if self.target_transform is not None:
  123. target = self.target_transform(target)
  124. yield img, target
  125. def __len__(self):
  126. if hasattr(self.reader, '__len__'):
  127. return len(self.reader)
  128. else:
  129. return 0
  130. def set_epoch(self, count):
  131. # TFDS and WDS need external epoch count for deterministic cross process shuffle
  132. if hasattr(self.reader, 'set_epoch'):
  133. self.reader.set_epoch(count)
  134. def set_loader_cfg(
  135. self,
  136. num_workers: Optional[int] = None,
  137. ):
  138. # TFDS and WDS readers need # workers for correct # samples estimate before loader processes created
  139. if hasattr(self.reader, 'set_loader_cfg'):
  140. self.reader.set_loader_cfg(num_workers=num_workers)
  141. def filename(self, index, basename=False, absolute=False):
  142. assert False, 'Filename lookup by index not supported, use filenames().'
  143. def filenames(self, basename=False, absolute=False):
  144. return self.reader.filenames(basename, absolute)
  145. class AugMixDataset(torch.utils.data.Dataset):
  146. """Dataset wrapper to perform AugMix or other clean/augmentation mixes"""
  147. def __init__(self, dataset, num_splits=2):
  148. self.augmentation = None
  149. self.normalize = None
  150. self.dataset = dataset
  151. if self.dataset.transform is not None:
  152. self._set_transforms(self.dataset.transform)
  153. self.num_splits = num_splits
  154. def _set_transforms(self, x):
  155. assert isinstance(x, (list, tuple)) and len(x) == 3, 'Expecting a tuple/list of 3 transforms'
  156. self.dataset.transform = x[0]
  157. self.augmentation = x[1]
  158. self.normalize = x[2]
  159. @property
  160. def transform(self):
  161. return self.dataset.transform
  162. @transform.setter
  163. def transform(self, x):
  164. self._set_transforms(x)
  165. def _normalize(self, x):
  166. return x if self.normalize is None else self.normalize(x)
  167. def __getitem__(self, i):
  168. x, y = self.dataset[i] # all splits share the same dataset base transform
  169. x_list = [self._normalize(x)] # first split only normalizes (this is the 'clean' split)
  170. # run the full augmentation on the remaining splits
  171. for _ in range(self.num_splits - 1):
  172. x_list.append(self._normalize(self.augmentation(x)))
  173. return tuple(x_list), y
  174. def __len__(self):
  175. return len(self.dataset)