mnist.py 5.1 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156
  1. # Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserved
  2. #
  3. # Licensed under the Apache License, Version 2.0 (the "License");
  4. # you may not use this file except in compliance with the License.
  5. # You may obtain a copy of the License at
  6. #
  7. # http://www.apache.org/licenses/LICENSE-2.0
  8. #
  9. # Unless required by applicable law or agreed to in writing, software
  10. # distributed under the License is distributed on an "AS IS" BASIS,
  11. # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  12. # See the License for the specific language governing permissions and
  13. # limitations under the License.
  14. """
  15. MNIST dataset.
  16. This module will download dataset from http://yann.lecun.com/exdb/mnist/ and
  17. parse training set and test set into paddle reader creators.
  18. """
  19. import gzip
  20. import struct
  21. import numpy
  22. import paddle.dataset.common
  23. from paddle.utils import deprecated
  24. __all__ = []
  25. URL_PREFIX = 'https://dataset.bj.bcebos.com/mnist/'
  26. TEST_IMAGE_URL = URL_PREFIX + 't10k-images-idx3-ubyte.gz'
  27. TEST_IMAGE_MD5 = '9fb629c4189551a2d022fa330f9573f3'
  28. TEST_LABEL_URL = URL_PREFIX + 't10k-labels-idx1-ubyte.gz'
  29. TEST_LABEL_MD5 = 'ec29112dd5afa0611ce80d1b7f02629c'
  30. TRAIN_IMAGE_URL = URL_PREFIX + 'train-images-idx3-ubyte.gz'
  31. TRAIN_IMAGE_MD5 = 'f68b3c2dcbeaaa9fbdd348bbdeb94873'
  32. TRAIN_LABEL_URL = URL_PREFIX + 'train-labels-idx1-ubyte.gz'
  33. TRAIN_LABEL_MD5 = 'd53e105ee54ea40749a09fcbcd1e9432'
  34. def reader_creator(image_filename, label_filename, buffer_size):
  35. def reader():
  36. with gzip.GzipFile(image_filename, 'rb') as image_file:
  37. img_buf = image_file.read()
  38. with gzip.GzipFile(label_filename, 'rb') as label_file:
  39. lab_buf = label_file.read()
  40. step_label = 0
  41. offset_img = 0
  42. # read from Big-endian
  43. # get file info from magic byte
  44. # image file : 16B
  45. magic_byte_img = '>IIII'
  46. magic_img, image_num, rows, cols = struct.unpack_from(
  47. magic_byte_img, img_buf, offset_img
  48. )
  49. offset_img += struct.calcsize(magic_byte_img)
  50. offset_lab = 0
  51. # label file : 8B
  52. magic_byte_lab = '>II'
  53. magic_lab, label_num = struct.unpack_from(
  54. magic_byte_lab, lab_buf, offset_lab
  55. )
  56. offset_lab += struct.calcsize(magic_byte_lab)
  57. while True:
  58. if step_label >= label_num:
  59. break
  60. fmt_label = '>' + str(buffer_size) + 'B'
  61. labels = struct.unpack_from(fmt_label, lab_buf, offset_lab)
  62. offset_lab += struct.calcsize(fmt_label)
  63. step_label += buffer_size
  64. fmt_images = '>' + str(buffer_size * rows * cols) + 'B'
  65. images_temp = struct.unpack_from(
  66. fmt_images, img_buf, offset_img
  67. )
  68. images = numpy.reshape(
  69. images_temp, (buffer_size, rows * cols)
  70. ).astype('float32')
  71. offset_img += struct.calcsize(fmt_images)
  72. images = images / 255.0
  73. images = images * 2.0
  74. images = images - 1.0
  75. for i in range(buffer_size):
  76. yield images[i, :], int(labels[i])
  77. return reader
  78. @deprecated(
  79. since="2.0.0",
  80. update_to="paddle.vision.datasets.MNIST",
  81. level=1,
  82. reason="Please use new dataset API which supports paddle.io.DataLoader",
  83. )
  84. def train():
  85. """
  86. MNIST training set creator.
  87. It returns a reader creator, each sample in the reader is image pixels in
  88. [-1, 1] and label in [0, 9].
  89. :return: Training reader creator
  90. :rtype: callable
  91. """
  92. return reader_creator(
  93. paddle.dataset.common.download(
  94. TRAIN_IMAGE_URL, 'mnist', TRAIN_IMAGE_MD5
  95. ),
  96. paddle.dataset.common.download(
  97. TRAIN_LABEL_URL, 'mnist', TRAIN_LABEL_MD5
  98. ),
  99. 100,
  100. )
  101. @deprecated(
  102. since="2.0.0",
  103. update_to="paddle.vision.datasets.MNIST",
  104. level=1,
  105. reason="Please use new dataset API which supports paddle.io.DataLoader",
  106. )
  107. def test():
  108. """
  109. MNIST test set creator.
  110. It returns a reader creator, each sample in the reader is image pixels in
  111. [-1, 1] and label in [0, 9].
  112. :return: Test reader creator.
  113. :rtype: callable
  114. """
  115. return reader_creator(
  116. paddle.dataset.common.download(TEST_IMAGE_URL, 'mnist', TEST_IMAGE_MD5),
  117. paddle.dataset.common.download(TEST_LABEL_URL, 'mnist', TEST_LABEL_MD5),
  118. 100,
  119. )
  120. @deprecated(
  121. since="2.0.0",
  122. update_to="paddle.vision.datasets.MNIST",
  123. level=1,
  124. reason="Please use new dataset API which supports paddle.io.DataLoader",
  125. )
  126. def fetch():
  127. paddle.dataset.common.download(TRAIN_IMAGE_URL, 'mnist', TRAIN_IMAGE_MD5)
  128. paddle.dataset.common.download(TRAIN_LABEL_URL, 'mnist', TRAIN_LABEL_MD5)
  129. paddle.dataset.common.download(TEST_IMAGE_URL, 'mnist', TEST_IMAGE_MD5)
  130. paddle.dataset.common.download(TEST_LABEL_URL, 'mnist', TEST_LABEL_MD5)