| 12345678910111213141516171819202122232425262728293031323334353637383940414243444546474849 |
- """
- Required functions for optimized contractions of numpy arrays using jax.
- """
- import numpy as np
- from ..sharing import to_backend_cache_wrap
- __all__ = ["build_expression", "evaluate_constants"]
- _JAX = None
- def _get_jax_and_to_jax():
- global _JAX
- if _JAX is None:
- import jax
- @to_backend_cache_wrap
- @jax.jit
- def to_jax(x):
- return x
- _JAX = jax, to_jax
- return _JAX
- def build_expression(_, expr): # pragma: no cover
- """Build a jax function based on ``arrays`` and ``expr``.
- """
- jax, _ = _get_jax_and_to_jax()
- jax_expr = jax.jit(expr._contract)
- def jax_contract(*arrays):
- return np.asarray(jax_expr(arrays))
- return jax_contract
- def evaluate_constants(const_arrays, expr): # pragma: no cover
- """Convert constant arguments to jax arrays, and perform any possible
- constant contractions.
- """
- jax, to_jax = _get_jax_and_to_jax()
- return expr(*[to_jax(x) for x in const_arrays], backend='jax', evaluate_constants=True)
|