sampler.py 13 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367
  1. # mypy: allow-untyped-defs
  2. import itertools
  3. from collections.abc import Iterable, Iterator, Sequence, Sized
  4. from typing import Generic, Optional, TypeVar, Union
  5. import torch
  6. # Note: For benchmarking changes to samplers, see:
  7. # /benchmarks/data/samplers_bench.py
  8. # This benchmark compares the performance of different sampler implementations
  9. # and can be used to evaluate the impact of optimizations.
  10. __all__ = [
  11. "BatchSampler",
  12. "RandomSampler",
  13. "Sampler",
  14. "SequentialSampler",
  15. "SubsetRandomSampler",
  16. "WeightedRandomSampler",
  17. ]
  18. _T_co = TypeVar("_T_co", covariant=True)
  19. class Sampler(Generic[_T_co]):
  20. r"""Base class for all Samplers.
  21. Every Sampler subclass has to provide an :meth:`__iter__` method, providing a
  22. way to iterate over indices or lists of indices (batches) of dataset elements,
  23. and may provide a :meth:`__len__` method that returns the length of the returned iterators.
  24. Args:
  25. data_source (Dataset): This argument is not used and will be removed in 2.2.0.
  26. You may still have custom implementation that utilizes it.
  27. Example:
  28. >>> # xdoctest: +SKIP
  29. >>> class AccedingSequenceLengthSampler(Sampler[int]):
  30. >>> def __init__(self, data: List[str]) -> None:
  31. >>> self.data = data
  32. >>>
  33. >>> def __len__(self) -> int:
  34. >>> return len(self.data)
  35. >>>
  36. >>> def __iter__(self) -> Iterator[int]:
  37. >>> sizes = torch.tensor([len(x) for x in self.data])
  38. >>> yield from torch.argsort(sizes).tolist()
  39. >>>
  40. >>> class AccedingSequenceLengthBatchSampler(Sampler[List[int]]):
  41. >>> def __init__(self, data: List[str], batch_size: int) -> None:
  42. >>> self.data = data
  43. >>> self.batch_size = batch_size
  44. >>>
  45. >>> def __len__(self) -> int:
  46. >>> return (len(self.data) + self.batch_size - 1) // self.batch_size
  47. >>>
  48. >>> def __iter__(self) -> Iterator[List[int]]:
  49. >>> sizes = torch.tensor([len(x) for x in self.data])
  50. >>> for batch in torch.chunk(torch.argsort(sizes), len(self)):
  51. >>> yield batch.tolist()
  52. .. note:: The :meth:`__len__` method isn't strictly required by
  53. :class:`~torch.utils.data.DataLoader`, but is expected in any
  54. calculation involving the length of a :class:`~torch.utils.data.DataLoader`.
  55. """
  56. def __init__(self, data_source: Optional[Sized] = None) -> None:
  57. if data_source is not None:
  58. import warnings
  59. warnings.warn(
  60. "`data_source` argument is not used and will be removed in 2.2.0."
  61. "You may still have custom implementation that utilizes it."
  62. )
  63. def __iter__(self) -> Iterator[_T_co]:
  64. raise NotImplementedError
  65. # NOTE [ Lack of Default `__len__` in Python Abstract Base Classes ]
  66. #
  67. # Many times we have an abstract class representing a collection/iterable of
  68. # data, e.g., `torch.utils.data.Sampler`, with its subclasses optionally
  69. # implementing a `__len__` method. In such cases, we must make sure to not
  70. # provide a default implementation, because both straightforward default
  71. # implementations have their issues:
  72. #
  73. # + `return NotImplemented`:
  74. # Calling `len(subclass_instance)` raises:
  75. # TypeError: 'NotImplementedType' object cannot be interpreted as an integer
  76. #
  77. # + `raise NotImplementedError`:
  78. # This prevents triggering some fallback behavior. E.g., the built-in
  79. # `list(X)` tries to call `len(X)` first, and executes a different code
  80. # path if the method is not found or `NotImplemented` is returned, while
  81. # raising a `NotImplementedError` will propagate and make the call fail
  82. # where it could have used `__iter__` to complete the call.
  83. #
  84. # Thus, the only two sensible things to do are
  85. #
  86. # + **not** provide a default `__len__`.
  87. #
  88. # + raise a `TypeError` instead, which is what Python uses when users call
  89. # a method that is not defined on an object.
  90. # (@ssnl verifies that this works on at least Python 3.7.)
  91. class SequentialSampler(Sampler[int]):
  92. r"""Samples elements sequentially, always in the same order.
  93. Args:
  94. data_source (Dataset): dataset to sample from
  95. """
  96. data_source: Sized
  97. def __init__(self, data_source: Sized) -> None:
  98. self.data_source = data_source
  99. def __iter__(self) -> Iterator[int]:
  100. return iter(range(len(self.data_source)))
  101. def __len__(self) -> int:
  102. return len(self.data_source)
  103. class RandomSampler(Sampler[int]):
  104. r"""Samples elements randomly. If without replacement, then sample from a shuffled dataset.
  105. If with replacement, then user can specify :attr:`num_samples` to draw.
  106. Args:
  107. data_source (Dataset): dataset to sample from
  108. replacement (bool): samples are drawn on-demand with replacement if ``True``, default=``False``
  109. num_samples (int): number of samples to draw, default=`len(dataset)`.
  110. generator (Generator): Generator used in sampling.
  111. """
  112. data_source: Sized
  113. replacement: bool
  114. def __init__(
  115. self,
  116. data_source: Sized,
  117. replacement: bool = False,
  118. num_samples: Optional[int] = None,
  119. generator=None,
  120. ) -> None:
  121. self.data_source = data_source
  122. self.replacement = replacement
  123. self._num_samples = num_samples
  124. self.generator = generator
  125. if not isinstance(self.replacement, bool):
  126. raise TypeError(
  127. f"replacement should be a boolean value, but got replacement={self.replacement}"
  128. )
  129. if not isinstance(self.num_samples, int) or self.num_samples <= 0:
  130. raise ValueError(
  131. f"num_samples should be a positive integer value, but got num_samples={self.num_samples}"
  132. )
  133. @property
  134. def num_samples(self) -> int:
  135. # dataset size might change at runtime
  136. if self._num_samples is None:
  137. return len(self.data_source)
  138. return self._num_samples
  139. def __iter__(self) -> Iterator[int]:
  140. n = len(self.data_source)
  141. if self.generator is None:
  142. seed = int(torch.empty((), dtype=torch.int64).random_().item())
  143. generator = torch.Generator()
  144. generator.manual_seed(seed)
  145. else:
  146. generator = self.generator
  147. if self.replacement:
  148. for _ in range(self.num_samples // 32):
  149. yield from torch.randint(
  150. high=n, size=(32,), dtype=torch.int64, generator=generator
  151. ).tolist()
  152. yield from torch.randint(
  153. high=n,
  154. size=(self.num_samples % 32,),
  155. dtype=torch.int64,
  156. generator=generator,
  157. ).tolist()
  158. else:
  159. for _ in range(self.num_samples // n):
  160. yield from torch.randperm(n, generator=generator).tolist()
  161. yield from torch.randperm(n, generator=generator).tolist()[
  162. : self.num_samples % n
  163. ]
  164. def __len__(self) -> int:
  165. return self.num_samples
  166. class SubsetRandomSampler(Sampler[int]):
  167. r"""Samples elements randomly from a given list of indices, without replacement.
  168. Args:
  169. indices (sequence): a sequence of indices
  170. generator (Generator): Generator used in sampling.
  171. """
  172. indices: Sequence[int]
  173. def __init__(self, indices: Sequence[int], generator=None) -> None:
  174. self.indices = indices
  175. self.generator = generator
  176. def __iter__(self) -> Iterator[int]:
  177. for i in torch.randperm(len(self.indices), generator=self.generator).tolist():
  178. yield self.indices[i]
  179. def __len__(self) -> int:
  180. return len(self.indices)
  181. class WeightedRandomSampler(Sampler[int]):
  182. r"""Samples elements from ``[0,..,len(weights)-1]`` with given probabilities (weights).
  183. Args:
  184. weights (sequence) : a sequence of weights, not necessary summing up to one
  185. num_samples (int): number of samples to draw
  186. replacement (bool): if ``True``, samples are drawn with replacement.
  187. If not, they are drawn without replacement, which means that when a
  188. sample index is drawn for a row, it cannot be drawn again for that row.
  189. generator (Generator): Generator used in sampling.
  190. Example:
  191. >>> # xdoctest: +IGNORE_WANT("non-deterministic")
  192. >>> list(
  193. ... WeightedRandomSampler(
  194. ... [0.1, 0.9, 0.4, 0.7, 3.0, 0.6], 5, replacement=True
  195. ... )
  196. ... )
  197. [4, 4, 1, 4, 5]
  198. >>> list(
  199. ... WeightedRandomSampler(
  200. ... [0.9, 0.4, 0.05, 0.2, 0.3, 0.1], 5, replacement=False
  201. ... )
  202. ... )
  203. [0, 1, 4, 3, 2]
  204. """
  205. weights: torch.Tensor
  206. num_samples: int
  207. replacement: bool
  208. def __init__(
  209. self,
  210. weights: Sequence[float],
  211. num_samples: int,
  212. replacement: bool = True,
  213. generator=None,
  214. ) -> None:
  215. if (
  216. not isinstance(num_samples, int)
  217. or isinstance(num_samples, bool)
  218. or num_samples <= 0
  219. ):
  220. raise ValueError(
  221. f"num_samples should be a positive integer value, but got num_samples={num_samples}"
  222. )
  223. if not isinstance(replacement, bool):
  224. raise ValueError(
  225. f"replacement should be a boolean value, but got replacement={replacement}"
  226. )
  227. weights_tensor = torch.as_tensor(weights, dtype=torch.double)
  228. if len(weights_tensor.shape) != 1:
  229. raise ValueError(
  230. "weights should be a 1d sequence but given "
  231. f"weights have shape {tuple(weights_tensor.shape)}"
  232. )
  233. self.weights = weights_tensor
  234. self.num_samples = num_samples
  235. self.replacement = replacement
  236. self.generator = generator
  237. def __iter__(self) -> Iterator[int]:
  238. rand_tensor = torch.multinomial(
  239. self.weights, self.num_samples, self.replacement, generator=self.generator
  240. )
  241. yield from iter(rand_tensor.tolist())
  242. def __len__(self) -> int:
  243. return self.num_samples
  244. class BatchSampler(Sampler[list[int]]):
  245. r"""Wraps another sampler to yield a mini-batch of indices.
  246. Args:
  247. sampler (Sampler or Iterable): Base sampler. Can be any iterable object
  248. batch_size (int): Size of mini-batch.
  249. drop_last (bool): If ``True``, the sampler will drop the last batch if
  250. its size would be less than ``batch_size``
  251. Example:
  252. >>> list(
  253. ... BatchSampler(
  254. ... SequentialSampler(range(10)), batch_size=3, drop_last=False
  255. ... )
  256. ... )
  257. [[0, 1, 2], [3, 4, 5], [6, 7, 8], [9]]
  258. >>> list(
  259. ... BatchSampler(SequentialSampler(range(10)), batch_size=3, drop_last=True)
  260. ... )
  261. [[0, 1, 2], [3, 4, 5], [6, 7, 8]]
  262. """
  263. def __init__(
  264. self,
  265. sampler: Union[Sampler[int], Iterable[int]],
  266. batch_size: int,
  267. drop_last: bool,
  268. ) -> None:
  269. # Since collections.abc.Iterable does not check for `__getitem__`, which
  270. # is one way for an object to be an iterable, we don't do an `isinstance`
  271. # check here.
  272. if (
  273. not isinstance(batch_size, int)
  274. or isinstance(batch_size, bool)
  275. or batch_size <= 0
  276. ):
  277. raise ValueError(
  278. f"batch_size should be a positive integer value, but got batch_size={batch_size}"
  279. )
  280. if not isinstance(drop_last, bool):
  281. raise ValueError(
  282. f"drop_last should be a boolean value, but got drop_last={drop_last}"
  283. )
  284. self.sampler = sampler
  285. self.batch_size = batch_size
  286. self.drop_last = drop_last
  287. def __iter__(self) -> Iterator[list[int]]:
  288. sampler_iter = iter(self.sampler)
  289. if self.drop_last:
  290. # Create multiple references to the same iterator
  291. args = [sampler_iter] * self.batch_size
  292. for batch_droplast in zip(*args):
  293. yield [*batch_droplast]
  294. else:
  295. batch = [*itertools.islice(sampler_iter, self.batch_size)]
  296. while batch:
  297. yield batch
  298. batch = [*itertools.islice(sampler_iter, self.batch_size)]
  299. def __len__(self) -> int:
  300. # Can only be called if self.sampler has __len__ implemented
  301. # We cannot enforce this condition, so we turn off typechecking for the
  302. # implementation below.
  303. # Somewhat related: see NOTE [ Lack of Default `__len__` in Python Abstract Base Classes ]
  304. if self.drop_last:
  305. return len(self.sampler) // self.batch_size # type: ignore[arg-type]
  306. else:
  307. return (len(self.sampler) + self.batch_size - 1) // self.batch_size # type: ignore[arg-type]