| 123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203 |
- """
- A module for sharing intermediates between contractions.
- Copyright (c) 2018 Uber Technologies
- """
- import contextlib
- import functools
- import numbers
- import threading
- from collections import Counter, defaultdict
- from .parser import alpha_canonicalize, parse_einsum_input
- __all__ = [
- "currently_sharing", "get_sharing_cache", "shared_intermediates", "count_cached_ops", "transpose_cache_wrap",
- "einsum_cache_wrap", "to_backend_cache_wrap"
- ]
- _SHARING_STACK = defaultdict(list)
- def currently_sharing():
- """Check if we are currently sharing a cache -- thread specific.
- """
- return threading.get_ident() in _SHARING_STACK
- def get_sharing_cache():
- """Return the most recent sharing cache -- thread specific.
- """
- return _SHARING_STACK[threading.get_ident()][-1]
- def _add_sharing_cache(cache):
- _SHARING_STACK[threading.get_ident()].append(cache)
- def _remove_sharing_cache():
- tid = threading.get_ident()
- _SHARING_STACK[tid].pop()
- if not _SHARING_STACK[tid]:
- del _SHARING_STACK[tid]
- @contextlib.contextmanager
- def shared_intermediates(cache=None):
- """Context in which contract intermediate results are shared.
- Note that intermediate computations will not be garbage collected until
- 1. this context exits, and
- 2. the yielded cache is garbage collected (if it was captured).
- Parameters
- ----------
- cache : dict
- If specified, a user-stored dict in which intermediate results will
- be stored. This can be used to interleave sharing contexts.
- Returns
- -------
- cache : dict
- A dictionary in which sharing results are stored. If ignored,
- sharing results will be garbage collected when this context is
- exited. This dict can be passed to another context to resume
- sharing.
- """
- if cache is None:
- cache = {}
- _add_sharing_cache(cache)
- try:
- yield cache
- finally:
- _remove_sharing_cache()
- def count_cached_ops(cache):
- """Returns a counter of the types of each op in the cache.
- This is useful for profiling to increase sharing.
- """
- return Counter(key[0] for key in cache.keys())
- def _save_tensors(*tensors):
- """Save tensors in the cache to prevent their ids from being recycled.
- This is needed to prevent false cache lookups.
- """
- cache = get_sharing_cache()
- for tensor in tensors:
- cache['tensor', id(tensor)] = tensor
- def _memoize(key, fn, *args, **kwargs):
- """Memoize ``fn(*args, **kwargs)`` using the given ``key``.
- Results will be stored in the innermost ``cache`` yielded by
- :func:`shared_intermediates`.
- """
- cache = get_sharing_cache()
- if key in cache:
- return cache[key]
- result = fn(*args, **kwargs)
- cache[key] = result
- return result
- def transpose_cache_wrap(transpose):
- """Decorates a ``transpose()`` implementation to be memoized inside a
- :func:`shared_intermediates` context.
- """
- @functools.wraps(transpose)
- def cached_transpose(a, axes, backend='numpy'):
- if not currently_sharing():
- return transpose(a, axes, backend=backend)
- # hash by axes
- _save_tensors(a)
- axes = tuple(axes)
- key = 'transpose', backend, id(a), axes
- return _memoize(key, transpose, a, axes, backend=backend)
- return cached_transpose
- def tensordot_cache_wrap(tensordot):
- """Decorates a ``tensordot()`` implementation to be memoized inside a
- :func:`shared_intermediates` context.
- """
- @functools.wraps(tensordot)
- def cached_tensordot(x, y, axes=2, backend='numpy'):
- if not currently_sharing():
- return tensordot(x, y, axes, backend=backend)
- # hash based on the (axes_x,axes_y) form of axes
- _save_tensors(x, y)
- if isinstance(axes, numbers.Number):
- axes = list(range(len(x.shape)))[len(x.shape) - axes:], list(range(len(y.shape)))[:axes]
- axes = tuple(axes[0]), tuple(axes[1])
- key = 'tensordot', backend, id(x), id(y), axes
- return _memoize(key, tensordot, x, y, axes, backend=backend)
- return cached_tensordot
- def einsum_cache_wrap(einsum):
- """Decorates an ``einsum()`` implementation to be memoized inside a
- :func:`shared_intermediates` context.
- """
- @functools.wraps(einsum)
- def cached_einsum(*args, **kwargs):
- if not currently_sharing():
- return einsum(*args, **kwargs)
- # hash modulo commutativity by computing a canonical ordering and names
- backend = kwargs.pop('backend', 'numpy')
- equation = args[0]
- inputs, output, operands = parse_einsum_input(args)
- inputs = inputs.split(',')
- _save_tensors(*operands)
- # Build canonical key
- canonical = sorted(zip(inputs, map(id, operands)), key=lambda x: x[1])
- canonical_ids = tuple(id_ for _, id_ in canonical)
- canonical_inputs = ','.join(input_ for input_, _ in canonical)
- canonical_equation = alpha_canonicalize(canonical_inputs + "->" + output)
- key = 'einsum', backend, canonical_equation, canonical_ids
- return _memoize(key, einsum, equation, *operands, backend=backend)
- return cached_einsum
- def to_backend_cache_wrap(to_backend=None, constants=False):
- """Decorates an ``to_backend()`` implementation to be memoized inside a
- :func:`shared_intermediates` context (e.g. ``to_cupy``, ``to_torch``).
- """
- # manage the case that decorator is called with args
- if to_backend is None:
- return functools.partial(to_backend_cache_wrap, constants=constants)
- if constants:
- @functools.wraps(to_backend)
- def cached_to_backend(array, constant=False):
- if not currently_sharing():
- return to_backend(array, constant=constant)
- # hash by id
- key = to_backend.__name__, id(array), constant
- return _memoize(key, to_backend, array, constant=constant)
- else:
- @functools.wraps(to_backend)
- def cached_to_backend(array):
- if not currently_sharing():
- return to_backend(array)
- # hash by id
- key = to_backend.__name__, id(array)
- return _memoize(key, to_backend, array)
- return cached_to_backend
|