cupy.py 919 B

1234567891011121314151617181920212223242526272829303132333435
  1. """
  2. Required functions for optimized contractions of numpy arrays using cupy.
  3. """
  4. import numpy as np
  5. from ..sharing import to_backend_cache_wrap
  6. __all__ = ["to_cupy", "build_expression", "evaluate_constants"]
  7. @to_backend_cache_wrap
  8. def to_cupy(array): # pragma: no cover
  9. import cupy
  10. if isinstance(array, np.ndarray):
  11. return cupy.asarray(array)
  12. return array
  13. def build_expression(_, expr): # pragma: no cover
  14. """Build a cupy function based on ``arrays`` and ``expr``.
  15. """
  16. def cupy_contract(*arrays):
  17. return expr._contract([to_cupy(x) for x in arrays], backend='cupy').get()
  18. return cupy_contract
  19. def evaluate_constants(const_arrays, expr): # pragma: no cover
  20. """Convert constant arguments to cupy arrays, and perform any possible
  21. constant contractions.
  22. """
  23. return expr(*[to_cupy(x) for x in const_arrays], backend='cupy', evaluate_constants=True)