dispatch.py 4.3 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145
  1. """
  2. Handles dispatching array operations to the correct backend library, as well
  3. as converting arrays to backend formats and then potentially storing them as
  4. constants.
  5. """
  6. import importlib
  7. import numpy
  8. from . import object_arrays
  9. from . import cupy as _cupy
  10. from . import jax as _jax
  11. from . import tensorflow as _tensorflow
  12. from . import theano as _theano
  13. from . import torch as _torch
  14. __all__ = ["get_func", "has_einsum", "has_tensordot", "build_expression", "evaluate_constants", "has_backend"]
  15. # known non top-level imports
  16. _aliases = {
  17. 'dask': 'dask.array',
  18. 'theano': 'theano.tensor',
  19. 'torch': 'opt_einsum.backends.torch',
  20. 'jax': 'jax.numpy',
  21. 'autograd': 'autograd.numpy',
  22. 'mars': 'mars.tensor',
  23. }
  24. def _import_func(func, backend, default=None):
  25. """Try and import ``{backend}.{func}``.
  26. If library is installed and func is found, return the func;
  27. otherwise if default is provided, return default;
  28. otherwise raise an error.
  29. """
  30. try:
  31. lib = importlib.import_module(_aliases.get(backend, backend))
  32. return getattr(lib, func) if default is None else getattr(lib, func, default)
  33. except AttributeError:
  34. error_msg = ("{} doesn't seem to provide the function {} - see "
  35. "https://optimized-einsum.readthedocs.io/en/latest/backends.html "
  36. "for details on which functions are required for which contractions.")
  37. raise AttributeError(error_msg.format(backend, func))
  38. # manually cache functions as python2 doesn't support functools.lru_cache
  39. # other libs will be added to this if needed, but pre-populate with numpy
  40. _cached_funcs = {
  41. ('tensordot', 'numpy'): numpy.tensordot,
  42. ('transpose', 'numpy'): numpy.transpose,
  43. ('einsum', 'numpy'): numpy.einsum,
  44. # also pre-populate with the arbitrary object backend
  45. ('tensordot', 'object'): numpy.tensordot,
  46. ('transpose', 'object'): numpy.transpose,
  47. ('einsum', 'object'): object_arrays.object_einsum,
  48. }
  49. def get_func(func, backend='numpy', default=None):
  50. """Return ``{backend}.{func}``, e.g. ``numpy.einsum``,
  51. or a default func if provided. Cache result.
  52. """
  53. try:
  54. return _cached_funcs[func, backend]
  55. except KeyError:
  56. fn = _import_func(func, backend, default)
  57. _cached_funcs[func, backend] = fn
  58. return fn
  59. # mark libs with einsum, else try to use tensordot/tranpose as much as possible
  60. _has_einsum = {}
  61. def has_einsum(backend):
  62. """Check if ``{backend}.einsum`` exists, cache result for performance.
  63. """
  64. try:
  65. return _has_einsum[backend]
  66. except KeyError:
  67. try:
  68. get_func('einsum', backend)
  69. _has_einsum[backend] = True
  70. except AttributeError:
  71. _has_einsum[backend] = False
  72. return _has_einsum[backend]
  73. _has_tensordot = {}
  74. def has_tensordot(backend):
  75. """Check if ``{backend}.tensordot`` exists, cache result for performance.
  76. """
  77. try:
  78. return _has_tensordot[backend]
  79. except KeyError:
  80. try:
  81. get_func('tensordot', backend)
  82. _has_tensordot[backend] = True
  83. except AttributeError:
  84. _has_tensordot[backend] = False
  85. return _has_tensordot[backend]
  86. # Dispatch to correct expression backend
  87. # these are the backends which support explicit to-and-from numpy conversion
  88. CONVERT_BACKENDS = {
  89. 'tensorflow': _tensorflow.build_expression,
  90. 'theano': _theano.build_expression,
  91. 'cupy': _cupy.build_expression,
  92. 'torch': _torch.build_expression,
  93. 'jax': _jax.build_expression,
  94. }
  95. EVAL_CONSTS_BACKENDS = {
  96. 'tensorflow': _tensorflow.evaluate_constants,
  97. 'theano': _theano.evaluate_constants,
  98. 'cupy': _cupy.evaluate_constants,
  99. 'torch': _torch.evaluate_constants,
  100. 'jax': _jax.evaluate_constants,
  101. }
  102. def build_expression(backend, arrays, expr):
  103. """Build an expression, based on ``expr`` and initial arrays ``arrays``,
  104. that evaluates using backend ``backend``.
  105. """
  106. return CONVERT_BACKENDS[backend](arrays, expr)
  107. def evaluate_constants(backend, arrays, expr):
  108. """Convert constant arrays to the correct backend, and perform as much of
  109. the contraction of ``expr`` with these as possible.
  110. """
  111. return EVAL_CONSTS_BACKENDS[backend](arrays, expr)
  112. def has_backend(backend):
  113. """Checks if the backend is known.
  114. """
  115. return backend.lower() in CONVERT_BACKENDS