dataset.py 22 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525526527528529530531532533534535536537538539540541542543544545546547548549550551552553554555556557558559560561562563564565566567568569570571572573574575576577578579580581582583584585586587588589590591592593594595596597598599600601602603604605606607608609610611612613614615616617618619620621622623624625626627628629630631632633634635636637638639640641642643644645646647648649650651652653654655656657658659660661662663664665666667668669670671672673
  1. # Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved.
  2. #
  3. # Licensed under the Apache License, Version 2.0 (the "License");
  4. # you may not use this file except in compliance with the License.
  5. # You may obtain a copy of the License at
  6. #
  7. # http://www.apache.org/licenses/LICENSE-2.0
  8. #
  9. # Unless required by applicable law or agreed to in writing, software
  10. # distributed under the License is distributed on an "AS IS" BASIS,
  11. # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  12. # See the License for the specific language governing permissions and
  13. # limitations under the License.
  14. import bisect
  15. import math
  16. import warnings
  17. from typing import Iterable
  18. import paddle
  19. from ... import framework
  20. class Dataset:
  21. """
  22. An abstract class to encapsulate methods and behaviors of datasets.
  23. All datasets in map-style(dataset samples can be get by a given key)
  24. should be a subclass of `paddle.io.Dataset`. All subclasses should
  25. implement following methods:
  26. :code:`__getitem__`: get sample from dataset with a given index. This
  27. method is required by reading dataset sample in :code:`paddle.io.DataLoader`.
  28. :code:`__len__`: return dataset sample number. This method is required
  29. by some implements of :code:`paddle.io.BatchSampler`
  30. see :code:`paddle.io.DataLoader`.
  31. Examples:
  32. .. code-block:: python
  33. >>> import numpy as np
  34. >>> from paddle.io import Dataset
  35. >>> # define a random dataset
  36. >>> class RandomDataset(Dataset):
  37. ... def __init__(self, num_samples):
  38. ... self.num_samples = num_samples
  39. ...
  40. ... def __getitem__(self, idx):
  41. ... image = np.random.random([784]).astype('float32')
  42. ... label = np.random.randint(0, 9, (1, )).astype('int64')
  43. ... return image, label
  44. ...
  45. ... def __len__(self):
  46. ... return self.num_samples
  47. ...
  48. >>> dataset = RandomDataset(10)
  49. >>> for i in range(len(dataset)):
  50. ... image, label = dataset[i]
  51. ... # do something
  52. """
  53. def __init__(self):
  54. pass
  55. def __getitem__(self, idx):
  56. raise NotImplementedError(
  57. "'{}' not implement in class "
  58. "{}".format('__getitem__', self.__class__.__name__)
  59. )
  60. def __len__(self):
  61. raise NotImplementedError(
  62. "'{}' not implement in class "
  63. "{}".format('__len__', self.__class__.__name__)
  64. )
  65. class IterableDataset(Dataset):
  66. """
  67. An abstract class to encapsulate methods and behaviors of iterable datasets.
  68. All datasets in iterable-style (can only get sample one by one sequentially, like
  69. a Python iterator) should be a subclass of :ref:`api_paddle_io_IterableDataset` . All subclasses should
  70. implement following methods:
  71. :code:`__iter__`: yield sample sequentially. This method is required by reading dataset sample in :ref:`api_paddle_io_DataLoader` .
  72. .. note::
  73. do not implement :code:`__getitem__` and :code:`__len__` in IterableDataset, should not be called either.
  74. see :ref:`api_paddle_io_DataLoader` .
  75. Examples:
  76. .. code-block:: python
  77. :name: code-example1
  78. >>> import numpy as np
  79. >>> from paddle.io import IterableDataset
  80. >>> # define a random dataset
  81. >>> class RandomDataset(IterableDataset):
  82. ... def __init__(self, num_samples):
  83. ... self.num_samples = num_samples
  84. ...
  85. ... def __iter__(self):
  86. ... for i in range(self.num_samples):
  87. ... image = np.random.random([784]).astype('float32')
  88. ... label = np.random.randint(0, 9, (1, )).astype('int64')
  89. ... yield image, label
  90. ...
  91. >>> dataset = RandomDataset(10)
  92. >>> for img, label in dataset:
  93. ... # do something
  94. ... ...
  95. When :attr:`num_workers > 0`, each worker has a different copy of the dataset object and
  96. will yield whole dataset samples, which means samples in dataset will be repeated in
  97. :attr:`num_workers` times. If it is required for each sample to yield only once, there
  98. are two methods to configure different copy in each worker process to avoid duplicate data
  99. among workers as follows. In both the methods, worker information that can be getted in
  100. a worker process by `paddle.io.get_worker_info` will be needed.
  101. splitting data copy in each worker in :code:`__iter__`
  102. .. code-block:: python
  103. :name: code-example2
  104. >>> import math
  105. >>> import paddle
  106. >>> import numpy as np
  107. >>> from paddle.io import IterableDataset, DataLoader, get_worker_info
  108. >>> class SplitedIterableDataset(IterableDataset):
  109. ... def __init__(self, start, end):
  110. ... self.start = start
  111. ... self.end = end
  112. ...
  113. ... def __iter__(self):
  114. ... worker_info = get_worker_info()
  115. ... if worker_info is None:
  116. ... iter_start = self.start
  117. ... iter_end = self.end
  118. ... else:
  119. ... per_worker = int(
  120. ... math.ceil((self.end - self.start) / float(
  121. ... worker_info.num_workers)))
  122. ... worker_id = worker_info.id
  123. ... iter_start = self.start + worker_id * per_worker
  124. ... iter_end = min(iter_start + per_worker, self.end)
  125. ...
  126. ... for i in range(iter_start, iter_end):
  127. ... yield np.array([i])
  128. ...
  129. >>> dataset = SplitedIterableDataset(start=2, end=9)
  130. >>> dataloader = DataLoader(
  131. ... dataset,
  132. ... num_workers=2,
  133. ... batch_size=1,
  134. ... drop_last=True)
  135. ...
  136. >>> for data in dataloader:
  137. ... print(data) # doctest: +SKIP("The output depends on the environment.")
  138. Tensor(shape=[1, 1], dtype=int64, place=Place(cpu), stop_gradient=True,
  139. [[2]])
  140. Tensor(shape=[1, 1], dtype=int64, place=Place(cpu), stop_gradient=True,
  141. [[3]])
  142. Tensor(shape=[1, 1], dtype=int64, place=Place(cpu), stop_gradient=True,
  143. [[4]])
  144. Tensor(shape=[1, 1], dtype=int64, place=Place(cpu), stop_gradient=True,
  145. [[5]])
  146. Tensor(shape=[1, 1], dtype=int64, place=Place(cpu), stop_gradient=True,
  147. [[6]])
  148. Tensor(shape=[1, 1], dtype=int64, place=Place(cpu), stop_gradient=True,
  149. [[7]])
  150. Tensor(shape=[1, 1], dtype=int64, place=Place(cpu), stop_gradient=True,
  151. [[8]])
  152. splitting data copy in each worker by :code:`worker_init_fn`
  153. .. code-block:: python
  154. :name: code-example3
  155. >>> import math
  156. >>> import paddle
  157. >>> import numpy as np
  158. >>> from paddle.io import IterableDataset, DataLoader, get_worker_info
  159. >>> class RangeIterableDataset(IterableDataset):
  160. ... def __init__(self, start, end):
  161. ... self.start = start
  162. ... self.end = end
  163. ...
  164. ... def __iter__(self):
  165. ... for i in range(self.start, self.end):
  166. ... yield np.array([i])
  167. ...
  168. >>> dataset = RangeIterableDataset(start=2, end=9)
  169. >>> def worker_init_fn(worker_id):
  170. ... worker_info = get_worker_info()
  171. ...
  172. ... dataset = worker_info.dataset
  173. ... start = dataset.start
  174. ... end = dataset.end
  175. ... num_per_worker = int(
  176. ... math.ceil((end - start) / float(worker_info.num_workers)))
  177. ...
  178. ... worker_id = worker_info.id
  179. ... dataset.start = start + worker_id * num_per_worker
  180. ... dataset.end = min(dataset.start + num_per_worker, end)
  181. ...
  182. >>> dataloader = DataLoader(
  183. ... dataset,
  184. ... num_workers=2,
  185. ... batch_size=1,
  186. ... drop_last=True,
  187. ... worker_init_fn=worker_init_fn)
  188. ...
  189. >>> for data in dataloader:
  190. ... print(data) # doctest: +SKIP("The output depends on the environment.")
  191. Tensor(shape=[1, 1], dtype=int64, place=Place(cpu), stop_gradient=True,
  192. [[2]])
  193. Tensor(shape=[1, 1], dtype=int64, place=Place(cpu), stop_gradient=True,
  194. [[3]])
  195. Tensor(shape=[1, 1], dtype=int64, place=Place(cpu), stop_gradient=True,
  196. [[4]])
  197. Tensor(shape=[1, 1], dtype=int64, place=Place(cpu), stop_gradient=True,
  198. [[5]])
  199. Tensor(shape=[1, 1], dtype=int64, place=Place(cpu), stop_gradient=True,
  200. [[6]])
  201. Tensor(shape=[1, 1], dtype=int64, place=Place(cpu), stop_gradient=True,
  202. [[7]])
  203. Tensor(shape=[1, 1], dtype=int64, place=Place(cpu), stop_gradient=True,
  204. [[8]])
  205. """
  206. def __init__(self):
  207. pass
  208. def __iter__(self):
  209. raise NotImplementedError(
  210. "'{}' not implement in class "
  211. "{}".format('__iter__', self.__class__.__name__)
  212. )
  213. def __getitem__(self, idx):
  214. raise RuntimeError(
  215. "'{}' should not be called for IterableDataset"
  216. "{}".format('__getitem__', self.__class__.__name__)
  217. )
  218. def __len__(self):
  219. raise RuntimeError(
  220. "'{}' should not be called for IterableDataset"
  221. "{}".format('__len__', self.__class__.__name__)
  222. )
  223. class TensorDataset(Dataset):
  224. """
  225. Dataset defined by a list of tensors.
  226. Each tensor should be in shape of [N, ...], while N is the sample number,
  227. and each tensor contains a field of sample, :code:`TensorDataset` retrieve
  228. each sample by indexing tensors in the 1st dimension.
  229. Args:
  230. tensors(list|tuple): A list/tuple of tensors with same shape in the 1st dimension.
  231. Returns:
  232. Dataset: a Dataset instance wrapping tensors.
  233. Examples:
  234. .. code-block:: python
  235. >>> import numpy as np
  236. >>> import paddle
  237. >>> from paddle.io import TensorDataset
  238. >>> input_np = np.random.random([2, 3, 4]).astype('float32')
  239. >>> input = paddle.to_tensor(input_np)
  240. >>> label_np = np.random.random([2, 1]).astype('int32')
  241. >>> label = paddle.to_tensor(label_np)
  242. >>> dataset = TensorDataset([input, label])
  243. >>> for i in range(len(dataset)):
  244. ... input, label = dataset[i]
  245. ... # do something
  246. """
  247. def __init__(self, tensors):
  248. if not framework.in_dynamic_mode():
  249. raise RuntimeError(
  250. "TensorDataset con only be used in imperative mode"
  251. )
  252. assert all(
  253. tensor.shape[0] == tensors[0].shape[0] for tensor in tensors
  254. ), "tensors not have same shape of the 1st dimension"
  255. self.tensors = tensors
  256. def __getitem__(self, index):
  257. return tuple(tensor[index] for tensor in self.tensors)
  258. def __len__(self):
  259. return self.tensors[0].shape[0]
  260. def to_list(value):
  261. if value is None:
  262. return value
  263. if isinstance(value, (list, tuple)):
  264. return list(value)
  265. return [value]
  266. class ComposeDataset(Dataset):
  267. """
  268. A Dataset which composes fields of multiple datasets.
  269. This dataset is used for composing fields of multiple map-style
  270. datasets of same length.
  271. Args:
  272. datasets(list of Dataset): List of datasets to be composed.
  273. Returns:
  274. Dataset: A Dataset which composes fields of multiple datasets.
  275. Examples:
  276. .. code-block:: python
  277. >>> import numpy as np
  278. >>> import paddle
  279. >>> from paddle.io import Dataset, ComposeDataset
  280. >>> # define a random dataset
  281. >>> class RandomDataset(Dataset):
  282. ... def __init__(self, num_samples):
  283. ... self.num_samples = num_samples
  284. ...
  285. ... def __getitem__(self, idx):
  286. ... image = np.random.random([32]).astype('float32')
  287. ... label = np.random.randint(0, 9, (1, )).astype('int64')
  288. ... return image, label
  289. ...
  290. ... def __len__(self):
  291. ... return self.num_samples
  292. ...
  293. >>> dataset = ComposeDataset([RandomDataset(10), RandomDataset(10)])
  294. >>> for i in range(len(dataset)):
  295. ... image1, label1, image2, label2 = dataset[i]
  296. ... # do something
  297. """
  298. def __init__(self, datasets):
  299. self.datasets = list(datasets)
  300. assert len(self.datasets) > 0, "input datasets should not be empty"
  301. for i, dataset in enumerate(self.datasets):
  302. assert isinstance(
  303. dataset, Dataset
  304. ), "each input dataset should be paddle.io.Dataset"
  305. assert not isinstance(
  306. dataset, IterableDataset
  307. ), "paddle.io.IterableDataset not supported"
  308. if i > 0:
  309. assert len(dataset) == len(
  310. self.datasets[i - 1]
  311. ), "lengths of datasets should be same"
  312. def __len__(self):
  313. return len(self.datasets[0])
  314. def __getitem__(self, idx):
  315. sample = []
  316. for dataset in self.datasets:
  317. sample.extend(to_list(dataset[idx]))
  318. return tuple(sample)
  319. class ChainDataset(IterableDataset):
  320. """
  321. A Dataset which chains multiple iterable-style datasets.
  322. This dataset is used for assembling multiple datasets which should
  323. be :ref:`api_paddle_io_IterableDataset`.
  324. Args:
  325. datasets(list of IterableDatasets): List of datasets to be chainned.
  326. Returns:
  327. paddle.io.IterableDataset: A Dataset which chains fields of multiple datasets.
  328. Examples:
  329. .. code-block:: python
  330. >>> import numpy as np
  331. >>> import paddle
  332. >>> from paddle.io import IterableDataset, ChainDataset
  333. >>> # define a random dataset
  334. >>> class RandomDataset(IterableDataset):
  335. ... def __init__(self, num_samples):
  336. ... self.num_samples = num_samples
  337. ...
  338. ... def __iter__(self):
  339. ... for i in range(10):
  340. ... image = np.random.random([32]).astype('float32')
  341. ... label = np.random.randint(0, 9, (1, )).astype('int64')
  342. ... yield image, label
  343. ...
  344. >>> dataset = ChainDataset([RandomDataset(10), RandomDataset(10)])
  345. >>> for image, label in iter(dataset):
  346. ... # do something
  347. ... ...
  348. """
  349. def __init__(self, datasets):
  350. self.datasets = list(datasets)
  351. assert len(self.datasets) > 0, "input datasets should not be empty"
  352. for i, dataset in enumerate(self.datasets):
  353. assert isinstance(
  354. dataset, IterableDataset
  355. ), "ChainDataset only support paddle.io.IterableDataset"
  356. def __iter__(self):
  357. for dataset in self.datasets:
  358. yield from dataset
  359. class Subset(Dataset):
  360. """
  361. Subset of a dataset at specified indices.
  362. Args:
  363. dataset (Dataset): The whole Dataset.
  364. indices (sequence): Indices in the whole set selected for subset.
  365. Returns:
  366. List[Dataset]: A Dataset which is the subset of the original dataset.
  367. Examples:
  368. .. code-block:: python
  369. >>> import paddle
  370. >>> from paddle.io import Subset
  371. >>> # example 1:
  372. >>> a = paddle.io.Subset(dataset=range(1, 4), indices=[0, 2])
  373. >>> print(list(a))
  374. [1, 3]
  375. >>> # example 2:
  376. >>> b = paddle.io.Subset(dataset=range(1, 4), indices=[1, 1])
  377. >>> print(list(b))
  378. [2, 2]
  379. """
  380. def __init__(self, dataset, indices):
  381. self.dataset = dataset
  382. self.indices = indices
  383. def __getitem__(self, idx):
  384. return self.dataset[self.indices[idx]]
  385. def __len__(self):
  386. return len(self.indices)
  387. def random_split(dataset, lengths, generator=None):
  388. """
  389. Randomly split a dataset into non-overlapping new datasets of given lengths.
  390. Optionally fix the generator for reproducible results, e.g.:
  391. Args:
  392. dataset (Dataset): Dataset to be split
  393. lengths (sequence): lengths or fractions of splits to be produced
  394. generator (Generator, optional): Generator used for the random permutation. Default is None then the DefaultGenerator is used in manual_seed().
  395. Returns:
  396. Datasets: A list of subset Datasets, which are the non-overlapping subsets of the original Dataset.
  397. Examples:
  398. .. code-block:: python
  399. >>> import paddle
  400. >>> paddle.seed(2023)
  401. >>> a_list = paddle.io.random_split(range(10), [3, 7])
  402. >>> print(len(a_list))
  403. 2
  404. >>> # output of the first subset
  405. >>> for idx, v in enumerate(a_list[0]):
  406. ... print(idx, v) # doctest: +SKIP("The output depends on the environment.")
  407. 0 7
  408. 1 6
  409. 2 5
  410. >>> # output of the second subset
  411. >>> for idx, v in enumerate(a_list[1]):
  412. ... print(idx, v) # doctest: +SKIP("The output depends on the environment.")
  413. 0 1
  414. 1 9
  415. 2 4
  416. 3 2
  417. 4 0
  418. 5 3
  419. 6 8
  420. """
  421. if math.isclose(sum(lengths), 1) and sum(lengths) <= 1:
  422. subset_lengths = []
  423. for i, frac in enumerate(lengths):
  424. if frac < 0 or frac > 1:
  425. raise ValueError(
  426. f"Fraction at index {i} is not between 0 and 1"
  427. )
  428. n_items_in_split = int(math.floor(len(dataset) * frac))
  429. subset_lengths.append(n_items_in_split)
  430. remainder = len(dataset) - sum(subset_lengths)
  431. for i in range(remainder):
  432. idx_to_add_at = i % len(subset_lengths)
  433. subset_lengths[idx_to_add_at] += 1
  434. lengths = subset_lengths
  435. for i, length in enumerate(lengths):
  436. if length == 0:
  437. warnings.warn(
  438. f"Length of split at index {i} is 0. "
  439. f"This might result in an empty dataset."
  440. )
  441. # Cannot verify that dataset is Sized
  442. if sum(lengths) != len(dataset): # type: ignore
  443. raise ValueError(
  444. "Sum of input lengths does not equal the length of the input dataset!"
  445. )
  446. # TODO(@Joejiong): support Variable or Tensor type with .tolist class member function.
  447. # For example var.item() and var.tolist()
  448. indices = paddle.randperm(sum(lengths)).tolist()
  449. return [
  450. Subset(dataset, indices[offset - length : offset])
  451. for offset, length in zip(_accumulate(lengths), lengths)
  452. ]
  453. def _accumulate(iterable, fn=lambda x, y: x + y):
  454. """
  455. Return running totals
  456. Args:
  457. iterable: any iterable object for example dataset.
  458. y (x): one element in the iterable object.
  459. fn (x, y): Defaults to lambdax.
  460. Yields:
  461. yields total from beginning iterator to current iterator.
  462. Example code:
  463. .. code-block:: python
  464. >>> list(_accumulate([1, 2, 3, 4, 5]))
  465. [1, 3, 6, 10, 15]
  466. >>> import operator
  467. >>> list(_accumulate([1, 2, 3, 4, 5], operator.mul))
  468. [1, 2, 6, 24, 120]
  469. """
  470. it = iter(iterable)
  471. try:
  472. total = next(it)
  473. except StopIteration:
  474. return
  475. yield total
  476. for element in it:
  477. total = fn(total, element)
  478. yield total
  479. class ConcatDataset(Dataset):
  480. """
  481. Dataset as a concatenation of multiple datasets.
  482. This class is useful to assemble different existing datasets.
  483. Args:
  484. datasets (sequence): List of datasets to be concatenated
  485. Returns:
  486. Dataset: A Dataset which concatenated by multiple datasets.
  487. Examples:
  488. .. code-block:: python
  489. >>> import numpy as np
  490. >>> import paddle
  491. >>> from paddle.io import Dataset, ConcatDataset
  492. >>> # define a random dataset
  493. >>> class RandomDataset(Dataset):
  494. ... def __init__(self, num_samples):
  495. ... self.num_samples = num_samples
  496. ...
  497. ... def __getitem__(self, idx):
  498. ... image = np.random.random([32]).astype('float32')
  499. ... label = np.random.randint(0, 9, (1, )).astype('int64')
  500. ... return image, label
  501. ...
  502. ... def __len__(self):
  503. ... return self.num_samples
  504. ...
  505. >>> dataset = ConcatDataset([RandomDataset(10), RandomDataset(10)])
  506. >>> for i in range(len(dataset)):
  507. ... image, label = dataset[i]
  508. ... # do something
  509. """
  510. @staticmethod
  511. def cumsum(sequence):
  512. r, s = [], 0
  513. for e in sequence:
  514. l = len(e)
  515. r.append(l + s)
  516. s += l
  517. return r
  518. def __init__(self, datasets: Iterable[Dataset]):
  519. self.datasets = list(datasets)
  520. assert (
  521. len(self.datasets) > 0
  522. ), 'datasets should not be an empty iterable'
  523. for d in self.datasets:
  524. assert not isinstance(
  525. d, IterableDataset
  526. ), "ConcatDataset does not support IterableDataset"
  527. self.cumulative_sizes = self.cumsum(self.datasets)
  528. def __len__(self):
  529. return self.cumulative_sizes[-1]
  530. def __getitem__(self, idx):
  531. if idx < 0:
  532. if -idx > len(self):
  533. raise ValueError(
  534. "absolute value of index should not exceed dataset length"
  535. )
  536. idx = len(self) + idx
  537. dataset_idx = bisect.bisect_right(self.cumulative_sizes, idx)
  538. if dataset_idx == 0:
  539. sample_idx = idx
  540. else:
  541. sample_idx = idx - self.cumulative_sizes[dataset_idx - 1]
  542. return self.datasets[dataset_idx][sample_idx]