mnist.py 12 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311
  1. # Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved.
  2. #
  3. # Licensed under the Apache License, Version 2.0 (the "License");
  4. # you may not use this file except in compliance with the License.
  5. # You may obtain a copy of the License at
  6. #
  7. # http://www.apache.org/licenses/LICENSE-2.0
  8. #
  9. # Unless required by applicable law or agreed to in writing, software
  10. # distributed under the License is distributed on an "AS IS" BASIS,
  11. # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  12. # See the License for the specific language governing permissions and
  13. # limitations under the License.
  14. import gzip
  15. import struct
  16. import numpy as np
  17. from PIL import Image
  18. import paddle
  19. from paddle.dataset.common import _check_exists_and_download
  20. from paddle.io import Dataset
  21. __all__ = []
  22. class MNIST(Dataset):
  23. """
  24. Implementation of `MNIST <http://yann.lecun.com/exdb/mnist/>`_ dataset.
  25. Args:
  26. image_path (str, optional): Path to image file, can be set None if
  27. :attr:`download` is True. Default: None, default data path: ~/.cache/paddle/dataset/mnist.
  28. label_path (str, optional): Path to label file, can be set None if
  29. :attr:`download` is True. Default: None, default data path: ~/.cache/paddle/dataset/mnist.
  30. mode (str, optional): Either train or test mode. Default 'train'.
  31. transform (Callable, optional): Transform to perform on image, None for no transform. Default: None.
  32. download (bool, optional): Download dataset automatically if
  33. :attr:`image_path` :attr:`label_path` is not set. Default: True.
  34. backend (str, optional): Specifies which type of image to be returned:
  35. PIL.Image or numpy.ndarray. Should be one of {'pil', 'cv2'}.
  36. If this option is not set, will get backend from :ref:`paddle.vision.get_image_backend <api_paddle_vision_get_image_backend>`,
  37. default backend is 'pil'. Default: None.
  38. Returns:
  39. :ref:`api_paddle_io_Dataset`. An instance of MNIST dataset.
  40. Examples:
  41. .. code-block:: python
  42. >>> import itertools
  43. >>> import paddle.vision.transforms as T
  44. >>> from paddle.vision.datasets import MNIST
  45. >>> mnist = MNIST()
  46. >>> print(len(mnist))
  47. 60000
  48. >>> for i in range(5): # only show first 5 images
  49. ... img, label = mnist[i]
  50. ... # do something with img and label
  51. ... print(type(img), img.size, label)
  52. ... # <class 'PIL.Image.Image'> (28, 28) [5]
  53. >>> transform = T.Compose(
  54. ... [
  55. ... T.ToTensor(),
  56. ... T.Normalize(
  57. ... mean=[127.5],
  58. ... std=[127.5],
  59. ... ),
  60. ... ]
  61. ... )
  62. >>> mnist_test = MNIST(
  63. ... mode="test",
  64. ... transform=transform, # apply transform to every image
  65. ... backend="cv2", # use OpenCV as image transform backend
  66. ... )
  67. >>> print(len(mnist_test))
  68. 10000
  69. >>> for img, label in itertools.islice(iter(mnist_test), 5): # only show first 5 images
  70. ... # do something with img and label
  71. ... print(type(img), img.shape, label)
  72. ... # <class 'paddle.Tensor'> [1, 28, 28] [7]
  73. """
  74. NAME = 'mnist'
  75. URL_PREFIX = 'https://dataset.bj.bcebos.com/mnist/'
  76. TEST_IMAGE_URL = URL_PREFIX + 't10k-images-idx3-ubyte.gz'
  77. TEST_IMAGE_MD5 = '9fb629c4189551a2d022fa330f9573f3'
  78. TEST_LABEL_URL = URL_PREFIX + 't10k-labels-idx1-ubyte.gz'
  79. TEST_LABEL_MD5 = 'ec29112dd5afa0611ce80d1b7f02629c'
  80. TRAIN_IMAGE_URL = URL_PREFIX + 'train-images-idx3-ubyte.gz'
  81. TRAIN_IMAGE_MD5 = 'f68b3c2dcbeaaa9fbdd348bbdeb94873'
  82. TRAIN_LABEL_URL = URL_PREFIX + 'train-labels-idx1-ubyte.gz'
  83. TRAIN_LABEL_MD5 = 'd53e105ee54ea40749a09fcbcd1e9432'
  84. def __init__(
  85. self,
  86. image_path=None,
  87. label_path=None,
  88. mode='train',
  89. transform=None,
  90. download=True,
  91. backend=None,
  92. ):
  93. assert mode.lower() in [
  94. 'train',
  95. 'test',
  96. ], f"mode should be 'train' or 'test', but got {mode}"
  97. if backend is None:
  98. backend = paddle.vision.get_image_backend()
  99. if backend not in ['pil', 'cv2']:
  100. raise ValueError(
  101. f"Expected backend are one of ['pil', 'cv2'], but got {backend}"
  102. )
  103. self.backend = backend
  104. self.mode = mode.lower()
  105. self.image_path = image_path
  106. if self.image_path is None:
  107. assert (
  108. download
  109. ), "image_path is not set and downloading automatically is disabled"
  110. image_url = (
  111. self.TRAIN_IMAGE_URL if mode == 'train' else self.TEST_IMAGE_URL
  112. )
  113. image_md5 = (
  114. self.TRAIN_IMAGE_MD5 if mode == 'train' else self.TEST_IMAGE_MD5
  115. )
  116. self.image_path = _check_exists_and_download(
  117. image_path, image_url, image_md5, self.NAME, download
  118. )
  119. self.label_path = label_path
  120. if self.label_path is None:
  121. assert (
  122. download
  123. ), "label_path is not set and downloading automatically is disabled"
  124. label_url = (
  125. self.TRAIN_LABEL_URL
  126. if self.mode == 'train'
  127. else self.TEST_LABEL_URL
  128. )
  129. label_md5 = (
  130. self.TRAIN_LABEL_MD5
  131. if self.mode == 'train'
  132. else self.TEST_LABEL_MD5
  133. )
  134. self.label_path = _check_exists_and_download(
  135. label_path, label_url, label_md5, self.NAME, download
  136. )
  137. self.transform = transform
  138. # read dataset into memory
  139. self._parse_dataset()
  140. self.dtype = paddle.get_default_dtype()
  141. def _parse_dataset(self, buffer_size=100):
  142. self.images = []
  143. self.labels = []
  144. with gzip.GzipFile(self.image_path, 'rb') as image_file:
  145. img_buf = image_file.read()
  146. with gzip.GzipFile(self.label_path, 'rb') as label_file:
  147. lab_buf = label_file.read()
  148. step_label = 0
  149. offset_img = 0
  150. # read from Big-endian
  151. # get file info from magic byte
  152. # image file : 16B
  153. magic_byte_img = '>IIII'
  154. magic_img, image_num, rows, cols = struct.unpack_from(
  155. magic_byte_img, img_buf, offset_img
  156. )
  157. offset_img += struct.calcsize(magic_byte_img)
  158. offset_lab = 0
  159. # label file : 8B
  160. magic_byte_lab = '>II'
  161. magic_lab, label_num = struct.unpack_from(
  162. magic_byte_lab, lab_buf, offset_lab
  163. )
  164. offset_lab += struct.calcsize(magic_byte_lab)
  165. while True:
  166. if step_label >= label_num:
  167. break
  168. fmt_label = '>' + str(buffer_size) + 'B'
  169. labels = struct.unpack_from(fmt_label, lab_buf, offset_lab)
  170. offset_lab += struct.calcsize(fmt_label)
  171. step_label += buffer_size
  172. fmt_images = '>' + str(buffer_size * rows * cols) + 'B'
  173. images_temp = struct.unpack_from(
  174. fmt_images, img_buf, offset_img
  175. )
  176. images = np.reshape(
  177. images_temp, (buffer_size, rows * cols)
  178. ).astype('float32')
  179. offset_img += struct.calcsize(fmt_images)
  180. for i in range(buffer_size):
  181. self.images.append(images[i, :])
  182. self.labels.append(
  183. np.array([labels[i]]).astype('int64')
  184. )
  185. def __getitem__(self, idx):
  186. image, label = self.images[idx], self.labels[idx]
  187. image = np.reshape(image, [28, 28])
  188. if self.backend == 'pil':
  189. image = Image.fromarray(image.astype('uint8'), mode='L')
  190. if self.transform is not None:
  191. image = self.transform(image)
  192. if self.backend == 'pil':
  193. return image, label.astype('int64')
  194. return image.astype(self.dtype), label.astype('int64')
  195. def __len__(self):
  196. return len(self.labels)
  197. class FashionMNIST(MNIST):
  198. """
  199. Implementation of `Fashion-MNIST <https://github.com/zalandoresearch/fashion-mnist>`_ dataset.
  200. Args:
  201. image_path (str, optional): Path to image file, can be set None if
  202. :attr:`download` is True. Default: None, default data path: ~/.cache/paddle/dataset/fashion-mnist.
  203. label_path (str, optional): Path to label file, can be set None if
  204. :attr:`download` is True. Default: None, default data path: ~/.cache/paddle/dataset/fashion-mnist.
  205. mode (str, optional): Either train or test mode. Default 'train'.
  206. transform (Callable, optional): Transform to perform on image, None for no transform. Default: None.
  207. download (bool, optional): Whether to download dataset automatically if
  208. :attr:`image_path` :attr:`label_path` is not set. Default: True.
  209. backend (str, optional): Specifies which type of image to be returned:
  210. PIL.Image or numpy.ndarray. Should be one of {'pil', 'cv2'}.
  211. If this option is not set, will get backend from :ref:`paddle.vision.get_image_backend <api_paddle_vision_get_image_backend>`,
  212. default backend is 'pil'. Default: None.
  213. Returns:
  214. :ref:`api_paddle_io_Dataset`. An instance of FashionMNIST dataset.
  215. Examples:
  216. .. code-block:: python
  217. >>> import itertools
  218. >>> import paddle.vision.transforms as T
  219. >>> from paddle.vision.datasets import FashionMNIST
  220. >>> fashion_mnist = FashionMNIST()
  221. >>> print(len(fashion_mnist))
  222. 60000
  223. >>> for i in range(5): # only show first 5 images
  224. ... img, label = fashion_mnist[i]
  225. ... # do something with img and label
  226. ... print(type(img), img.size, label)
  227. ... # <class 'PIL.Image.Image'> (28, 28) [9]
  228. >>> transform = T.Compose(
  229. ... [
  230. ... T.ToTensor(),
  231. ... T.Normalize(
  232. ... mean=[127.5],
  233. ... std=[127.5],
  234. ... ),
  235. ... ]
  236. ... )
  237. >>> fashion_mnist_test = FashionMNIST(
  238. ... mode="test",
  239. ... transform=transform, # apply transform to every image
  240. ... backend="cv2", # use OpenCV as image transform backend
  241. ... )
  242. >>> print(len(fashion_mnist_test))
  243. 10000
  244. >>> for img, label in itertools.islice(iter(fashion_mnist_test), 5): # only show first 5 images
  245. ... # do something with img and label
  246. ... print(type(img), img.shape, label)
  247. ... # <class 'paddle.Tensor'> [1, 28, 28] [9]
  248. """
  249. NAME = 'fashion-mnist'
  250. URL_PREFIX = 'https://dataset.bj.bcebos.com/fashion_mnist/'
  251. TEST_IMAGE_URL = URL_PREFIX + 't10k-images-idx3-ubyte.gz'
  252. TEST_IMAGE_MD5 = 'bef4ecab320f06d8554ea6380940ec79'
  253. TEST_LABEL_URL = URL_PREFIX + 't10k-labels-idx1-ubyte.gz'
  254. TEST_LABEL_MD5 = 'bb300cfdad3c16e7a12a480ee83cd310'
  255. TRAIN_IMAGE_URL = URL_PREFIX + 'train-images-idx3-ubyte.gz'
  256. TRAIN_IMAGE_MD5 = '8d4fb7e6c68d591d4c3dfef9ec88bf0d'
  257. TRAIN_LABEL_URL = URL_PREFIX + 'train-labels-idx1-ubyte.gz'
  258. TRAIN_LABEL_MD5 = '25c81989df183df01b3e8a0aad5dffbe'