itertools.py 7.8 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275
  1. """
  2. Python polyfills for itertools
  3. """
  4. from __future__ import annotations
  5. import itertools
  6. import operator
  7. import sys
  8. from typing import Callable, Optional, overload, TYPE_CHECKING, TypeVar
  9. from typing_extensions import TypeAlias
  10. from ..decorators import substitute_in_graph
  11. if TYPE_CHECKING:
  12. from collections.abc import Iterable, Iterator
  13. __all__ = [
  14. "accumulate",
  15. "chain",
  16. "chain_from_iterable",
  17. "compress",
  18. "cycle",
  19. "dropwhile",
  20. "filterfalse",
  21. "islice",
  22. "tee",
  23. "zip_longest",
  24. ]
  25. _T = TypeVar("_T")
  26. _U = TypeVar("_U")
  27. _Predicate: TypeAlias = Callable[[_T], object]
  28. _T1 = TypeVar("_T1")
  29. _T2 = TypeVar("_T2")
  30. # Reference: https://docs.python.org/3/library/itertools.html#itertools.chain
  31. @substitute_in_graph(itertools.chain, is_embedded_type=True) # type: ignore[arg-type]
  32. def chain(*iterables: Iterable[_T]) -> Iterator[_T]:
  33. for iterable in iterables:
  34. yield from iterable
  35. # Reference: https://docs.python.org/3/library/itertools.html#itertools.accumulate
  36. @substitute_in_graph(itertools.accumulate, is_embedded_type=True) # type: ignore[arg-type]
  37. def accumulate(
  38. iterable: Iterable[_T],
  39. func: Optional[Callable[[_T, _T], _T]] = None,
  40. *,
  41. initial: Optional[_T] = None,
  42. ) -> Iterator[_T]:
  43. # call iter outside of the generator to match cypthon behavior
  44. iterator = iter(iterable)
  45. if func is None:
  46. func = operator.add
  47. def _accumulate(iterator: Iterator[_T]) -> Iterator[_T]:
  48. total = initial
  49. if total is None:
  50. try:
  51. total = next(iterator)
  52. except StopIteration:
  53. return
  54. yield total
  55. for element in iterator:
  56. total = func(total, element)
  57. yield total
  58. return _accumulate(iterator)
  59. @substitute_in_graph(itertools.chain.from_iterable) # type: ignore[arg-type]
  60. def chain_from_iterable(iterable: Iterable[Iterable[_T]], /) -> Iterator[_T]:
  61. # previous version of this code was:
  62. # return itertools.chain(*iterable)
  63. # If iterable is an infinite generator, this will lead to infinite recursion
  64. for it in iterable:
  65. yield from it
  66. chain.from_iterable = chain_from_iterable # type: ignore[attr-defined]
  67. # Reference: https://docs.python.org/3/library/itertools.html#itertools.compress
  68. @substitute_in_graph(itertools.compress, is_embedded_type=True) # type: ignore[arg-type]
  69. def compress(data: Iterable[_T], selectors: Iterable[_U], /) -> Iterator[_T]:
  70. return (datum for datum, selector in zip(data, selectors) if selector)
  71. # Reference: https://docs.python.org/3/library/itertools.html#itertools.cycle
  72. @substitute_in_graph(itertools.cycle, is_embedded_type=True) # type: ignore[arg-type]
  73. def cycle(iterable: Iterable[_T]) -> Iterator[_T]:
  74. iterator = iter(iterable)
  75. def _cycle(iterator: Iterator[_T]) -> Iterator[_T]:
  76. saved = []
  77. for element in iterable:
  78. yield element
  79. saved.append(element)
  80. while saved:
  81. for element in saved:
  82. yield element
  83. return _cycle(iterator)
  84. # Reference: https://docs.python.org/3/library/itertools.html#itertools.dropwhile
  85. @substitute_in_graph(itertools.dropwhile, is_embedded_type=True) # type: ignore[arg-type]
  86. def dropwhile(predicate: _Predicate[_T], iterable: Iterable[_T], /) -> Iterator[_T]:
  87. # dropwhile(lambda x: x < 5, [1, 4, 6, 3, 8]) -> 6 3 8
  88. iterator = iter(iterable)
  89. for x in iterator:
  90. if not predicate(x):
  91. yield x
  92. break
  93. yield from iterator
  94. @substitute_in_graph(itertools.filterfalse, is_embedded_type=True) # type: ignore[arg-type]
  95. def filterfalse(function: _Predicate[_T], iterable: Iterable[_T], /) -> Iterator[_T]:
  96. it = iter(iterable)
  97. if function is None:
  98. return filter(operator.not_, it)
  99. else:
  100. return filter(lambda x: not function(x), it)
  101. # Reference: https://docs.python.org/3/library/itertools.html#itertools.islice
  102. @substitute_in_graph(itertools.islice, is_embedded_type=True) # type: ignore[arg-type]
  103. def islice(iterable: Iterable[_T], /, *args: int | None) -> Iterator[_T]:
  104. s = slice(*args)
  105. start = 0 if s.start is None else s.start
  106. stop = s.stop
  107. step = 1 if s.step is None else s.step
  108. if start < 0 or (stop is not None and stop < 0) or step <= 0:
  109. raise ValueError(
  110. "Indices for islice() must be None or an integer: 0 <= x <= sys.maxsize.",
  111. )
  112. if stop is None:
  113. # TODO: use indices = itertools.count() and merge implementation with the else branch
  114. # when we support infinite iterators
  115. next_i = start
  116. for i, element in enumerate(iterable):
  117. if i == next_i:
  118. yield element
  119. next_i += step
  120. else:
  121. indices = range(max(start, stop))
  122. next_i = start
  123. for i, element in zip(indices, iterable):
  124. if i == next_i:
  125. yield element
  126. next_i += step
  127. # Reference: https://docs.python.org/3/library/itertools.html#itertools.pairwise
  128. if sys.version_info >= (3, 10):
  129. @substitute_in_graph(itertools.pairwise, is_embedded_type=True) # type: ignore[arg-type]
  130. def pairwise(iterable: Iterable[_T], /) -> Iterator[tuple[_T, _T]]:
  131. a = None
  132. first = True
  133. for b in iterable:
  134. if first:
  135. first = False
  136. else:
  137. yield a, b # type: ignore[misc]
  138. a = b
  139. __all__ += ["pairwise"]
  140. # Reference: https://docs.python.org/3/library/itertools.html#itertools.tee
  141. @substitute_in_graph(itertools.tee)
  142. def tee(iterable: Iterable[_T], n: int = 2, /) -> tuple[Iterator[_T], ...]:
  143. iterator = iter(iterable)
  144. shared_link = [None, None]
  145. def _tee(link) -> Iterator[_T]: # type: ignore[no-untyped-def]
  146. try:
  147. while True:
  148. if link[1] is None:
  149. link[0] = next(iterator)
  150. link[1] = [None, None]
  151. value, link = link
  152. yield value
  153. except StopIteration:
  154. return
  155. return tuple(_tee(shared_link) for _ in range(n))
  156. @overload
  157. def zip_longest(
  158. iter1: Iterable[_T1],
  159. /,
  160. *,
  161. fillvalue: _U = ...,
  162. ) -> Iterator[tuple[_T1]]: ...
  163. @overload
  164. def zip_longest(
  165. iter1: Iterable[_T1],
  166. iter2: Iterable[_T2],
  167. /,
  168. ) -> Iterator[tuple[_T1 | None, _T2 | None]]: ...
  169. @overload
  170. def zip_longest(
  171. iter1: Iterable[_T1],
  172. iter2: Iterable[_T2],
  173. /,
  174. *,
  175. fillvalue: _U = ...,
  176. ) -> Iterator[tuple[_T1 | _U, _T2 | _U]]: ...
  177. @overload
  178. def zip_longest(
  179. iter1: Iterable[_T],
  180. iter2: Iterable[_T],
  181. iter3: Iterable[_T],
  182. /,
  183. *iterables: Iterable[_T],
  184. ) -> Iterator[tuple[_T | None, ...]]: ...
  185. @overload
  186. def zip_longest(
  187. iter1: Iterable[_T],
  188. iter2: Iterable[_T],
  189. iter3: Iterable[_T],
  190. /,
  191. *iterables: Iterable[_T],
  192. fillvalue: _U = ...,
  193. ) -> Iterator[tuple[_T | _U, ...]]: ...
  194. # Reference: https://docs.python.org/3/library/itertools.html#itertools.zip_longest
  195. @substitute_in_graph(itertools.zip_longest, is_embedded_type=True) # type: ignore[arg-type,misc]
  196. def zip_longest(
  197. *iterables: Iterable[_T],
  198. fillvalue: _U = None, # type: ignore[assignment]
  199. ) -> Iterator[tuple[_T | _U, ...]]:
  200. # zip_longest('ABCD', 'xy', fillvalue='-') -> Ax By C- D-
  201. iterators = list(map(iter, iterables))
  202. num_active = len(iterators)
  203. if not num_active:
  204. return
  205. while True:
  206. values = []
  207. for i, iterator in enumerate(iterators):
  208. try:
  209. value = next(iterator)
  210. except StopIteration:
  211. num_active -= 1
  212. if not num_active:
  213. return
  214. iterators[i] = itertools.repeat(fillvalue) # type: ignore[arg-type]
  215. value = fillvalue # type: ignore[assignment]
  216. values.append(value)
  217. yield tuple(values)