flowers.py 6.9 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189
  1. # Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved.
  2. #
  3. # Licensed under the Apache License, Version 2.0 (the "License");
  4. # you may not use this file except in compliance with the License.
  5. # You may obtain a copy of the License at
  6. #
  7. # http://www.apache.org/licenses/LICENSE-2.0
  8. #
  9. # Unless required by applicable law or agreed to in writing, software
  10. # distributed under the License is distributed on an "AS IS" BASIS,
  11. # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  12. # See the License for the specific language governing permissions and
  13. # limitations under the License.
  14. import os
  15. import tarfile
  16. import numpy as np
  17. from PIL import Image
  18. import paddle
  19. from paddle.dataset.common import _check_exists_and_download
  20. from paddle.io import Dataset
  21. from paddle.utils import try_import
  22. __all__ = []
  23. DATA_URL = 'http://paddlemodels.bj.bcebos.com/flowers/102flowers.tgz'
  24. LABEL_URL = 'http://paddlemodels.bj.bcebos.com/flowers/imagelabels.mat'
  25. SETID_URL = 'http://paddlemodels.bj.bcebos.com/flowers/setid.mat'
  26. DATA_MD5 = '52808999861908f626f3c1f4e79d11fa'
  27. LABEL_MD5 = 'e0620be6f572b9609742df49c70aed4d'
  28. SETID_MD5 = 'a5357ecc9cb78c4bef273ce3793fc85c'
  29. # In official 'readme', tstid is the flag of test data
  30. # and trnid is the flag of train data. But test data is more than train data.
  31. # So we exchange the train data and test data.
  32. MODE_FLAG_MAP = {'train': 'tstid', 'test': 'trnid', 'valid': 'valid'}
  33. class Flowers(Dataset):
  34. """
  35. Implementation of `Flowers102 <https://www.robots.ox.ac.uk/~vgg/data/flowers/>`_
  36. dataset.
  37. Args:
  38. data_file (str, optional): Path to data file, can be set None if
  39. :attr:`download` is True. Default: None, default data path: ~/.cache/paddle/dataset/flowers/.
  40. label_file (str, optional): Path to label file, can be set None if
  41. :attr:`download` is True. Default: None, default data path: ~/.cache/paddle/dataset/flowers/.
  42. setid_file (str, optional): Path to subset index file, can be set
  43. None if :attr:`download` is True. Default: None, default data path: ~/.cache/paddle/dataset/flowers/.
  44. mode (str, optional): Either train or test mode. Default 'train'.
  45. transform (Callable, optional): transform to perform on image, None for no transform. Default: None.
  46. download (bool, optional): download dataset automatically if :attr:`data_file` is None. Default: True.
  47. backend (str, optional): Specifies which type of image to be returned:
  48. PIL.Image or numpy.ndarray. Should be one of {'pil', 'cv2'}.
  49. If this option is not set, will get backend from :ref:`paddle.vision.get_image_backend <api_paddle_vision_get_image_backend>`,
  50. default backend is 'pil'. Default: None.
  51. Returns:
  52. :ref:`api_paddle_io_Dataset`. An instance of Flowers dataset.
  53. Examples:
  54. .. code-block:: python
  55. >>> # doctest: +TIMEOUT(60)
  56. >>> import itertools
  57. >>> import paddle.vision.transforms as T
  58. >>> from paddle.vision.datasets import Flowers
  59. >>> flowers = Flowers()
  60. >>> print(len(flowers))
  61. 6149
  62. >>> for i in range(5): # only show first 5 images
  63. ... img, label = flowers[i]
  64. ... # do something with img and label
  65. ... print(type(img), img.size, label)
  66. ... # <class 'PIL.JpegImagePlugin.JpegImageFile'> (523, 500) [1]
  67. >>> transform = T.Compose(
  68. ... [
  69. ... T.Resize(64),
  70. ... T.ToTensor(),
  71. ... T.Normalize(
  72. ... mean=[0.5, 0.5, 0.5],
  73. ... std=[0.5, 0.5, 0.5],
  74. ... to_rgb=True,
  75. ... ),
  76. ... ]
  77. ... )
  78. >>> flowers_test = Flowers(
  79. ... mode="test",
  80. ... transform=transform, # apply transform to every image
  81. ... backend="cv2", # use OpenCV as image transform backend
  82. ... )
  83. >>> print(len(flowers_test))
  84. 1020
  85. >>> for img, label in itertools.islice(iter(flowers_test), 5): # only show first 5 images
  86. ... # do something with img and label
  87. ... print(type(img), img.shape, label)
  88. ... # <class 'paddle.Tensor'> [3, 64, 96] [1]
  89. """
  90. def __init__(
  91. self,
  92. data_file=None,
  93. label_file=None,
  94. setid_file=None,
  95. mode='train',
  96. transform=None,
  97. download=True,
  98. backend=None,
  99. ):
  100. assert mode.lower() in [
  101. 'train',
  102. 'valid',
  103. 'test',
  104. ], f"mode should be 'train', 'valid' or 'test', but got {mode}"
  105. if backend is None:
  106. backend = paddle.vision.get_image_backend()
  107. if backend not in ['pil', 'cv2']:
  108. raise ValueError(
  109. f"Expected backend are one of ['pil', 'cv2'], but got {backend}"
  110. )
  111. self.backend = backend
  112. flag = MODE_FLAG_MAP[mode.lower()]
  113. if not data_file:
  114. assert (
  115. download
  116. ), "data_file is not set and downloading automatically is disabled"
  117. data_file = _check_exists_and_download(
  118. data_file, DATA_URL, DATA_MD5, 'flowers', download
  119. )
  120. if not label_file:
  121. assert (
  122. download
  123. ), "label_file is not set and downloading automatically is disabled"
  124. label_file = _check_exists_and_download(
  125. label_file, LABEL_URL, LABEL_MD5, 'flowers', download
  126. )
  127. if not setid_file:
  128. assert (
  129. download
  130. ), "setid_file is not set and downloading automatically is disabled"
  131. setid_file = _check_exists_and_download(
  132. setid_file, SETID_URL, SETID_MD5, 'flowers', download
  133. )
  134. self.transform = transform
  135. data_tar = tarfile.open(data_file)
  136. self.data_path = data_file.replace(".tgz", "/")
  137. if not os.path.exists(self.data_path):
  138. os.mkdir(self.data_path)
  139. data_tar.extractall(self.data_path)
  140. scio = try_import('scipy.io')
  141. self.labels = scio.loadmat(label_file)['labels'][0]
  142. self.indexes = scio.loadmat(setid_file)[flag][0]
  143. def __getitem__(self, idx):
  144. index = self.indexes[idx]
  145. label = np.array([self.labels[index - 1]])
  146. img_name = "jpg/image_%05d.jpg" % index
  147. image = os.path.join(self.data_path, img_name)
  148. if self.backend == 'pil':
  149. image = Image.open(image)
  150. elif self.backend == 'cv2':
  151. image = np.array(Image.open(image))
  152. if self.transform is not None:
  153. image = self.transform(image)
  154. if self.backend == 'pil':
  155. return image, label.astype('int64')
  156. return image.astype(paddle.get_default_dtype()), label.astype('int64')
  157. def __len__(self):
  158. return len(self.indexes)