datapipe.py 16 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420
  1. import functools
  2. import pickle
  3. from collections.abc import Iterable, Iterator
  4. from typing import Callable, Optional, TypeVar
  5. from torch.utils._import_utils import import_dill
  6. from torch.utils.data.datapipes._hook_iterator import _SnapshotState
  7. from torch.utils.data.datapipes._typing import _DataPipeMeta, _IterDataPipeMeta
  8. from torch.utils.data.datapipes.utils.common import (
  9. _deprecation_warning,
  10. _iter_deprecated_functional_names,
  11. _map_deprecated_functional_names,
  12. )
  13. from torch.utils.data.dataset import Dataset, IterableDataset
  14. dill = import_dill()
  15. HAS_DILL = dill is not None
  16. __all__ = [
  17. "DataChunk",
  18. "DFIterDataPipe",
  19. "IterDataPipe",
  20. "MapDataPipe",
  21. ]
  22. _T = TypeVar("_T")
  23. _T_co = TypeVar("_T_co", covariant=True)
  24. UNTRACABLE_DATAFRAME_PIPES = [
  25. "batch", # As it returns DataChunks
  26. "groupby", # As it returns DataChunks
  27. "_dataframes_as_tuples", # As it unpacks DF
  28. "trace_as_dataframe", # As it used to mark DF for tracing
  29. ]
  30. class DataChunk(list[_T]):
  31. def __init__(self, items: Iterable[_T]) -> None:
  32. items = list(items)
  33. super().__init__(items)
  34. self.items = items
  35. def as_str(self, indent: str = "") -> str:
  36. return indent + "[" + ", ".join(str(i) for i in iter(self)) + "]"
  37. def __iter__(self) -> Iterator[_T]:
  38. yield from super().__iter__()
  39. def raw_iterator(self) -> Iterator[_T]:
  40. yield from self.items
  41. class IterDataPipe(IterableDataset[_T_co], metaclass=_IterDataPipeMeta):
  42. r"""
  43. Iterable-style DataPipe.
  44. All DataPipes that represent an iterable of data samples should subclass this.
  45. This style of DataPipes is particularly useful when data come from a stream, or
  46. when the number of samples is too large to fit them all in memory. ``IterDataPipe`` is lazily initialized and its
  47. elements are computed only when ``next()`` is called on the iterator of an ``IterDataPipe``.
  48. All subclasses should overwrite :meth:`__iter__`, which would return an
  49. iterator of samples in this DataPipe. Calling ``__iter__`` of an ``IterDataPipe`` automatically invokes its
  50. method ``reset()``, which by default performs no operation. When writing a custom ``IterDataPipe``, users should
  51. override ``reset()`` if necessary. The common usages include resetting buffers, pointers,
  52. and various state variables within the custom ``IterDataPipe``.
  53. Note:
  54. Only `one` iterator can be valid for each ``IterDataPipe`` at a time,
  55. and the creation a second iterator will invalidate the first one. This constraint is necessary because
  56. some ``IterDataPipe`` have internal buffers, whose states can become invalid if there are multiple iterators.
  57. The code example below presents details on how this constraint looks in practice.
  58. If you have any feedback related to this constraint, please see `GitHub IterDataPipe Single Iterator Issue`_.
  59. These DataPipes can be invoked in two ways, using the class constructor or applying their
  60. functional form onto an existing ``IterDataPipe`` (recommended, available to most but not all DataPipes).
  61. You can chain multiple `IterDataPipe` together to form a pipeline that will perform multiple
  62. operations in succession.
  63. .. _GitHub IterDataPipe Single Iterator Issue:
  64. https://github.com/pytorch/data/issues/45
  65. Note:
  66. When a subclass is used with :class:`~torch.utils.data.DataLoader`, each
  67. item in the DataPipe will be yielded from the :class:`~torch.utils.data.DataLoader`
  68. iterator. When :attr:`num_workers > 0`, each worker process will have a
  69. different copy of the DataPipe object, so it is often desired to configure
  70. each copy independently to avoid having duplicate data returned from the
  71. workers. :func:`~torch.utils.data.get_worker_info`, when called in a worker
  72. process, returns information about the worker. It can be used in either the
  73. dataset's :meth:`__iter__` method or the :class:`~torch.utils.data.DataLoader` 's
  74. :attr:`worker_init_fn` option to modify each copy's behavior.
  75. Examples:
  76. General Usage:
  77. >>> # xdoctest: +SKIP
  78. >>> from torchdata.datapipes.iter import IterableWrapper, Mapper
  79. >>> dp = IterableWrapper(range(10))
  80. >>> map_dp_1 = Mapper(dp, lambda x: x + 1) # Using class constructor
  81. >>> map_dp_2 = dp.map(
  82. ... lambda x: x + 1
  83. ... ) # Using functional form (recommended)
  84. >>> list(map_dp_1)
  85. [1, 2, 3, 4, 5, 6, 7, 8, 9, 10]
  86. >>> list(map_dp_2)
  87. [1, 2, 3, 4, 5, 6, 7, 8, 9, 10]
  88. >>> filter_dp = map_dp_1.filter(lambda x: x % 2 == 0)
  89. >>> list(filter_dp)
  90. [2, 4, 6, 8, 10]
  91. Single Iterator Constraint Example:
  92. >>> from torchdata.datapipes.iter import IterableWrapper, Mapper
  93. >>> source_dp = IterableWrapper(range(10))
  94. >>> it1 = iter(source_dp)
  95. >>> list(it1)
  96. [0, 1, 2, 3, 4, 5, 6, 7, 8, 9]
  97. >>> it1 = iter(source_dp)
  98. >>> it2 = iter(
  99. ... source_dp
  100. ... ) # The creation of a new iterator invalidates `it1`
  101. >>> next(it2)
  102. 0
  103. >>> next(it1) # Further usage of `it1` will raise a `RunTimeError`
  104. """
  105. functions: dict[str, Callable] = {}
  106. reduce_ex_hook: Optional[Callable] = None
  107. getstate_hook: Optional[Callable] = None
  108. str_hook: Optional[Callable] = None
  109. repr_hook: Optional[Callable] = None
  110. _valid_iterator_id: Optional[int] = None
  111. _number_of_samples_yielded: int = 0
  112. _snapshot_state: _SnapshotState = _SnapshotState.NotStarted
  113. _fast_forward_iterator: Optional[Iterator] = None
  114. def __iter__(self) -> Iterator[_T_co]:
  115. return self
  116. def __getattr__(self, attribute_name):
  117. if attribute_name in IterDataPipe.functions:
  118. if attribute_name in _iter_deprecated_functional_names:
  119. kwargs = _iter_deprecated_functional_names[attribute_name]
  120. _deprecation_warning(**kwargs)
  121. f = IterDataPipe.functions[attribute_name]
  122. function = functools.partial(f, self)
  123. functools.update_wrapper(wrapper=function, wrapped=f, assigned=("__doc__",))
  124. return function
  125. else:
  126. raise AttributeError(
  127. f"'{self.__class__.__name__}' object has no attribute '{attribute_name}"
  128. )
  129. @classmethod
  130. def register_function(cls, function_name, function):
  131. cls.functions[function_name] = function
  132. @classmethod
  133. def register_datapipe_as_function(
  134. cls, function_name, cls_to_register, enable_df_api_tracing=False
  135. ):
  136. if function_name in cls.functions:
  137. raise Exception( # noqa: TRY002
  138. f"Unable to add DataPipe function name {function_name} as it is already taken"
  139. )
  140. def class_function(cls, enable_df_api_tracing, source_dp, *args, **kwargs):
  141. result_pipe = cls(source_dp, *args, **kwargs)
  142. if isinstance(result_pipe, IterDataPipe):
  143. if enable_df_api_tracing or isinstance(source_dp, DFIterDataPipe):
  144. if function_name not in UNTRACABLE_DATAFRAME_PIPES:
  145. result_pipe = result_pipe.trace_as_dataframe()
  146. return result_pipe
  147. function = functools.partial(
  148. class_function, cls_to_register, enable_df_api_tracing
  149. )
  150. functools.update_wrapper(
  151. wrapper=function, wrapped=cls_to_register, assigned=("__doc__",)
  152. )
  153. cls.functions[function_name] = function
  154. def __getstate__(self):
  155. """
  156. Serialize `lambda` functions when `dill` is available.
  157. If this doesn't cover your custom DataPipe's use case, consider writing custom methods for
  158. `__getstate__` and `__setstate__`, or use `pickle.dumps` for serialization.
  159. """
  160. state = self.__dict__
  161. if IterDataPipe.getstate_hook is not None:
  162. return IterDataPipe.getstate_hook(state)
  163. return state
  164. def __reduce_ex__(self, *args, **kwargs):
  165. if IterDataPipe.reduce_ex_hook is not None:
  166. try:
  167. return IterDataPipe.reduce_ex_hook(self)
  168. except NotImplementedError:
  169. pass
  170. return super().__reduce_ex__(*args, **kwargs)
  171. @classmethod
  172. def set_getstate_hook(cls, hook_fn):
  173. if IterDataPipe.getstate_hook is not None and hook_fn is not None:
  174. raise RuntimeError("Attempt to override existing getstate_hook")
  175. IterDataPipe.getstate_hook = hook_fn
  176. @classmethod
  177. def set_reduce_ex_hook(cls, hook_fn):
  178. if IterDataPipe.reduce_ex_hook is not None and hook_fn is not None:
  179. raise RuntimeError("Attempt to override existing reduce_ex_hook")
  180. IterDataPipe.reduce_ex_hook = hook_fn
  181. def __repr__(self):
  182. if self.repr_hook is not None:
  183. return self.repr_hook(self)
  184. # Instead of showing <torch. ... .MapperIterDataPipe object at 0x.....>, return the class name
  185. return str(self.__class__.__qualname__)
  186. def __str__(self):
  187. if self.str_hook is not None:
  188. return self.str_hook(self)
  189. # Instead of showing <torch. ... .MapperIterDataPipe object at 0x.....>, return the class name
  190. return str(self.__class__.__qualname__)
  191. def __dir__(self):
  192. # for auto-completion in a REPL (e.g. Jupyter notebook)
  193. return list(super().__dir__()) + list(self.functions.keys())
  194. def reset(self) -> None:
  195. r"""
  196. Reset the `IterDataPipe` to the initial state.
  197. By default, no-op. For subclasses of `IterDataPipe`, depending on their functionalities,
  198. they may want to override this method with implementations that
  199. may clear the buffers and reset pointers of the DataPipe.
  200. The `reset` method is always called when `__iter__` is called as part of `hook_iterator`.
  201. """
  202. class DFIterDataPipe(IterDataPipe):
  203. def _is_dfpipe(self):
  204. return True
  205. class MapDataPipe(Dataset[_T_co], metaclass=_DataPipeMeta):
  206. r"""
  207. Map-style DataPipe.
  208. All datasets that represent a map from keys to data samples should subclass this.
  209. Subclasses should overwrite :meth:`__getitem__`, supporting fetching a
  210. data sample for a given, unique key. Subclasses can also optionally overwrite
  211. :meth:`__len__`, which is expected to return the size of the dataset by many
  212. :class:`~torch.utils.data.Sampler` implementations and the default options
  213. of :class:`~torch.utils.data.DataLoader`.
  214. These DataPipes can be invoked in two ways, using the class constructor or applying their
  215. functional form onto an existing `MapDataPipe` (recommend, available to most but not all DataPipes).
  216. Note:
  217. :class:`~torch.utils.data.DataLoader` by default constructs an index
  218. sampler that yields integral indices. To make it work with a map-style
  219. DataPipe with non-integral indices/keys, a custom sampler must be provided.
  220. Example:
  221. >>> # xdoctest: +SKIP
  222. >>> from torchdata.datapipes.map import SequenceWrapper, Mapper
  223. >>> dp = SequenceWrapper(range(10))
  224. >>> map_dp_1 = dp.map(lambda x: x + 1) # Using functional form (recommended)
  225. >>> list(map_dp_1)
  226. [1, 2, 3, 4, 5, 6, 7, 8, 9, 10]
  227. >>> map_dp_2 = Mapper(dp, lambda x: x + 1) # Using class constructor
  228. >>> list(map_dp_2)
  229. [1, 2, 3, 4, 5, 6, 7, 8, 9, 10]
  230. >>> batch_dp = map_dp_1.batch(batch_size=2)
  231. >>> list(batch_dp)
  232. [[1, 2], [3, 4], [5, 6], [7, 8], [9, 10]]
  233. """
  234. functions: dict[str, Callable] = {}
  235. reduce_ex_hook: Optional[Callable] = None
  236. getstate_hook: Optional[Callable] = None
  237. str_hook: Optional[Callable] = None
  238. repr_hook: Optional[Callable] = None
  239. def __getattr__(self, attribute_name):
  240. if attribute_name in MapDataPipe.functions:
  241. if attribute_name in _map_deprecated_functional_names:
  242. kwargs = _map_deprecated_functional_names[attribute_name]
  243. _deprecation_warning(**kwargs)
  244. f = MapDataPipe.functions[attribute_name]
  245. function = functools.partial(f, self)
  246. functools.update_wrapper(wrapper=function, wrapped=f, assigned=("__doc__",))
  247. return function
  248. else:
  249. raise AttributeError(
  250. f"'{self.__class__.__name__}' object has no attribute '{attribute_name}"
  251. )
  252. @classmethod
  253. def register_function(cls, function_name, function):
  254. cls.functions[function_name] = function
  255. @classmethod
  256. def register_datapipe_as_function(cls, function_name, cls_to_register):
  257. if function_name in cls.functions:
  258. raise Exception( # noqa: TRY002
  259. f"Unable to add DataPipe function name {function_name} as it is already taken"
  260. )
  261. def class_function(cls, source_dp, *args, **kwargs):
  262. result_pipe = cls(source_dp, *args, **kwargs)
  263. return result_pipe
  264. function = functools.partial(class_function, cls_to_register)
  265. functools.update_wrapper(
  266. wrapper=function, wrapped=cls_to_register, assigned=("__doc__",)
  267. )
  268. cls.functions[function_name] = function
  269. def __getstate__(self):
  270. """
  271. Serialize `lambda` functions when `dill` is available.
  272. If this doesn't cover your custom DataPipe's use case, consider writing custom methods for
  273. `__getstate__` and `__setstate__`, or use `pickle.dumps` for serialization.
  274. """
  275. state = self.__dict__
  276. if MapDataPipe.getstate_hook is not None:
  277. return MapDataPipe.getstate_hook(state)
  278. return state
  279. def __reduce_ex__(self, *args, **kwargs):
  280. if MapDataPipe.reduce_ex_hook is not None:
  281. try:
  282. return MapDataPipe.reduce_ex_hook(self)
  283. except NotImplementedError:
  284. pass
  285. return super().__reduce_ex__(*args, **kwargs)
  286. @classmethod
  287. def set_getstate_hook(cls, hook_fn):
  288. if MapDataPipe.getstate_hook is not None and hook_fn is not None:
  289. raise RuntimeError("Attempt to override existing getstate_hook")
  290. MapDataPipe.getstate_hook = hook_fn
  291. @classmethod
  292. def set_reduce_ex_hook(cls, hook_fn):
  293. if MapDataPipe.reduce_ex_hook is not None and hook_fn is not None:
  294. raise RuntimeError("Attempt to override existing reduce_ex_hook")
  295. MapDataPipe.reduce_ex_hook = hook_fn
  296. def __repr__(self):
  297. if self.repr_hook is not None:
  298. return self.repr_hook(self)
  299. # Instead of showing <torch. ... .MapperMapDataPipe object at 0x.....>, return the class name
  300. return str(self.__class__.__qualname__)
  301. def __str__(self):
  302. if self.str_hook is not None:
  303. return self.str_hook(self)
  304. # Instead of showing <torch. ... .MapperMapDataPipe object at 0x.....>, return the class name
  305. return str(self.__class__.__qualname__)
  306. def __dir__(self):
  307. # for auto-completion in a REPL (e.g. Jupyter notebook)
  308. return list(super().__dir__()) + list(self.functions.keys())
  309. class _DataPipeSerializationWrapper:
  310. def __init__(self, datapipe):
  311. self._datapipe = datapipe
  312. def __getstate__(self):
  313. use_dill = False
  314. try:
  315. value = pickle.dumps(self._datapipe)
  316. except Exception:
  317. if HAS_DILL:
  318. value = dill.dumps(self._datapipe)
  319. use_dill = True
  320. else:
  321. raise
  322. return (value, use_dill)
  323. def __setstate__(self, state):
  324. value, use_dill = state
  325. if use_dill:
  326. self._datapipe = dill.loads(value)
  327. else:
  328. self._datapipe = pickle.loads(value)
  329. def __len__(self):
  330. try:
  331. return len(self._datapipe)
  332. except Exception as e:
  333. raise TypeError(
  334. f"{type(self).__name__} instance doesn't have valid length"
  335. ) from e
  336. class _IterDataPipeSerializationWrapper(_DataPipeSerializationWrapper, IterDataPipe):
  337. def __init__(self, datapipe: IterDataPipe[_T_co]):
  338. super().__init__(datapipe)
  339. self._datapipe_iter: Optional[Iterator[_T_co]] = None
  340. def __iter__(self) -> "_IterDataPipeSerializationWrapper":
  341. self._datapipe_iter = iter(self._datapipe)
  342. return self
  343. def __next__(self) -> _T_co: # type: ignore[type-var]
  344. assert self._datapipe_iter is not None
  345. return next(self._datapipe_iter)
  346. class _MapDataPipeSerializationWrapper(_DataPipeSerializationWrapper, MapDataPipe):
  347. def __getitem__(self, idx):
  348. return self._datapipe[idx]