_indexing_functions.py 715 B

1234567891011121314151617181920
  1. from __future__ import annotations
  2. from ._array_object import Array
  3. from ._dtypes import _integer_dtypes
  4. import numpy as np
  5. def take(x: Array, indices: Array, /, *, axis: Optional[int] = None) -> Array:
  6. """
  7. Array API compatible wrapper for :py:func:`np.take <numpy.take>`.
  8. See its docstring for more information.
  9. """
  10. if axis is None and x.ndim != 1:
  11. raise ValueError("axis must be specified when ndim > 1")
  12. if indices.dtype not in _integer_dtypes:
  13. raise TypeError("Only integer dtypes are allowed in indexing")
  14. if indices.ndim != 1:
  15. raise ValueError("Only 1-dim indices array is supported")
  16. return Array._new(np.take(x._array, indices._array, axis=axis))