folder.py 16 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488
  1. # Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved.
  2. #
  3. # Licensed under the Apache License, Version 2.0 (the "License");
  4. # you may not use this file except in compliance with the License.
  5. # You may obtain a copy of the License at
  6. #
  7. # http://www.apache.org/licenses/LICENSE-2.0
  8. #
  9. # Unless required by applicable law or agreed to in writing, software
  10. # distributed under the License is distributed on an "AS IS" BASIS,
  11. # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  12. # See the License for the specific language governing permissions and
  13. # limitations under the License.
  14. import os
  15. from PIL import Image
  16. import paddle
  17. from paddle.io import Dataset
  18. from paddle.utils import try_import
  19. __all__ = []
  20. def has_valid_extension(filename, extensions):
  21. """Checks if a file is a valid extension.
  22. Args:
  23. filename (str): path to a file
  24. extensions (list[str]|tuple[str]): extensions to consider
  25. Returns:
  26. bool: True if the filename ends with one of given extensions
  27. """
  28. assert isinstance(
  29. extensions, (list, tuple)
  30. ), "`extensions` must be list or tuple."
  31. extensions = tuple([x.lower() for x in extensions])
  32. return filename.lower().endswith(extensions)
  33. def make_dataset(dir, class_to_idx, extensions, is_valid_file=None):
  34. images = []
  35. dir = os.path.expanduser(dir)
  36. if extensions is not None:
  37. def is_valid_file(x):
  38. return has_valid_extension(x, extensions)
  39. for target in sorted(class_to_idx.keys()):
  40. d = os.path.join(dir, target)
  41. if not os.path.isdir(d):
  42. continue
  43. for root, _, fnames in sorted(os.walk(d, followlinks=True)):
  44. for fname in sorted(fnames):
  45. path = os.path.join(root, fname)
  46. if is_valid_file(path):
  47. item = (path, class_to_idx[target])
  48. images.append(item)
  49. return images
  50. class DatasetFolder(Dataset):
  51. """A generic data loader where the samples are arranged in this way:
  52. .. code-block:: text
  53. root/class_a/1.ext
  54. root/class_a/2.ext
  55. root/class_a/3.ext
  56. root/class_b/123.ext
  57. root/class_b/456.ext
  58. root/class_b/789.ext
  59. Args:
  60. root (str): Root directory path.
  61. loader (Callable, optional): A function to load a sample given its path. Default: None.
  62. extensions (list[str]|tuple[str], optional): A list of allowed extensions.
  63. Both :attr:`extensions` and :attr:`is_valid_file` should not be passed.
  64. If this value is not set, the default is to use ('.jpg', '.jpeg', '.png',
  65. '.ppm', '.bmp', '.pgm', '.tif', '.tiff', '.webp'). Default: None.
  66. transform (Callable, optional): A function/transform that takes in
  67. a sample and returns a transformed version. Default: None.
  68. is_valid_file (Callable, optional): A function that takes path of a file
  69. and check if the file is a valid file. Both :attr:`extensions` and
  70. :attr:`is_valid_file` should not be passed. Default: None.
  71. Returns:
  72. :ref:`api_paddle_io_Dataset`. An instance of DatasetFolder.
  73. Attributes:
  74. classes (list[str]): List of the class names.
  75. class_to_idx (dict[str, int]): Dict with items (class_name, class_index).
  76. samples (list[tuple[str, int]]): List of (sample_path, class_index) tuples.
  77. targets (list[int]): The class_index value for each image in the dataset.
  78. Example:
  79. .. code-block:: python
  80. >>> import shutil
  81. >>> import tempfile
  82. >>> import cv2
  83. >>> import numpy as np
  84. >>> import paddle.vision.transforms as T
  85. >>> from pathlib import Path
  86. >>> from paddle.vision.datasets import DatasetFolder
  87. >>> def make_fake_file(img_path: str):
  88. ... if img_path.endswith((".jpg", ".png", ".jpeg")):
  89. ... fake_img = np.random.randint(0, 256, (32, 32, 3), dtype=np.uint8)
  90. ... cv2.imwrite(img_path, fake_img)
  91. ... elif img_path.endswith(".txt"):
  92. ... with open(img_path, "w") as f:
  93. ... f.write("This is a fake file.")
  94. >>> def make_directory(root, directory_hierarchy, file_maker=make_fake_file):
  95. ... root = Path(root)
  96. ... root.mkdir(parents=True, exist_ok=True)
  97. ... for subpath in directory_hierarchy:
  98. ... if isinstance(subpath, str):
  99. ... filepath = root / subpath
  100. ... file_maker(str(filepath))
  101. ... else:
  102. ... dirname = list(subpath.keys())[0]
  103. ... make_directory(root / dirname, subpath[dirname])
  104. >>> directory_hierarchy = [
  105. ... {"class_0": [
  106. ... "abc.jpg",
  107. ... "def.png"]},
  108. ... {"class_1": [
  109. ... "ghi.jpeg",
  110. ... "jkl.png",
  111. ... {"mno": [
  112. ... "pqr.jpeg",
  113. ... "stu.jpg"]}]},
  114. ... "this_will_be_ignored.txt",
  115. ... ]
  116. >>> # You can replace this with any directory to explore the structure
  117. >>> # of generated data. e.g. fake_data_dir = "./temp_dir"
  118. >>> fake_data_dir = tempfile.mkdtemp()
  119. >>> make_directory(fake_data_dir, directory_hierarchy)
  120. >>> data_folder_1 = DatasetFolder(fake_data_dir)
  121. >>> print(data_folder_1.classes)
  122. ['class_0', 'class_1']
  123. >>> print(data_folder_1.class_to_idx)
  124. {'class_0': 0, 'class_1': 1}
  125. >>> print(data_folder_1.samples)
  126. >>> # doctest: +SKIP(it's different with windows)
  127. [('./temp_dir/class_0/abc.jpg', 0), ('./temp_dir/class_0/def.png', 0),
  128. ('./temp_dir/class_1/ghi.jpeg', 1), ('./temp_dir/class_1/jkl.png', 1),
  129. ('./temp_dir/class_1/mno/pqr.jpeg', 1), ('./temp_dir/class_1/mno/stu.jpg', 1)]
  130. >>> # doctest: -SKIP
  131. >>> print(data_folder_1.targets)
  132. [0, 0, 1, 1, 1, 1]
  133. >>> print(len(data_folder_1))
  134. 6
  135. >>> for i in range(len(data_folder_1)):
  136. ... img, label = data_folder_1[i]
  137. ... # do something with img and label
  138. ... print(type(img), img.size, label)
  139. ... # <class 'PIL.Image.Image'> (32, 32) 0
  140. >>> transform = T.Compose(
  141. ... [
  142. ... T.Resize(64),
  143. ... T.ToTensor(),
  144. ... T.Normalize(
  145. ... mean=[0.5, 0.5, 0.5],
  146. ... std=[0.5, 0.5, 0.5],
  147. ... to_rgb=True,
  148. ... ),
  149. ... ]
  150. ... )
  151. >>> data_folder_2 = DatasetFolder(
  152. ... fake_data_dir,
  153. ... loader=lambda x: cv2.imread(x), # load image with OpenCV
  154. ... extensions=(".jpg",), # only load *.jpg files
  155. ... transform=transform, # apply transform to every image
  156. ... )
  157. >>> print([img_path for img_path, label in data_folder_2.samples])
  158. >>> # doctest: +SKIP(it's different with windows)
  159. ['./temp_dir/class_0/abc.jpg', './temp_dir/class_1/mno/stu.jpg']
  160. >>> # doctest: -SKIP
  161. >>> print(len(data_folder_2))
  162. 2
  163. >>> for img, label in iter(data_folder_2):
  164. ... # do something with img and label
  165. ... print(type(img), img.shape, label)
  166. ... # <class 'paddle.Tensor'> [3, 64, 64] 0
  167. >>> shutil.rmtree(fake_data_dir)
  168. """
  169. def __init__(
  170. self,
  171. root,
  172. loader=None,
  173. extensions=None,
  174. transform=None,
  175. is_valid_file=None,
  176. ):
  177. self.root = root
  178. self.transform = transform
  179. if extensions is None:
  180. extensions = IMG_EXTENSIONS
  181. classes, class_to_idx = self._find_classes(self.root)
  182. samples = make_dataset(
  183. self.root, class_to_idx, extensions, is_valid_file
  184. )
  185. if len(samples) == 0:
  186. raise (
  187. RuntimeError(
  188. "Found 0 directories in subfolders of: " + self.root + "\n"
  189. "Supported extensions are: " + ",".join(extensions)
  190. )
  191. )
  192. self.loader = default_loader if loader is None else loader
  193. self.extensions = extensions
  194. self.classes = classes
  195. self.class_to_idx = class_to_idx
  196. self.samples = samples
  197. self.targets = [s[1] for s in samples]
  198. self.dtype = paddle.get_default_dtype()
  199. def _find_classes(self, dir):
  200. """
  201. Finds the class folders in a dataset.
  202. Args:
  203. dir (string): Root directory path.
  204. Returns:
  205. tuple: (classes, class_to_idx) where classes are relative to (dir),
  206. and class_to_idx is a dictionary.
  207. """
  208. classes = [d.name for d in os.scandir(dir) if d.is_dir()]
  209. classes.sort()
  210. class_to_idx = {classes[i]: i for i in range(len(classes))}
  211. return classes, class_to_idx
  212. def __getitem__(self, index):
  213. """
  214. Args:
  215. index (int): Index
  216. Returns:
  217. tuple: (sample, target) where target is class_index of the target class.
  218. """
  219. path, target = self.samples[index]
  220. sample = self.loader(path)
  221. if self.transform is not None:
  222. sample = self.transform(sample)
  223. return sample, target
  224. def __len__(self):
  225. return len(self.samples)
  226. IMG_EXTENSIONS = (
  227. '.jpg',
  228. '.jpeg',
  229. '.png',
  230. '.ppm',
  231. '.bmp',
  232. '.pgm',
  233. '.tif',
  234. '.tiff',
  235. '.webp',
  236. )
  237. def pil_loader(path):
  238. with open(path, 'rb') as f:
  239. img = Image.open(f)
  240. return img.convert('RGB')
  241. def cv2_loader(path):
  242. cv2 = try_import('cv2')
  243. return cv2.cvtColor(cv2.imread(path), cv2.COLOR_BGR2RGB)
  244. def default_loader(path):
  245. from paddle.vision import get_image_backend
  246. if get_image_backend() == 'cv2':
  247. return cv2_loader(path)
  248. else:
  249. return pil_loader(path)
  250. class ImageFolder(Dataset):
  251. """A generic data loader where the samples are arranged in this way:
  252. .. code-block:: text
  253. root/1.ext
  254. root/2.ext
  255. root/sub_dir/3.ext
  256. Args:
  257. root (str): Root directory path.
  258. loader (Callable, optional): A function to load a sample given its path. Default: None.
  259. extensions (list[str]|tuple[str], optional): A list of allowed extensions.
  260. Both :attr:`extensions` and :attr:`is_valid_file` should not be passed.
  261. If this value is not set, the default is to use ('.jpg', '.jpeg', '.png',
  262. '.ppm', '.bmp', '.pgm', '.tif', '.tiff', '.webp'). Default: None.
  263. transform (Callable, optional): A function/transform that takes in
  264. a sample and returns a transformed version. Default: None.
  265. is_valid_file (Callable, optional): A function that takes path of a file
  266. and check if the file is a valid file. Both :attr:`extensions` and
  267. :attr:`is_valid_file` should not be passed. Default: None.
  268. Returns:
  269. :ref:`api_paddle_io_Dataset`. An instance of ImageFolder.
  270. Attributes:
  271. samples (list[str]): List of sample path.
  272. Example:
  273. .. code-block:: python
  274. >>> import shutil
  275. >>> import tempfile
  276. >>> import cv2
  277. >>> import numpy as np
  278. >>> import paddle.vision.transforms as T
  279. >>> from pathlib import Path
  280. >>> from paddle.vision.datasets import ImageFolder
  281. >>> def make_fake_file(img_path: str):
  282. ... if img_path.endswith((".jpg", ".png", ".jpeg")):
  283. ... fake_img = np.random.randint(0, 256, (32, 32, 3), dtype=np.uint8)
  284. ... cv2.imwrite(img_path, fake_img)
  285. ... elif img_path.endswith(".txt"):
  286. ... with open(img_path, "w") as f:
  287. ... f.write("This is a fake file.")
  288. >>> def make_directory(root, directory_hierarchy, file_maker=make_fake_file):
  289. ... root = Path(root)
  290. ... root.mkdir(parents=True, exist_ok=True)
  291. ... for subpath in directory_hierarchy:
  292. ... if isinstance(subpath, str):
  293. ... filepath = root / subpath
  294. ... file_maker(str(filepath))
  295. ... else:
  296. ... dirname = list(subpath.keys())[0]
  297. ... make_directory(root / dirname, subpath[dirname])
  298. >>> directory_hierarchy = [
  299. ... "abc.jpg",
  300. ... "def.png",
  301. ... {"ghi": [
  302. ... "jkl.jpeg",
  303. ... {"mno": [
  304. ... "pqr.jpg"]}]},
  305. ... "this_will_be_ignored.txt",
  306. ... ]
  307. >>> # You can replace this with any directory to explore the structure
  308. >>> # of generated data. e.g. fake_data_dir = "./temp_dir"
  309. >>> fake_data_dir = tempfile.mkdtemp()
  310. >>> make_directory(fake_data_dir, directory_hierarchy)
  311. >>> image_folder_1 = ImageFolder(fake_data_dir)
  312. >>> print(image_folder_1.samples)
  313. >>> # doctest: +SKIP(it's different with windows)
  314. ['./temp_dir/abc.jpg', './temp_dir/def.png',
  315. './temp_dir/ghi/jkl.jpeg', './temp_dir/ghi/mno/pqr.jpg']
  316. >>> # doctest: -SKIP
  317. >>> print(len(image_folder_1))
  318. 4
  319. >>> for i in range(len(image_folder_1)):
  320. ... (img,) = image_folder_1[i]
  321. ... # do something with img
  322. ... print(type(img), img.size)
  323. ... # <class 'PIL.Image.Image'> (32, 32)
  324. >>> transform = T.Compose(
  325. ... [
  326. ... T.Resize(64),
  327. ... T.ToTensor(),
  328. ... T.Normalize(
  329. ... mean=[0.5, 0.5, 0.5],
  330. ... std=[0.5, 0.5, 0.5],
  331. ... to_rgb=True,
  332. ... ),
  333. ... ]
  334. ... )
  335. >>> image_folder_2 = ImageFolder(
  336. ... fake_data_dir,
  337. ... loader=lambda x: cv2.imread(x), # load image with OpenCV
  338. ... extensions=(".jpg",), # only load *.jpg files
  339. ... transform=transform, # apply transform to every image
  340. ... )
  341. >>> print(image_folder_2.samples)
  342. >>> # doctest: +SKIP(it's different with windows)
  343. ['./temp_dir/abc.jpg', './temp_dir/ghi/mno/pqr.jpg']
  344. >>> # doctest: -SKIP
  345. >>> print(len(image_folder_2))
  346. 2
  347. >>> for (img,) in iter(image_folder_2):
  348. ... # do something with img
  349. ... print(type(img), img.shape)
  350. ... # <class 'paddle.Tensor'> [3, 64, 64]
  351. >>> shutil.rmtree(fake_data_dir)
  352. """
  353. def __init__(
  354. self,
  355. root,
  356. loader=None,
  357. extensions=None,
  358. transform=None,
  359. is_valid_file=None,
  360. ):
  361. self.root = root
  362. if extensions is None:
  363. extensions = IMG_EXTENSIONS
  364. samples = []
  365. path = os.path.expanduser(root)
  366. if extensions is not None:
  367. def is_valid_file(x):
  368. return has_valid_extension(x, extensions)
  369. for root, _, fnames in sorted(os.walk(path, followlinks=True)):
  370. for fname in sorted(fnames):
  371. f = os.path.join(root, fname)
  372. if is_valid_file(f):
  373. samples.append(f)
  374. if len(samples) == 0:
  375. raise (
  376. RuntimeError(
  377. "Found 0 files in subfolders of: " + self.root + "\n"
  378. "Supported extensions are: " + ",".join(extensions)
  379. )
  380. )
  381. self.loader = default_loader if loader is None else loader
  382. self.extensions = extensions
  383. self.samples = samples
  384. self.transform = transform
  385. def __getitem__(self, index):
  386. """
  387. Args:
  388. index (int): Index
  389. Returns:
  390. sample of specific index.
  391. """
  392. path = self.samples[index]
  393. sample = self.loader(path)
  394. if self.transform is not None:
  395. sample = self.transform(sample)
  396. return [sample]
  397. def __len__(self):
  398. return len(self.samples)