cifar.py 9.1 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267
  1. # Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved.
  2. #
  3. # Licensed under the Apache License, Version 2.0 (the "License");
  4. # you may not use this file except in compliance with the License.
  5. # You may obtain a copy of the License at
  6. #
  7. # http://www.apache.org/licenses/LICENSE-2.0
  8. #
  9. # Unless required by applicable law or agreed to in writing, software
  10. # distributed under the License is distributed on an "AS IS" BASIS,
  11. # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  12. # See the License for the specific language governing permissions and
  13. # limitations under the License.
  14. import pickle
  15. import tarfile
  16. import numpy as np
  17. from PIL import Image
  18. import paddle
  19. from paddle.dataset.common import _check_exists_and_download
  20. from paddle.io import Dataset
  21. __all__ = []
  22. URL_PREFIX = 'https://dataset.bj.bcebos.com/cifar/'
  23. CIFAR10_URL = URL_PREFIX + 'cifar-10-python.tar.gz'
  24. CIFAR10_MD5 = 'c58f30108f718f92721af3b95e74349a'
  25. CIFAR100_URL = URL_PREFIX + 'cifar-100-python.tar.gz'
  26. CIFAR100_MD5 = 'eb9058c3a382ffc7106e4002c42a8d85'
  27. MODE_FLAG_MAP = {
  28. 'train10': 'data_batch',
  29. 'test10': 'test_batch',
  30. 'train100': 'train',
  31. 'test100': 'test',
  32. }
  33. class Cifar10(Dataset):
  34. """
  35. Implementation of `Cifar-10 <https://www.cs.toronto.edu/~kriz/cifar.html>`_
  36. dataset, which has 10 categories.
  37. Args:
  38. data_file (str, optional): Path to data file, can be set None if
  39. :attr:`download` is True. Default None, default data path: ~/.cache/paddle/dataset/cifar
  40. mode (str, optional): Either train or test mode. Default 'train'.
  41. transform (Callable, optional): transform to perform on image, None for no transform. Default: None.
  42. download (bool, optional): download dataset automatically if :attr:`data_file` is None. Default True.
  43. backend (str, optional): Specifies which type of image to be returned:
  44. PIL.Image or numpy.ndarray. Should be one of {'pil', 'cv2'}.
  45. If this option is not set, will get backend from :ref:`paddle.vision.get_image_backend <api_paddle_vision_get_image_backend>`,
  46. default backend is 'pil'. Default: None.
  47. Returns:
  48. :ref:`api_paddle_io_Dataset`. An instance of Cifar10 dataset.
  49. Examples:
  50. .. code-block:: python
  51. >>> # doctest: +TIMEOUT(60)
  52. >>> import itertools
  53. >>> import paddle.vision.transforms as T
  54. >>> from paddle.vision.datasets import Cifar10
  55. >>> cifar10 = Cifar10()
  56. >>> print(len(cifar10))
  57. 50000
  58. >>> for i in range(5): # only show first 5 images
  59. ... img, label = cifar10[i]
  60. ... # do something with img and label
  61. ... print(type(img), img.size, label)
  62. ... # <class 'PIL.Image.Image'> (32, 32) 6
  63. >>> transform = T.Compose(
  64. ... [
  65. ... T.Resize(64),
  66. ... T.ToTensor(),
  67. ... T.Normalize(
  68. ... mean=[0.5, 0.5, 0.5],
  69. ... std=[0.5, 0.5, 0.5],
  70. ... to_rgb=True,
  71. ... ),
  72. ... ]
  73. ... )
  74. >>> cifar10_test = Cifar10(
  75. ... mode="test",
  76. ... transform=transform, # apply transform to every image
  77. ... backend="cv2", # use OpenCV as image transform backend
  78. ... )
  79. >>> print(len(cifar10_test))
  80. 10000
  81. >>> for img, label in itertools.islice(iter(cifar10_test), 5): # only show first 5 images
  82. ... # do something with img and label
  83. ... print(type(img), img.shape, label)
  84. ... # <class 'paddle.Tensor'> [3, 64, 64] 3
  85. """
  86. def __init__(
  87. self,
  88. data_file=None,
  89. mode='train',
  90. transform=None,
  91. download=True,
  92. backend=None,
  93. ):
  94. assert mode.lower() in [
  95. 'train',
  96. 'test',
  97. ], f"mode.lower() should be 'train' or 'test', but got {mode}"
  98. self.mode = mode.lower()
  99. if backend is None:
  100. backend = paddle.vision.get_image_backend()
  101. if backend not in ['pil', 'cv2']:
  102. raise ValueError(
  103. f"Expected backend are one of ['pil', 'cv2'], but got {backend}"
  104. )
  105. self.backend = backend
  106. self._init_url_md5_flag()
  107. self.data_file = data_file
  108. if self.data_file is None:
  109. assert (
  110. download
  111. ), "data_file is not set and downloading automatically is disabled"
  112. self.data_file = _check_exists_and_download(
  113. data_file, self.data_url, self.data_md5, 'cifar', download
  114. )
  115. self.transform = transform
  116. # read dataset into memory
  117. self._load_data()
  118. self.dtype = paddle.get_default_dtype()
  119. def _init_url_md5_flag(self):
  120. self.data_url = CIFAR10_URL
  121. self.data_md5 = CIFAR10_MD5
  122. self.flag = MODE_FLAG_MAP[self.mode + '10']
  123. def _load_data(self):
  124. self.data = []
  125. with tarfile.open(self.data_file, mode='r') as f:
  126. names = (
  127. each_item.name for each_item in f if self.flag in each_item.name
  128. )
  129. names = sorted(names)
  130. for name in names:
  131. batch = pickle.load(f.extractfile(name), encoding='bytes')
  132. data = batch[b'data']
  133. labels = batch.get(b'labels', batch.get(b'fine_labels', None))
  134. assert labels is not None
  135. for sample, label in zip(data, labels):
  136. self.data.append((sample, label))
  137. def __getitem__(self, idx):
  138. image, label = self.data[idx]
  139. image = np.reshape(image, [3, 32, 32])
  140. image = image.transpose([1, 2, 0])
  141. if self.backend == 'pil':
  142. image = Image.fromarray(image.astype('uint8'))
  143. if self.transform is not None:
  144. image = self.transform(image)
  145. if self.backend == 'pil':
  146. return image, np.array(label).astype('int64')
  147. return image.astype(self.dtype), np.array(label).astype('int64')
  148. def __len__(self):
  149. return len(self.data)
  150. class Cifar100(Cifar10):
  151. """
  152. Implementation of `Cifar-100 <https://www.cs.toronto.edu/~kriz/cifar.html>`_
  153. dataset, which has 100 categories.
  154. Args:
  155. data_file (str, optional): path to data file, can be set None if
  156. :attr:`download` is True. Default: None, default data path: ~/.cache/paddle/dataset/cifar
  157. mode (str, optional): Either train or test mode. Default 'train'.
  158. transform (Callable, optional): transform to perform on image, None for no transform. Default: None.
  159. download (bool, optional): download dataset automatically if :attr:`data_file` is None. Default True.
  160. backend (str, optional): Specifies which type of image to be returned:
  161. PIL.Image or numpy.ndarray. Should be one of {'pil', 'cv2'}.
  162. If this option is not set, will get backend from :ref:`paddle.vision.get_image_backend <api_paddle_vision_get_image_backend>`,
  163. default backend is 'pil'. Default: None.
  164. Returns:
  165. :ref:`api_paddle_io_Dataset`. An instance of Cifar100 dataset.
  166. Examples:
  167. .. code-block:: python
  168. >>> # doctest: +TIMEOUT(60)
  169. >>> import itertools
  170. >>> import paddle.vision.transforms as T
  171. >>> from paddle.vision.datasets import Cifar100
  172. >>> cifar100 = Cifar100()
  173. >>> print(len(cifar100))
  174. 50000
  175. >>> for i in range(5): # only show first 5 images
  176. ... img, label = cifar100[i]
  177. ... # do something with img and label
  178. ... print(type(img), img.size, label)
  179. ... # <class 'PIL.Image.Image'> (32, 32) 19
  180. >>> transform = T.Compose(
  181. ... [
  182. ... T.Resize(64),
  183. ... T.ToTensor(),
  184. ... T.Normalize(
  185. ... mean=[0.5, 0.5, 0.5],
  186. ... std=[0.5, 0.5, 0.5],
  187. ... to_rgb=True,
  188. ... ),
  189. ... ]
  190. ... )
  191. >>> cifar100_test = Cifar100(
  192. ... mode="test",
  193. ... transform=transform, # apply transform to every image
  194. ... backend="cv2", # use OpenCV as image transform backend
  195. ... )
  196. >>> print(len(cifar100_test))
  197. 10000
  198. >>> for img, label in itertools.islice(iter(cifar100_test), 5): # only show first 5 images
  199. ... # do something with img and label
  200. ... print(type(img), img.shape, label)
  201. ... # <class 'paddle.Tensor'> [3, 64, 64] 49
  202. """
  203. def __init__(
  204. self,
  205. data_file=None,
  206. mode='train',
  207. transform=None,
  208. download=True,
  209. backend=None,
  210. ):
  211. super().__init__(data_file, mode, transform, download, backend)
  212. def _init_url_md5_flag(self):
  213. self.data_url = CIFAR100_URL
  214. self.data_md5 = CIFAR100_MD5
  215. self.flag = MODE_FLAG_MAP[self.mode + '100']