flowers.py 7.6 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243
  1. # Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserved
  2. #
  3. # Licensed under the Apache License, Version 2.0 (the "License");
  4. # you may not use this file except in compliance with the License.
  5. # You may obtain a copy of the License at
  6. #
  7. # http://www.apache.org/licenses/LICENSE-2.0
  8. #
  9. # Unless required by applicable law or agreed to in writing, software
  10. # distributed under the License is distributed on an "AS IS" BASIS,
  11. # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  12. # See the License for the specific language governing permissions and
  13. # limitations under the License.
  14. """
  15. This module will download dataset from
  16. http://www.robots.ox.ac.uk/~vgg/data/flowers/102/index.html
  17. and parse train/test dataset into paddle reader creators.
  18. This set contains images of flowers belonging to 102 different categories.
  19. The images were acquired by searching the web and taking pictures. There are a
  20. minimum of 40 images for each category.
  21. The database was used in:
  22. Nilsback, M-E. and Zisserman, A. Automated flower classification over a large
  23. number of classes.Proceedings of the Indian Conference on Computer Vision,
  24. Graphics and Image Processing (2008)
  25. http://www.robots.ox.ac.uk/~vgg/publications/papers/nilsback08.{pdf,ps.gz}.
  26. """
  27. import functools
  28. import tarfile
  29. from multiprocessing import cpu_count
  30. from paddle.dataset.image import load_image_bytes, simple_transform
  31. from paddle.reader import map_readers, xmap_readers
  32. from paddle.utils import deprecated, try_import
  33. from .common import download
  34. __all__ = []
  35. DATA_URL = 'http://paddlemodels.bj.bcebos.com/flowers/102flowers.tgz'
  36. LABEL_URL = 'http://paddlemodels.bj.bcebos.com/flowers/imagelabels.mat'
  37. SETID_URL = 'http://paddlemodels.bj.bcebos.com/flowers/setid.mat'
  38. DATA_MD5 = '52808999861908f626f3c1f4e79d11fa'
  39. LABEL_MD5 = 'e0620be6f572b9609742df49c70aed4d'
  40. SETID_MD5 = 'a5357ecc9cb78c4bef273ce3793fc85c'
  41. # In official 'readme', tstid is the flag of test data
  42. # and trnid is the flag of train data. But test data is more than train data.
  43. # So we exchange the train data and test data.
  44. TRAIN_FLAG = 'tstid'
  45. TEST_FLAG = 'trnid'
  46. VALID_FLAG = 'valid'
  47. def default_mapper(is_train, sample):
  48. '''
  49. map image bytes data to type needed by model input layer
  50. '''
  51. img, label = sample
  52. img = load_image_bytes(img)
  53. img = simple_transform(
  54. img, 256, 224, is_train, mean=[103.94, 116.78, 123.68]
  55. )
  56. return img.flatten().astype('float32'), label
  57. train_mapper = functools.partial(default_mapper, True)
  58. test_mapper = functools.partial(default_mapper, False)
  59. def reader_creator(
  60. data_file,
  61. label_file,
  62. setid_file,
  63. dataset_name,
  64. mapper,
  65. buffered_size=1024,
  66. use_xmap=True,
  67. cycle=False,
  68. ):
  69. '''
  70. 1. read images from tar file and
  71. merge images into batch files in 102flowers.tgz_batch/
  72. 2. get a reader to read sample from batch file
  73. :param data_file: downloaded data file
  74. :type data_file: string
  75. :param label_file: downloaded label file
  76. :type label_file: string
  77. :param setid_file: downloaded setid file containing information
  78. about how to split dataset
  79. :type setid_file: string
  80. :param dataset_name: data set name (tstid|trnid|valid)
  81. :type dataset_name: string
  82. :param mapper: a function to map image bytes data to type
  83. needed by model input layer
  84. :type mapper: callable
  85. :param buffered_size: the size of buffer used to process images
  86. :type buffered_size: int
  87. :param cycle: whether to cycle through the dataset
  88. :type cycle: bool
  89. :return: data reader
  90. :rtype: callable
  91. '''
  92. def reader():
  93. scio = try_import('scipy.io')
  94. labels = scio.loadmat(label_file)['labels'][0]
  95. indexes = scio.loadmat(setid_file)[dataset_name][0]
  96. img2label = {}
  97. for i in indexes:
  98. img = "jpg/image_%05d.jpg" % i
  99. img2label[img] = labels[i - 1]
  100. tf = tarfile.open(data_file)
  101. mems = tf.getmembers()
  102. file_id = 0
  103. for mem in mems:
  104. if mem.name in img2label:
  105. image = tf.extractfile(mem).read()
  106. label = img2label[mem.name]
  107. yield image, int(label) - 1
  108. if use_xmap:
  109. return xmap_readers(mapper, reader, min(4, cpu_count()), buffered_size)
  110. else:
  111. return map_readers(mapper, reader)
  112. @deprecated(
  113. since="2.0.0",
  114. update_to="paddle.vision.datasets.Flowers",
  115. level=1,
  116. reason="Please use new dataset API which supports paddle.io.DataLoader",
  117. )
  118. def train(mapper=train_mapper, buffered_size=1024, use_xmap=True, cycle=False):
  119. '''
  120. Create flowers training set reader.
  121. It returns a reader, each sample in the reader is
  122. image pixels in [0, 1] and label in [1, 102]
  123. translated from original color image by steps:
  124. 1. resize to 256*256
  125. 2. random crop to 224*224
  126. 3. flatten
  127. :param mapper: a function to map sample.
  128. :type mapper: callable
  129. :param buffered_size: the size of buffer used to process images
  130. :type buffered_size: int
  131. :param cycle: whether to cycle through the dataset
  132. :type cycle: bool
  133. :return: train data reader
  134. :rtype: callable
  135. '''
  136. return reader_creator(
  137. download(DATA_URL, 'flowers', DATA_MD5),
  138. download(LABEL_URL, 'flowers', LABEL_MD5),
  139. download(SETID_URL, 'flowers', SETID_MD5),
  140. TRAIN_FLAG,
  141. mapper,
  142. buffered_size,
  143. use_xmap,
  144. cycle=cycle,
  145. )
  146. @deprecated(
  147. since="2.0.0",
  148. update_to="paddle.vision.datasets.Flowers",
  149. level=1,
  150. reason="Please use new dataset API which supports paddle.io.DataLoader",
  151. )
  152. def test(mapper=test_mapper, buffered_size=1024, use_xmap=True, cycle=False):
  153. '''
  154. Create flowers test set reader.
  155. It returns a reader, each sample in the reader is
  156. image pixels in [0, 1] and label in [1, 102]
  157. translated from original color image by steps:
  158. 1. resize to 256*256
  159. 2. random crop to 224*224
  160. 3. flatten
  161. :param mapper: a function to map sample.
  162. :type mapper: callable
  163. :param buffered_size: the size of buffer used to process images
  164. :type buffered_size: int
  165. :param cycle: whether to cycle through the dataset
  166. :type cycle: bool
  167. :return: test data reader
  168. :rtype: callable
  169. '''
  170. return reader_creator(
  171. download(DATA_URL, 'flowers', DATA_MD5),
  172. download(LABEL_URL, 'flowers', LABEL_MD5),
  173. download(SETID_URL, 'flowers', SETID_MD5),
  174. TEST_FLAG,
  175. mapper,
  176. buffered_size,
  177. use_xmap,
  178. cycle=cycle,
  179. )
  180. @deprecated(
  181. since="2.0.0",
  182. update_to="paddle.vision.datasets.Flowers",
  183. level=1,
  184. reason="Please use new dataset API which supports paddle.io.DataLoader",
  185. )
  186. def valid(mapper=test_mapper, buffered_size=1024, use_xmap=True):
  187. '''
  188. Create flowers validation set reader.
  189. It returns a reader, each sample in the reader is
  190. image pixels in [0, 1] and label in [1, 102]
  191. translated from original color image by steps:
  192. 1. resize to 256*256
  193. 2. random crop to 224*224
  194. 3. flatten
  195. :param mapper: a function to map sample.
  196. :type mapper: callable
  197. :param buffered_size: the size of buffer used to process images
  198. :type buffered_size: int
  199. :return: test data reader
  200. :rtype: callable
  201. '''
  202. return reader_creator(
  203. download(DATA_URL, 'flowers', DATA_MD5),
  204. download(LABEL_URL, 'flowers', LABEL_MD5),
  205. download(SETID_URL, 'flowers', SETID_MD5),
  206. VALID_FLAG,
  207. mapper,
  208. buffered_size,
  209. use_xmap,
  210. )
  211. def fetch():
  212. download(DATA_URL, 'flowers', DATA_MD5)
  213. download(LABEL_URL, 'flowers', LABEL_MD5)
  214. download(SETID_URL, 'flowers', SETID_MD5)