function.py 11 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376
  1. """
  2. For compatibility with numpy libraries, pandas functions or methods have to
  3. accept '*args' and '**kwargs' parameters to accommodate numpy arguments that
  4. are not actually used or respected in the pandas implementation.
  5. To ensure that users do not abuse these parameters, validation is performed in
  6. 'validators.py' to make sure that any extra parameters passed correspond ONLY
  7. to those in the numpy signature. Part of that validation includes whether or
  8. not the user attempted to pass in non-default values for these extraneous
  9. parameters. As we want to discourage users from relying on these parameters
  10. when calling the pandas implementation, we want them only to pass in the
  11. default values for these parameters.
  12. This module provides a set of commonly used default arguments for functions and
  13. methods that are spread throughout the codebase. This module will make it
  14. easier to adjust to future upstream changes in the analogous numpy signatures.
  15. """
  16. from __future__ import annotations
  17. from typing import (
  18. TYPE_CHECKING,
  19. Any,
  20. TypeVar,
  21. cast,
  22. overload,
  23. )
  24. import numpy as np
  25. from numpy import ndarray
  26. from pandas._libs.lib import (
  27. is_bool,
  28. is_integer,
  29. )
  30. from pandas.errors import UnsupportedFunctionCall
  31. from pandas.util._validators import (
  32. validate_args,
  33. validate_args_and_kwargs,
  34. validate_kwargs,
  35. )
  36. if TYPE_CHECKING:
  37. from pandas._typing import (
  38. Axis,
  39. AxisInt,
  40. )
  41. AxisNoneT = TypeVar("AxisNoneT", Axis, None)
  42. class CompatValidator:
  43. def __init__(
  44. self,
  45. defaults,
  46. fname=None,
  47. method: str | None = None,
  48. max_fname_arg_count=None,
  49. ) -> None:
  50. self.fname = fname
  51. self.method = method
  52. self.defaults = defaults
  53. self.max_fname_arg_count = max_fname_arg_count
  54. def __call__(
  55. self,
  56. args,
  57. kwargs,
  58. fname=None,
  59. max_fname_arg_count=None,
  60. method: str | None = None,
  61. ) -> None:
  62. if not args and not kwargs:
  63. return None
  64. fname = self.fname if fname is None else fname
  65. max_fname_arg_count = (
  66. self.max_fname_arg_count
  67. if max_fname_arg_count is None
  68. else max_fname_arg_count
  69. )
  70. method = self.method if method is None else method
  71. if method == "args":
  72. validate_args(fname, args, max_fname_arg_count, self.defaults)
  73. elif method == "kwargs":
  74. validate_kwargs(fname, kwargs, self.defaults)
  75. elif method == "both":
  76. validate_args_and_kwargs(
  77. fname, args, kwargs, max_fname_arg_count, self.defaults
  78. )
  79. else:
  80. raise ValueError(f"invalid validation method '{method}'")
  81. ARGMINMAX_DEFAULTS = {"out": None}
  82. validate_argmin = CompatValidator(
  83. ARGMINMAX_DEFAULTS, fname="argmin", method="both", max_fname_arg_count=1
  84. )
  85. validate_argmax = CompatValidator(
  86. ARGMINMAX_DEFAULTS, fname="argmax", method="both", max_fname_arg_count=1
  87. )
  88. def process_skipna(skipna: bool | ndarray | None, args) -> tuple[bool, Any]:
  89. if isinstance(skipna, ndarray) or skipna is None:
  90. args = (skipna, *args)
  91. skipna = True
  92. return skipna, args
  93. def validate_argmin_with_skipna(skipna: bool | ndarray | None, args, kwargs) -> bool:
  94. """
  95. If 'Series.argmin' is called via the 'numpy' library, the third parameter
  96. in its signature is 'out', which takes either an ndarray or 'None', so
  97. check if the 'skipna' parameter is either an instance of ndarray or is
  98. None, since 'skipna' itself should be a boolean
  99. """
  100. skipna, args = process_skipna(skipna, args)
  101. validate_argmin(args, kwargs)
  102. return skipna
  103. def validate_argmax_with_skipna(skipna: bool | ndarray | None, args, kwargs) -> bool:
  104. """
  105. If 'Series.argmax' is called via the 'numpy' library, the third parameter
  106. in its signature is 'out', which takes either an ndarray or 'None', so
  107. check if the 'skipna' parameter is either an instance of ndarray or is
  108. None, since 'skipna' itself should be a boolean
  109. """
  110. skipna, args = process_skipna(skipna, args)
  111. validate_argmax(args, kwargs)
  112. return skipna
  113. ARGSORT_DEFAULTS: dict[str, int | str | None] = {}
  114. ARGSORT_DEFAULTS["axis"] = -1
  115. ARGSORT_DEFAULTS["kind"] = "quicksort"
  116. ARGSORT_DEFAULTS["order"] = None
  117. ARGSORT_DEFAULTS["kind"] = None
  118. ARGSORT_DEFAULTS["stable"] = None
  119. validate_argsort = CompatValidator(
  120. ARGSORT_DEFAULTS, fname="argsort", max_fname_arg_count=0, method="both"
  121. )
  122. # two different signatures of argsort, this second validation for when the
  123. # `kind` param is supported
  124. ARGSORT_DEFAULTS_KIND: dict[str, int | None] = {}
  125. ARGSORT_DEFAULTS_KIND["axis"] = -1
  126. ARGSORT_DEFAULTS_KIND["order"] = None
  127. ARGSORT_DEFAULTS_KIND["stable"] = None
  128. validate_argsort_kind = CompatValidator(
  129. ARGSORT_DEFAULTS_KIND, fname="argsort", max_fname_arg_count=0, method="both"
  130. )
  131. def validate_argsort_with_ascending(ascending: bool | int | None, args, kwargs) -> bool:
  132. """
  133. If 'Categorical.argsort' is called via the 'numpy' library, the first
  134. parameter in its signature is 'axis', which takes either an integer or
  135. 'None', so check if the 'ascending' parameter has either integer type or is
  136. None, since 'ascending' itself should be a boolean
  137. """
  138. if is_integer(ascending) or ascending is None:
  139. args = (ascending, *args)
  140. ascending = True
  141. validate_argsort_kind(args, kwargs, max_fname_arg_count=3)
  142. ascending = cast(bool, ascending)
  143. return ascending
  144. CLIP_DEFAULTS: dict[str, Any] = {"out": None}
  145. validate_clip = CompatValidator(
  146. CLIP_DEFAULTS, fname="clip", method="both", max_fname_arg_count=3
  147. )
  148. @overload
  149. def validate_clip_with_axis(axis: ndarray, args, kwargs) -> None: ...
  150. @overload
  151. def validate_clip_with_axis(axis: AxisNoneT, args, kwargs) -> AxisNoneT: ...
  152. def validate_clip_with_axis(
  153. axis: ndarray | AxisNoneT, args, kwargs
  154. ) -> AxisNoneT | None:
  155. """
  156. If 'NDFrame.clip' is called via the numpy library, the third parameter in
  157. its signature is 'out', which can takes an ndarray, so check if the 'axis'
  158. parameter is an instance of ndarray, since 'axis' itself should either be
  159. an integer or None
  160. """
  161. if isinstance(axis, ndarray):
  162. args = (axis, *args)
  163. # error: Incompatible types in assignment (expression has type "None",
  164. # variable has type "Union[ndarray[Any, Any], str, int]")
  165. axis = None # type: ignore[assignment]
  166. validate_clip(args, kwargs)
  167. # error: Incompatible return value type (got "Union[ndarray[Any, Any],
  168. # str, int]", expected "Union[str, int, None]")
  169. return axis # type: ignore[return-value]
  170. CUM_FUNC_DEFAULTS: dict[str, Any] = {}
  171. CUM_FUNC_DEFAULTS["dtype"] = None
  172. CUM_FUNC_DEFAULTS["out"] = None
  173. validate_cum_func = CompatValidator(
  174. CUM_FUNC_DEFAULTS, method="both", max_fname_arg_count=1
  175. )
  176. validate_cumsum = CompatValidator(
  177. CUM_FUNC_DEFAULTS, fname="cumsum", method="both", max_fname_arg_count=1
  178. )
  179. def validate_cum_func_with_skipna(skipna: bool, args, kwargs, name) -> bool:
  180. """
  181. If this function is called via the 'numpy' library, the third parameter in
  182. its signature is 'dtype', which takes either a 'numpy' dtype or 'None', so
  183. check if the 'skipna' parameter is a boolean or not
  184. """
  185. if not is_bool(skipna):
  186. args = (skipna, *args)
  187. skipna = True
  188. elif isinstance(skipna, np.bool_):
  189. skipna = bool(skipna)
  190. validate_cum_func(args, kwargs, fname=name)
  191. return skipna
  192. ALLANY_DEFAULTS: dict[str, bool | None] = {}
  193. ALLANY_DEFAULTS["dtype"] = None
  194. ALLANY_DEFAULTS["out"] = None
  195. ALLANY_DEFAULTS["keepdims"] = False
  196. ALLANY_DEFAULTS["axis"] = None
  197. validate_all = CompatValidator(
  198. ALLANY_DEFAULTS, fname="all", method="both", max_fname_arg_count=1
  199. )
  200. validate_any = CompatValidator(
  201. ALLANY_DEFAULTS, fname="any", method="both", max_fname_arg_count=1
  202. )
  203. LOGICAL_FUNC_DEFAULTS = {"out": None, "keepdims": False}
  204. validate_logical_func = CompatValidator(LOGICAL_FUNC_DEFAULTS, method="kwargs")
  205. MINMAX_DEFAULTS = {"axis": None, "dtype": None, "out": None, "keepdims": False}
  206. validate_min = CompatValidator(
  207. MINMAX_DEFAULTS, fname="min", method="both", max_fname_arg_count=1
  208. )
  209. validate_max = CompatValidator(
  210. MINMAX_DEFAULTS, fname="max", method="both", max_fname_arg_count=1
  211. )
  212. REPEAT_DEFAULTS: dict[str, Any] = {"axis": None}
  213. validate_repeat = CompatValidator(
  214. REPEAT_DEFAULTS, fname="repeat", method="both", max_fname_arg_count=1
  215. )
  216. ROUND_DEFAULTS: dict[str, Any] = {"out": None}
  217. validate_round = CompatValidator(
  218. ROUND_DEFAULTS, fname="round", method="both", max_fname_arg_count=1
  219. )
  220. STAT_FUNC_DEFAULTS: dict[str, Any | None] = {}
  221. STAT_FUNC_DEFAULTS["dtype"] = None
  222. STAT_FUNC_DEFAULTS["out"] = None
  223. SUM_DEFAULTS = STAT_FUNC_DEFAULTS.copy()
  224. SUM_DEFAULTS["axis"] = None
  225. SUM_DEFAULTS["keepdims"] = False
  226. SUM_DEFAULTS["initial"] = None
  227. PROD_DEFAULTS = SUM_DEFAULTS.copy()
  228. MEAN_DEFAULTS = SUM_DEFAULTS.copy()
  229. MEDIAN_DEFAULTS = STAT_FUNC_DEFAULTS.copy()
  230. MEDIAN_DEFAULTS["overwrite_input"] = False
  231. MEDIAN_DEFAULTS["keepdims"] = False
  232. STAT_FUNC_DEFAULTS["keepdims"] = False
  233. validate_stat_func = CompatValidator(STAT_FUNC_DEFAULTS, method="kwargs")
  234. validate_sum = CompatValidator(
  235. SUM_DEFAULTS, fname="sum", method="both", max_fname_arg_count=1
  236. )
  237. validate_prod = CompatValidator(
  238. PROD_DEFAULTS, fname="prod", method="both", max_fname_arg_count=1
  239. )
  240. validate_mean = CompatValidator(
  241. MEAN_DEFAULTS, fname="mean", method="both", max_fname_arg_count=1
  242. )
  243. validate_median = CompatValidator(
  244. MEDIAN_DEFAULTS, fname="median", method="both", max_fname_arg_count=1
  245. )
  246. STAT_DDOF_FUNC_DEFAULTS: dict[str, bool | None] = {}
  247. STAT_DDOF_FUNC_DEFAULTS["dtype"] = None
  248. STAT_DDOF_FUNC_DEFAULTS["out"] = None
  249. STAT_DDOF_FUNC_DEFAULTS["keepdims"] = False
  250. validate_stat_ddof_func = CompatValidator(STAT_DDOF_FUNC_DEFAULTS, method="kwargs")
  251. TAKE_DEFAULTS: dict[str, str | None] = {}
  252. TAKE_DEFAULTS["out"] = None
  253. TAKE_DEFAULTS["mode"] = "raise"
  254. validate_take = CompatValidator(TAKE_DEFAULTS, fname="take", method="kwargs")
  255. TRANSPOSE_DEFAULTS = {"axes": None}
  256. validate_transpose = CompatValidator(
  257. TRANSPOSE_DEFAULTS, fname="transpose", method="both", max_fname_arg_count=0
  258. )
  259. def validate_groupby_func(name: str, args, kwargs, allowed=None) -> None:
  260. """
  261. 'args' and 'kwargs' should be empty, except for allowed kwargs because all
  262. of their necessary parameters are explicitly listed in the function
  263. signature
  264. """
  265. if allowed is None:
  266. allowed = []
  267. kwargs = set(kwargs) - set(allowed)
  268. if len(args) + len(kwargs) > 0:
  269. raise UnsupportedFunctionCall(
  270. "numpy operations are not valid with groupby. "
  271. f"Use .groupby(...).{name}() instead"
  272. )
  273. def validate_minmax_axis(axis: AxisInt | None, ndim: int = 1) -> None:
  274. """
  275. Ensure that the axis argument passed to min, max, argmin, or argmax is zero
  276. or None, as otherwise it will be incorrectly ignored.
  277. Parameters
  278. ----------
  279. axis : int or None
  280. ndim : int, default 1
  281. Raises
  282. ------
  283. ValueError
  284. """
  285. if axis is None:
  286. return
  287. if axis >= ndim or (axis < 0 and ndim + axis < 0):
  288. raise ValueError(f"`axis` must be fewer than the number of dimensions ({ndim})")
  289. _validation_funcs = {
  290. "median": validate_median,
  291. "mean": validate_mean,
  292. "min": validate_min,
  293. "max": validate_max,
  294. "sum": validate_sum,
  295. "prod": validate_prod,
  296. }
  297. def validate_func(fname, args, kwargs) -> None:
  298. if fname not in _validation_funcs:
  299. return validate_stat_func(args, kwargs, fname=fname)
  300. validation_func = _validation_funcs[fname]
  301. return validation_func(args, kwargs)