_manipulation_functions.py 3.2 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112
  1. from __future__ import annotations
  2. from ._array_object import Array
  3. from ._data_type_functions import result_type
  4. from typing import List, Optional, Tuple, Union
  5. import numpy as np
  6. # Note: the function name is different here
  7. def concat(
  8. arrays: Union[Tuple[Array, ...], List[Array]], /, *, axis: Optional[int] = 0
  9. ) -> Array:
  10. """
  11. Array API compatible wrapper for :py:func:`np.concatenate <numpy.concatenate>`.
  12. See its docstring for more information.
  13. """
  14. # Note: Casting rules here are different from the np.concatenate default
  15. # (no for scalars with axis=None, no cross-kind casting)
  16. dtype = result_type(*arrays)
  17. arrays = tuple(a._array for a in arrays)
  18. return Array._new(np.concatenate(arrays, axis=axis, dtype=dtype))
  19. def expand_dims(x: Array, /, *, axis: int) -> Array:
  20. """
  21. Array API compatible wrapper for :py:func:`np.expand_dims <numpy.expand_dims>`.
  22. See its docstring for more information.
  23. """
  24. return Array._new(np.expand_dims(x._array, axis))
  25. def flip(x: Array, /, *, axis: Optional[Union[int, Tuple[int, ...]]] = None) -> Array:
  26. """
  27. Array API compatible wrapper for :py:func:`np.flip <numpy.flip>`.
  28. See its docstring for more information.
  29. """
  30. return Array._new(np.flip(x._array, axis=axis))
  31. # Note: The function name is different here (see also matrix_transpose).
  32. # Unlike transpose(), the axes argument is required.
  33. def permute_dims(x: Array, /, axes: Tuple[int, ...]) -> Array:
  34. """
  35. Array API compatible wrapper for :py:func:`np.transpose <numpy.transpose>`.
  36. See its docstring for more information.
  37. """
  38. return Array._new(np.transpose(x._array, axes))
  39. # Note: the optional argument is called 'shape', not 'newshape'
  40. def reshape(x: Array,
  41. /,
  42. shape: Tuple[int, ...],
  43. *,
  44. copy: Optional[Bool] = None) -> Array:
  45. """
  46. Array API compatible wrapper for :py:func:`np.reshape <numpy.reshape>`.
  47. See its docstring for more information.
  48. """
  49. data = x._array
  50. if copy:
  51. data = np.copy(data)
  52. reshaped = np.reshape(data, shape)
  53. if copy is False and not np.shares_memory(data, reshaped):
  54. raise AttributeError("Incompatible shape for in-place modification.")
  55. return Array._new(reshaped)
  56. def roll(
  57. x: Array,
  58. /,
  59. shift: Union[int, Tuple[int, ...]],
  60. *,
  61. axis: Optional[Union[int, Tuple[int, ...]]] = None,
  62. ) -> Array:
  63. """
  64. Array API compatible wrapper for :py:func:`np.roll <numpy.roll>`.
  65. See its docstring for more information.
  66. """
  67. return Array._new(np.roll(x._array, shift, axis=axis))
  68. def squeeze(x: Array, /, axis: Union[int, Tuple[int, ...]]) -> Array:
  69. """
  70. Array API compatible wrapper for :py:func:`np.squeeze <numpy.squeeze>`.
  71. See its docstring for more information.
  72. """
  73. return Array._new(np.squeeze(x._array, axis=axis))
  74. def stack(arrays: Union[Tuple[Array, ...], List[Array]], /, *, axis: int = 0) -> Array:
  75. """
  76. Array API compatible wrapper for :py:func:`np.stack <numpy.stack>`.
  77. See its docstring for more information.
  78. """
  79. # Call result type here just to raise on disallowed type combinations
  80. result_type(*arrays)
  81. arrays = tuple(a._array for a in arrays)
  82. return Array._new(np.stack(arrays, axis=axis))