_statistical_functions.py 3.5 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122
  1. from __future__ import annotations
  2. from ._dtypes import (
  3. _real_floating_dtypes,
  4. _real_numeric_dtypes,
  5. _numeric_dtypes,
  6. )
  7. from ._array_object import Array
  8. from ._dtypes import float32, float64, complex64, complex128
  9. from typing import TYPE_CHECKING, Optional, Tuple, Union
  10. if TYPE_CHECKING:
  11. from ._typing import Dtype
  12. import numpy as np
  13. def max(
  14. x: Array,
  15. /,
  16. *,
  17. axis: Optional[Union[int, Tuple[int, ...]]] = None,
  18. keepdims: bool = False,
  19. ) -> Array:
  20. if x.dtype not in _real_numeric_dtypes:
  21. raise TypeError("Only real numeric dtypes are allowed in max")
  22. return Array._new(np.max(x._array, axis=axis, keepdims=keepdims))
  23. def mean(
  24. x: Array,
  25. /,
  26. *,
  27. axis: Optional[Union[int, Tuple[int, ...]]] = None,
  28. keepdims: bool = False,
  29. ) -> Array:
  30. if x.dtype not in _real_floating_dtypes:
  31. raise TypeError("Only real floating-point dtypes are allowed in mean")
  32. return Array._new(np.mean(x._array, axis=axis, keepdims=keepdims))
  33. def min(
  34. x: Array,
  35. /,
  36. *,
  37. axis: Optional[Union[int, Tuple[int, ...]]] = None,
  38. keepdims: bool = False,
  39. ) -> Array:
  40. if x.dtype not in _real_numeric_dtypes:
  41. raise TypeError("Only real numeric dtypes are allowed in min")
  42. return Array._new(np.min(x._array, axis=axis, keepdims=keepdims))
  43. def prod(
  44. x: Array,
  45. /,
  46. *,
  47. axis: Optional[Union[int, Tuple[int, ...]]] = None,
  48. dtype: Optional[Dtype] = None,
  49. keepdims: bool = False,
  50. ) -> Array:
  51. if x.dtype not in _numeric_dtypes:
  52. raise TypeError("Only numeric dtypes are allowed in prod")
  53. # Note: sum() and prod() always upcast for dtype=None. `np.prod` does that
  54. # for integers, but not for float32 or complex64, so we need to
  55. # special-case it here
  56. if dtype is None:
  57. if x.dtype == float32:
  58. dtype = float64
  59. elif x.dtype == complex64:
  60. dtype = complex128
  61. return Array._new(np.prod(x._array, dtype=dtype, axis=axis, keepdims=keepdims))
  62. def std(
  63. x: Array,
  64. /,
  65. *,
  66. axis: Optional[Union[int, Tuple[int, ...]]] = None,
  67. correction: Union[int, float] = 0.0,
  68. keepdims: bool = False,
  69. ) -> Array:
  70. # Note: the keyword argument correction is different here
  71. if x.dtype not in _real_floating_dtypes:
  72. raise TypeError("Only real floating-point dtypes are allowed in std")
  73. return Array._new(np.std(x._array, axis=axis, ddof=correction, keepdims=keepdims))
  74. def sum(
  75. x: Array,
  76. /,
  77. *,
  78. axis: Optional[Union[int, Tuple[int, ...]]] = None,
  79. dtype: Optional[Dtype] = None,
  80. keepdims: bool = False,
  81. ) -> Array:
  82. if x.dtype not in _numeric_dtypes:
  83. raise TypeError("Only numeric dtypes are allowed in sum")
  84. # Note: sum() and prod() always upcast for dtype=None. `np.sum` does that
  85. # for integers, but not for float32 or complex64, so we need to
  86. # special-case it here
  87. if dtype is None:
  88. if x.dtype == float32:
  89. dtype = float64
  90. elif x.dtype == complex64:
  91. dtype = complex128
  92. return Array._new(np.sum(x._array, axis=axis, dtype=dtype, keepdims=keepdims))
  93. def var(
  94. x: Array,
  95. /,
  96. *,
  97. axis: Optional[Union[int, Tuple[int, ...]]] = None,
  98. correction: Union[int, float] = 0.0,
  99. keepdims: bool = False,
  100. ) -> Array:
  101. # Note: the keyword argument correction is different here
  102. if x.dtype not in _real_floating_dtypes:
  103. raise TypeError("Only real floating-point dtypes are allowed in var")
  104. return Array._new(np.var(x._array, axis=axis, ddof=correction, keepdims=keepdims))