sharing.py 6.2 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203
  1. """
  2. A module for sharing intermediates between contractions.
  3. Copyright (c) 2018 Uber Technologies
  4. """
  5. import contextlib
  6. import functools
  7. import numbers
  8. import threading
  9. from collections import Counter, defaultdict
  10. from .parser import alpha_canonicalize, parse_einsum_input
  11. __all__ = [
  12. "currently_sharing", "get_sharing_cache", "shared_intermediates", "count_cached_ops", "transpose_cache_wrap",
  13. "einsum_cache_wrap", "to_backend_cache_wrap"
  14. ]
  15. _SHARING_STACK = defaultdict(list)
  16. def currently_sharing():
  17. """Check if we are currently sharing a cache -- thread specific.
  18. """
  19. return threading.get_ident() in _SHARING_STACK
  20. def get_sharing_cache():
  21. """Return the most recent sharing cache -- thread specific.
  22. """
  23. return _SHARING_STACK[threading.get_ident()][-1]
  24. def _add_sharing_cache(cache):
  25. _SHARING_STACK[threading.get_ident()].append(cache)
  26. def _remove_sharing_cache():
  27. tid = threading.get_ident()
  28. _SHARING_STACK[tid].pop()
  29. if not _SHARING_STACK[tid]:
  30. del _SHARING_STACK[tid]
  31. @contextlib.contextmanager
  32. def shared_intermediates(cache=None):
  33. """Context in which contract intermediate results are shared.
  34. Note that intermediate computations will not be garbage collected until
  35. 1. this context exits, and
  36. 2. the yielded cache is garbage collected (if it was captured).
  37. Parameters
  38. ----------
  39. cache : dict
  40. If specified, a user-stored dict in which intermediate results will
  41. be stored. This can be used to interleave sharing contexts.
  42. Returns
  43. -------
  44. cache : dict
  45. A dictionary in which sharing results are stored. If ignored,
  46. sharing results will be garbage collected when this context is
  47. exited. This dict can be passed to another context to resume
  48. sharing.
  49. """
  50. if cache is None:
  51. cache = {}
  52. _add_sharing_cache(cache)
  53. try:
  54. yield cache
  55. finally:
  56. _remove_sharing_cache()
  57. def count_cached_ops(cache):
  58. """Returns a counter of the types of each op in the cache.
  59. This is useful for profiling to increase sharing.
  60. """
  61. return Counter(key[0] for key in cache.keys())
  62. def _save_tensors(*tensors):
  63. """Save tensors in the cache to prevent their ids from being recycled.
  64. This is needed to prevent false cache lookups.
  65. """
  66. cache = get_sharing_cache()
  67. for tensor in tensors:
  68. cache['tensor', id(tensor)] = tensor
  69. def _memoize(key, fn, *args, **kwargs):
  70. """Memoize ``fn(*args, **kwargs)`` using the given ``key``.
  71. Results will be stored in the innermost ``cache`` yielded by
  72. :func:`shared_intermediates`.
  73. """
  74. cache = get_sharing_cache()
  75. if key in cache:
  76. return cache[key]
  77. result = fn(*args, **kwargs)
  78. cache[key] = result
  79. return result
  80. def transpose_cache_wrap(transpose):
  81. """Decorates a ``transpose()`` implementation to be memoized inside a
  82. :func:`shared_intermediates` context.
  83. """
  84. @functools.wraps(transpose)
  85. def cached_transpose(a, axes, backend='numpy'):
  86. if not currently_sharing():
  87. return transpose(a, axes, backend=backend)
  88. # hash by axes
  89. _save_tensors(a)
  90. axes = tuple(axes)
  91. key = 'transpose', backend, id(a), axes
  92. return _memoize(key, transpose, a, axes, backend=backend)
  93. return cached_transpose
  94. def tensordot_cache_wrap(tensordot):
  95. """Decorates a ``tensordot()`` implementation to be memoized inside a
  96. :func:`shared_intermediates` context.
  97. """
  98. @functools.wraps(tensordot)
  99. def cached_tensordot(x, y, axes=2, backend='numpy'):
  100. if not currently_sharing():
  101. return tensordot(x, y, axes, backend=backend)
  102. # hash based on the (axes_x,axes_y) form of axes
  103. _save_tensors(x, y)
  104. if isinstance(axes, numbers.Number):
  105. axes = list(range(len(x.shape)))[len(x.shape) - axes:], list(range(len(y.shape)))[:axes]
  106. axes = tuple(axes[0]), tuple(axes[1])
  107. key = 'tensordot', backend, id(x), id(y), axes
  108. return _memoize(key, tensordot, x, y, axes, backend=backend)
  109. return cached_tensordot
  110. def einsum_cache_wrap(einsum):
  111. """Decorates an ``einsum()`` implementation to be memoized inside a
  112. :func:`shared_intermediates` context.
  113. """
  114. @functools.wraps(einsum)
  115. def cached_einsum(*args, **kwargs):
  116. if not currently_sharing():
  117. return einsum(*args, **kwargs)
  118. # hash modulo commutativity by computing a canonical ordering and names
  119. backend = kwargs.pop('backend', 'numpy')
  120. equation = args[0]
  121. inputs, output, operands = parse_einsum_input(args)
  122. inputs = inputs.split(',')
  123. _save_tensors(*operands)
  124. # Build canonical key
  125. canonical = sorted(zip(inputs, map(id, operands)), key=lambda x: x[1])
  126. canonical_ids = tuple(id_ for _, id_ in canonical)
  127. canonical_inputs = ','.join(input_ for input_, _ in canonical)
  128. canonical_equation = alpha_canonicalize(canonical_inputs + "->" + output)
  129. key = 'einsum', backend, canonical_equation, canonical_ids
  130. return _memoize(key, einsum, equation, *operands, backend=backend)
  131. return cached_einsum
  132. def to_backend_cache_wrap(to_backend=None, constants=False):
  133. """Decorates an ``to_backend()`` implementation to be memoized inside a
  134. :func:`shared_intermediates` context (e.g. ``to_cupy``, ``to_torch``).
  135. """
  136. # manage the case that decorator is called with args
  137. if to_backend is None:
  138. return functools.partial(to_backend_cache_wrap, constants=constants)
  139. if constants:
  140. @functools.wraps(to_backend)
  141. def cached_to_backend(array, constant=False):
  142. if not currently_sharing():
  143. return to_backend(array, constant=constant)
  144. # hash by id
  145. key = to_backend.__name__, id(array), constant
  146. return _memoize(key, to_backend, array, constant=constant)
  147. else:
  148. @functools.wraps(to_backend)
  149. def cached_to_backend(array):
  150. if not currently_sharing():
  151. return to_backend(array)
  152. # hash by id
  153. key = to_backend.__name__, id(array)
  154. return _memoize(key, to_backend, array)
  155. return cached_to_backend