array.py 9.5 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311
  1. from __future__ import annotations
  2. import decimal
  3. import numbers
  4. import sys
  5. from typing import TYPE_CHECKING
  6. import numpy as np
  7. from pandas.core.dtypes.base import ExtensionDtype
  8. from pandas.core.dtypes.common import (
  9. is_dtype_equal,
  10. is_float,
  11. is_integer,
  12. pandas_dtype,
  13. )
  14. import pandas as pd
  15. from pandas.api.extensions import (
  16. no_default,
  17. register_extension_dtype,
  18. )
  19. from pandas.api.types import (
  20. is_list_like,
  21. is_scalar,
  22. )
  23. from pandas.core import arraylike
  24. from pandas.core.algorithms import value_counts_internal as value_counts
  25. from pandas.core.arraylike import OpsMixin
  26. from pandas.core.arrays import (
  27. ExtensionArray,
  28. ExtensionScalarOpsMixin,
  29. )
  30. from pandas.core.indexers import check_array_indexer
  31. if TYPE_CHECKING:
  32. from pandas._typing import type_t
  33. @register_extension_dtype
  34. class DecimalDtype(ExtensionDtype):
  35. type = decimal.Decimal
  36. name = "decimal"
  37. na_value = decimal.Decimal("NaN")
  38. _metadata = ("context",)
  39. def __init__(self, context=None) -> None:
  40. self.context = context or decimal.getcontext()
  41. def __repr__(self) -> str:
  42. return f"DecimalDtype(context={self.context})"
  43. @classmethod
  44. def construct_array_type(cls) -> type_t[DecimalArray]:
  45. """
  46. Return the array type associated with this dtype.
  47. Returns
  48. -------
  49. type
  50. """
  51. return DecimalArray
  52. @property
  53. def _is_numeric(self) -> bool:
  54. return True
  55. class DecimalArray(OpsMixin, ExtensionScalarOpsMixin, ExtensionArray):
  56. __array_priority__ = 1000
  57. def __init__(self, values, dtype=None, copy=False, context=None) -> None:
  58. for i, val in enumerate(values):
  59. if is_float(val) or is_integer(val):
  60. if np.isnan(val):
  61. values[i] = DecimalDtype.na_value
  62. else:
  63. # error: Argument 1 has incompatible type "float | int |
  64. # integer[Any]"; expected "Decimal | float | str | tuple[int,
  65. # Sequence[int], int]"
  66. values[i] = DecimalDtype.type(val) # type: ignore[arg-type]
  67. elif not isinstance(val, decimal.Decimal):
  68. raise TypeError("All values must be of type " + str(decimal.Decimal))
  69. values = np.asarray(values, dtype=object)
  70. self._data = values
  71. # Some aliases for common attribute names to ensure pandas supports
  72. # these
  73. self._items = self.data = self._data
  74. # those aliases are currently not working due to assumptions
  75. # in internal code (GH-20735)
  76. # self._values = self.values = self.data
  77. self._dtype = DecimalDtype(context)
  78. @property
  79. def dtype(self):
  80. return self._dtype
  81. @classmethod
  82. def _from_sequence(cls, scalars, *, dtype=None, copy=False):
  83. return cls(scalars)
  84. @classmethod
  85. def _from_sequence_of_strings(cls, strings, dtype=None, copy=False):
  86. return cls._from_sequence(
  87. [decimal.Decimal(x) for x in strings], dtype=dtype, copy=copy
  88. )
  89. @classmethod
  90. def _from_factorized(cls, values, original):
  91. return cls(values)
  92. _HANDLED_TYPES = (decimal.Decimal, numbers.Number, np.ndarray)
  93. def to_numpy(
  94. self,
  95. dtype=None,
  96. copy: bool = False,
  97. na_value: object = no_default,
  98. decimals=None,
  99. ) -> np.ndarray:
  100. result = np.asarray(self, dtype=dtype)
  101. if decimals is not None:
  102. result = np.asarray([round(x, decimals) for x in result])
  103. return result
  104. def __array_ufunc__(self, ufunc: np.ufunc, method: str, *inputs, **kwargs):
  105. #
  106. if not all(
  107. isinstance(t, self._HANDLED_TYPES + (DecimalArray,)) for t in inputs
  108. ):
  109. return NotImplemented
  110. result = arraylike.maybe_dispatch_ufunc_to_dunder_op(
  111. self, ufunc, method, *inputs, **kwargs
  112. )
  113. if result is not NotImplemented:
  114. # e.g. test_array_ufunc_series_scalar_other
  115. return result
  116. if "out" in kwargs:
  117. return arraylike.dispatch_ufunc_with_out(
  118. self, ufunc, method, *inputs, **kwargs
  119. )
  120. inputs = tuple(x._data if isinstance(x, DecimalArray) else x for x in inputs)
  121. result = getattr(ufunc, method)(*inputs, **kwargs)
  122. if method == "reduce":
  123. result = arraylike.dispatch_reduction_ufunc(
  124. self, ufunc, method, *inputs, **kwargs
  125. )
  126. if result is not NotImplemented:
  127. return result
  128. def reconstruct(x):
  129. if isinstance(x, (decimal.Decimal, numbers.Number)):
  130. return x
  131. else:
  132. return type(self)._from_sequence(x, dtype=self.dtype)
  133. if ufunc.nout > 1:
  134. return tuple(reconstruct(x) for x in result)
  135. else:
  136. return reconstruct(result)
  137. def __getitem__(self, item):
  138. if isinstance(item, numbers.Integral):
  139. return self._data[item]
  140. else:
  141. # array, slice.
  142. item = pd.api.indexers.check_array_indexer(self, item)
  143. return type(self)(self._data[item])
  144. def take(self, indexer, allow_fill=False, fill_value=None):
  145. from pandas.api.extensions import take
  146. data = self._data
  147. if allow_fill and fill_value is None:
  148. fill_value = self.dtype.na_value
  149. result = take(data, indexer, fill_value=fill_value, allow_fill=allow_fill)
  150. return self._from_sequence(result, dtype=self.dtype)
  151. def copy(self):
  152. return type(self)(self._data.copy(), dtype=self.dtype)
  153. def astype(self, dtype, copy=True):
  154. if is_dtype_equal(dtype, self._dtype):
  155. if not copy:
  156. return self
  157. dtype = pandas_dtype(dtype)
  158. if isinstance(dtype, type(self.dtype)):
  159. return type(self)(self._data, copy=copy, context=dtype.context)
  160. return super().astype(dtype, copy=copy)
  161. def __setitem__(self, key, value) -> None:
  162. if is_list_like(value):
  163. if is_scalar(key):
  164. raise ValueError("setting an array element with a sequence.")
  165. value = [decimal.Decimal(v) for v in value]
  166. else:
  167. value = decimal.Decimal(value)
  168. key = check_array_indexer(self, key)
  169. self._data[key] = value
  170. def __len__(self) -> int:
  171. return len(self._data)
  172. def __contains__(self, item) -> bool | np.bool_:
  173. if not isinstance(item, decimal.Decimal):
  174. return False
  175. elif item.is_nan():
  176. return self.isna().any()
  177. else:
  178. return super().__contains__(item)
  179. @property
  180. def nbytes(self) -> int:
  181. n = len(self)
  182. if n:
  183. return n * sys.getsizeof(self[0])
  184. return 0
  185. def isna(self):
  186. return np.array([x.is_nan() for x in self._data], dtype=bool)
  187. @property
  188. def _na_value(self):
  189. return decimal.Decimal("NaN")
  190. def _formatter(self, boxed=False):
  191. if boxed:
  192. return "Decimal: {}".format
  193. return repr
  194. @classmethod
  195. def _concat_same_type(cls, to_concat):
  196. return cls(np.concatenate([x._data for x in to_concat]))
  197. def _reduce(
  198. self, name: str, *, skipna: bool = True, keepdims: bool = False, **kwargs
  199. ):
  200. if skipna and self.isna().any():
  201. # If we don't have any NAs, we can ignore skipna
  202. other = self[~self.isna()]
  203. result = other._reduce(name, **kwargs)
  204. elif name == "sum" and len(self) == 0:
  205. # GH#29630 avoid returning int 0 or np.bool_(False) on old numpy
  206. result = decimal.Decimal(0)
  207. else:
  208. try:
  209. op = getattr(self.data, name)
  210. except AttributeError as err:
  211. raise NotImplementedError(
  212. f"decimal does not support the {name} operation"
  213. ) from err
  214. result = op(axis=0)
  215. if keepdims:
  216. return type(self)([result])
  217. else:
  218. return result
  219. def _cmp_method(self, other, op):
  220. # For use with OpsMixin
  221. def convert_values(param):
  222. if isinstance(param, ExtensionArray) or is_list_like(param):
  223. ovalues = param
  224. else:
  225. # Assume it's an object
  226. ovalues = [param] * len(self)
  227. return ovalues
  228. lvalues = self
  229. rvalues = convert_values(other)
  230. # If the operator is not defined for the underlying objects,
  231. # a TypeError should be raised
  232. res = [op(a, b) for (a, b) in zip(lvalues, rvalues)]
  233. return np.asarray(res, dtype=bool)
  234. def value_counts(self, dropna: bool = True):
  235. return value_counts(self.to_numpy(), dropna=dropna)
  236. # We override fillna here to simulate a 3rd party EA that has done so. This
  237. # lets us test the deprecation telling authors to implement _pad_or_backfill
  238. # Simulate a 3rd-party EA that has not yet updated to include a "copy"
  239. # keyword in its fillna method.
  240. # error: Signature of "fillna" incompatible with supertype "ExtensionArray"
  241. def fillna( # type: ignore[override]
  242. self,
  243. value=None,
  244. method=None,
  245. limit: int | None = None,
  246. ):
  247. return super().fillna(value=value, method=method, limit=limit, copy=True)
  248. def to_decimal(values, context=None):
  249. return DecimalArray([decimal.Decimal(x) for x in values], context=context)
  250. def make_data():
  251. return [decimal.Decimal(val) for val in np.random.default_rng(2).random(100)]
  252. DecimalArray._add_arithmetic_ops()