functools.py 11 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375
  1. from __future__ import annotations
  2. __all__ = (
  3. "AsyncCacheInfo",
  4. "AsyncCacheParameters",
  5. "AsyncLRUCacheWrapper",
  6. "cache",
  7. "lru_cache",
  8. "reduce",
  9. )
  10. import functools
  11. import sys
  12. from collections import OrderedDict
  13. from collections.abc import (
  14. AsyncIterable,
  15. Awaitable,
  16. Callable,
  17. Coroutine,
  18. Hashable,
  19. Iterable,
  20. )
  21. from functools import update_wrapper
  22. from inspect import iscoroutinefunction
  23. from typing import (
  24. Any,
  25. Generic,
  26. NamedTuple,
  27. TypedDict,
  28. TypeVar,
  29. cast,
  30. final,
  31. overload,
  32. )
  33. from weakref import WeakKeyDictionary
  34. from ._core._synchronization import Lock
  35. from .lowlevel import RunVar, checkpoint
  36. if sys.version_info >= (3, 11):
  37. from typing import ParamSpec
  38. else:
  39. from typing_extensions import ParamSpec
  40. T = TypeVar("T")
  41. S = TypeVar("S")
  42. P = ParamSpec("P")
  43. lru_cache_items: RunVar[
  44. WeakKeyDictionary[
  45. AsyncLRUCacheWrapper[Any, Any],
  46. OrderedDict[Hashable, tuple[_InitialMissingType, Lock] | tuple[Any, None]],
  47. ]
  48. ] = RunVar("lru_cache_items")
  49. class _InitialMissingType:
  50. pass
  51. initial_missing: _InitialMissingType = _InitialMissingType()
  52. class AsyncCacheInfo(NamedTuple):
  53. hits: int
  54. misses: int
  55. maxsize: int | None
  56. currsize: int
  57. class AsyncCacheParameters(TypedDict):
  58. maxsize: int | None
  59. typed: bool
  60. always_checkpoint: bool
  61. class _LRUMethodWrapper(Generic[T]):
  62. def __init__(self, wrapper: AsyncLRUCacheWrapper[..., T], instance: object):
  63. self.__wrapper = wrapper
  64. self.__instance = instance
  65. def cache_info(self) -> AsyncCacheInfo:
  66. return self.__wrapper.cache_info()
  67. def cache_parameters(self) -> AsyncCacheParameters:
  68. return self.__wrapper.cache_parameters()
  69. def cache_clear(self) -> None:
  70. self.__wrapper.cache_clear()
  71. async def __call__(self, *args: Any, **kwargs: Any) -> T:
  72. if self.__instance is None:
  73. return await self.__wrapper(*args, **kwargs)
  74. return await self.__wrapper(self.__instance, *args, **kwargs)
  75. @final
  76. class AsyncLRUCacheWrapper(Generic[P, T]):
  77. def __init__(
  78. self,
  79. func: Callable[P, Awaitable[T]],
  80. maxsize: int | None,
  81. typed: bool,
  82. always_checkpoint: bool,
  83. ):
  84. self.__wrapped__ = func
  85. self._hits: int = 0
  86. self._misses: int = 0
  87. self._maxsize = max(maxsize, 0) if maxsize is not None else None
  88. self._currsize: int = 0
  89. self._typed = typed
  90. self._always_checkpoint = always_checkpoint
  91. update_wrapper(self, func)
  92. def cache_info(self) -> AsyncCacheInfo:
  93. return AsyncCacheInfo(self._hits, self._misses, self._maxsize, self._currsize)
  94. def cache_parameters(self) -> AsyncCacheParameters:
  95. return {
  96. "maxsize": self._maxsize,
  97. "typed": self._typed,
  98. "always_checkpoint": self._always_checkpoint,
  99. }
  100. def cache_clear(self) -> None:
  101. if cache := lru_cache_items.get(None):
  102. cache.pop(self, None)
  103. self._hits = self._misses = self._currsize = 0
  104. async def __call__(self, *args: P.args, **kwargs: P.kwargs) -> T:
  105. # Easy case first: if maxsize == 0, no caching is done
  106. if self._maxsize == 0:
  107. value = await self.__wrapped__(*args, **kwargs)
  108. self._misses += 1
  109. return value
  110. # The key is constructed as a flat tuple to avoid memory overhead
  111. key: tuple[Any, ...] = args
  112. if kwargs:
  113. # initial_missing is used as a separator
  114. key += (initial_missing,) + sum(kwargs.items(), ())
  115. if self._typed:
  116. key += tuple(type(arg) for arg in args)
  117. if kwargs:
  118. key += (initial_missing,) + tuple(type(val) for val in kwargs.values())
  119. try:
  120. cache = lru_cache_items.get()
  121. except LookupError:
  122. cache = WeakKeyDictionary()
  123. lru_cache_items.set(cache)
  124. try:
  125. cache_entry = cache[self]
  126. except KeyError:
  127. cache_entry = cache[self] = OrderedDict()
  128. cached_value: T | _InitialMissingType
  129. try:
  130. cached_value, lock = cache_entry[key]
  131. except KeyError:
  132. # We're the first task to call this function
  133. cached_value, lock = (
  134. initial_missing,
  135. Lock(fast_acquire=not self._always_checkpoint),
  136. )
  137. cache_entry[key] = cached_value, lock
  138. if lock is None:
  139. # The value was already cached
  140. self._hits += 1
  141. cache_entry.move_to_end(key)
  142. if self._always_checkpoint:
  143. await checkpoint()
  144. return cast(T, cached_value)
  145. async with lock:
  146. # Check if another task filled the cache while we acquired the lock
  147. if (cached_value := cache_entry[key][0]) is initial_missing:
  148. self._misses += 1
  149. if self._maxsize is not None and self._currsize >= self._maxsize:
  150. cache_entry.popitem(last=False)
  151. else:
  152. self._currsize += 1
  153. value = await self.__wrapped__(*args, **kwargs)
  154. cache_entry[key] = value, None
  155. else:
  156. # Another task filled the cache while we were waiting for the lock
  157. self._hits += 1
  158. cache_entry.move_to_end(key)
  159. value = cast(T, cached_value)
  160. return value
  161. def __get__(
  162. self, instance: object, owner: type | None = None
  163. ) -> _LRUMethodWrapper[T]:
  164. wrapper = _LRUMethodWrapper(self, instance)
  165. update_wrapper(wrapper, self.__wrapped__)
  166. return wrapper
  167. class _LRUCacheWrapper(Generic[T]):
  168. def __init__(self, maxsize: int | None, typed: bool, always_checkpoint: bool):
  169. self._maxsize = maxsize
  170. self._typed = typed
  171. self._always_checkpoint = always_checkpoint
  172. @overload
  173. def __call__( # type: ignore[overload-overlap]
  174. self, func: Callable[P, Coroutine[Any, Any, T]], /
  175. ) -> AsyncLRUCacheWrapper[P, T]: ...
  176. @overload
  177. def __call__(
  178. self, func: Callable[..., T], /
  179. ) -> functools._lru_cache_wrapper[T]: ...
  180. def __call__(
  181. self, f: Callable[P, Coroutine[Any, Any, T]] | Callable[..., T], /
  182. ) -> AsyncLRUCacheWrapper[P, T] | functools._lru_cache_wrapper[T]:
  183. if iscoroutinefunction(f):
  184. return AsyncLRUCacheWrapper(
  185. f, self._maxsize, self._typed, self._always_checkpoint
  186. )
  187. return functools.lru_cache(maxsize=self._maxsize, typed=self._typed)(f) # type: ignore[arg-type]
  188. @overload
  189. def cache( # type: ignore[overload-overlap]
  190. func: Callable[P, Coroutine[Any, Any, T]], /
  191. ) -> AsyncLRUCacheWrapper[P, T]: ...
  192. @overload
  193. def cache(func: Callable[..., T], /) -> functools._lru_cache_wrapper[T]: ...
  194. def cache(
  195. func: Callable[..., T] | Callable[P, Coroutine[Any, Any, T]], /
  196. ) -> AsyncLRUCacheWrapper[P, T] | functools._lru_cache_wrapper[T]:
  197. """
  198. A convenient shortcut for :func:`lru_cache` with ``maxsize=None``.
  199. This is the asynchronous equivalent to :func:`functools.cache`.
  200. """
  201. return lru_cache(maxsize=None)(func)
  202. @overload
  203. def lru_cache(
  204. *, maxsize: int | None = ..., typed: bool = ..., always_checkpoint: bool = ...
  205. ) -> _LRUCacheWrapper[Any]: ...
  206. @overload
  207. def lru_cache( # type: ignore[overload-overlap]
  208. func: Callable[P, Coroutine[Any, Any, T]], /
  209. ) -> AsyncLRUCacheWrapper[P, T]: ...
  210. @overload
  211. def lru_cache(func: Callable[..., T], /) -> functools._lru_cache_wrapper[T]: ...
  212. def lru_cache(
  213. func: Callable[P, Coroutine[Any, Any, T]] | Callable[..., T] | None = None,
  214. /,
  215. *,
  216. maxsize: int | None = 128,
  217. typed: bool = False,
  218. always_checkpoint: bool = False,
  219. ) -> (
  220. AsyncLRUCacheWrapper[P, T] | functools._lru_cache_wrapper[T] | _LRUCacheWrapper[Any]
  221. ):
  222. """
  223. An asynchronous version of :func:`functools.lru_cache`.
  224. If a synchronous function is passed, the standard library
  225. :func:`functools.lru_cache` is applied instead.
  226. :param always_checkpoint: if ``True``, every call to the cached function will be
  227. guaranteed to yield control to the event loop at least once
  228. .. note:: Caches and locks are managed on a per-event loop basis.
  229. """
  230. if func is None:
  231. return _LRUCacheWrapper[Any](maxsize, typed, always_checkpoint)
  232. if not callable(func):
  233. raise TypeError("the first argument must be callable")
  234. return _LRUCacheWrapper[T](maxsize, typed, always_checkpoint)(func)
  235. @overload
  236. async def reduce(
  237. function: Callable[[T, S], Awaitable[T]],
  238. iterable: Iterable[S] | AsyncIterable[S],
  239. /,
  240. initial: T,
  241. ) -> T: ...
  242. @overload
  243. async def reduce(
  244. function: Callable[[T, T], Awaitable[T]],
  245. iterable: Iterable[T] | AsyncIterable[T],
  246. /,
  247. ) -> T: ...
  248. async def reduce( # type: ignore[misc]
  249. function: Callable[[T, T], Awaitable[T]] | Callable[[T, S], Awaitable[T]],
  250. iterable: Iterable[T] | Iterable[S] | AsyncIterable[T] | AsyncIterable[S],
  251. /,
  252. initial: T | _InitialMissingType = initial_missing,
  253. ) -> T:
  254. """
  255. Asynchronous version of :func:`functools.reduce`.
  256. :param function: a coroutine function that takes two arguments: the accumulated
  257. value and the next element from the iterable
  258. :param iterable: an iterable or async iterable
  259. :param initial: the initial value (if missing, the first element of the iterable is
  260. used as the initial value)
  261. """
  262. element: Any
  263. function_called = False
  264. if isinstance(iterable, AsyncIterable):
  265. async_it = iterable.__aiter__()
  266. if initial is initial_missing:
  267. try:
  268. value = cast(T, await async_it.__anext__())
  269. except StopAsyncIteration:
  270. raise TypeError(
  271. "reduce() of empty sequence with no initial value"
  272. ) from None
  273. else:
  274. value = cast(T, initial)
  275. async for element in async_it:
  276. value = await function(value, element)
  277. function_called = True
  278. elif isinstance(iterable, Iterable):
  279. it = iter(iterable)
  280. if initial is initial_missing:
  281. try:
  282. value = cast(T, next(it))
  283. except StopIteration:
  284. raise TypeError(
  285. "reduce() of empty sequence with no initial value"
  286. ) from None
  287. else:
  288. value = cast(T, initial)
  289. for element in it:
  290. value = await function(value, element)
  291. function_called = True
  292. else:
  293. raise TypeError("reduce() argument 2 must be an iterable or async iterable")
  294. # Make sure there is at least one checkpoint, even if an empty iterable and an
  295. # initial value were given
  296. if not function_called:
  297. await checkpoint()
  298. return value