jax.py 1.0 KB

12345678910111213141516171819202122232425262728293031323334353637383940414243444546474849
  1. """
  2. Required functions for optimized contractions of numpy arrays using jax.
  3. """
  4. import numpy as np
  5. from ..sharing import to_backend_cache_wrap
  6. __all__ = ["build_expression", "evaluate_constants"]
  7. _JAX = None
  8. def _get_jax_and_to_jax():
  9. global _JAX
  10. if _JAX is None:
  11. import jax
  12. @to_backend_cache_wrap
  13. @jax.jit
  14. def to_jax(x):
  15. return x
  16. _JAX = jax, to_jax
  17. return _JAX
  18. def build_expression(_, expr): # pragma: no cover
  19. """Build a jax function based on ``arrays`` and ``expr``.
  20. """
  21. jax, _ = _get_jax_and_to_jax()
  22. jax_expr = jax.jit(expr._contract)
  23. def jax_contract(*arrays):
  24. return np.asarray(jax_expr(arrays))
  25. return jax_contract
  26. def evaluate_constants(const_arrays, expr): # pragma: no cover
  27. """Convert constant arguments to jax arrays, and perform any possible
  28. constant contractions.
  29. """
  30. jax, to_jax = _get_jax_and_to_jax()
  31. return expr(*[to_jax(x) for x in const_arrays], backend='jax', evaluate_constants=True)