function.py 13 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418
  1. """
  2. For compatibility with numpy libraries, pandas functions or methods have to
  3. accept '*args' and '**kwargs' parameters to accommodate numpy arguments that
  4. are not actually used or respected in the pandas implementation.
  5. To ensure that users do not abuse these parameters, validation is performed in
  6. 'validators.py' to make sure that any extra parameters passed correspond ONLY
  7. to those in the numpy signature. Part of that validation includes whether or
  8. not the user attempted to pass in non-default values for these extraneous
  9. parameters. As we want to discourage users from relying on these parameters
  10. when calling the pandas implementation, we want them only to pass in the
  11. default values for these parameters.
  12. This module provides a set of commonly used default arguments for functions and
  13. methods that are spread throughout the codebase. This module will make it
  14. easier to adjust to future upstream changes in the analogous numpy signatures.
  15. """
  16. from __future__ import annotations
  17. from typing import (
  18. TYPE_CHECKING,
  19. Any,
  20. TypeVar,
  21. cast,
  22. overload,
  23. )
  24. import numpy as np
  25. from numpy import ndarray
  26. from pandas._libs.lib import (
  27. is_bool,
  28. is_integer,
  29. )
  30. from pandas.errors import UnsupportedFunctionCall
  31. from pandas.util._validators import (
  32. validate_args,
  33. validate_args_and_kwargs,
  34. validate_kwargs,
  35. )
  36. if TYPE_CHECKING:
  37. from pandas._typing import (
  38. Axis,
  39. AxisInt,
  40. )
  41. AxisNoneT = TypeVar("AxisNoneT", Axis, None)
  42. class CompatValidator:
  43. def __init__(
  44. self,
  45. defaults,
  46. fname=None,
  47. method: str | None = None,
  48. max_fname_arg_count=None,
  49. ) -> None:
  50. self.fname = fname
  51. self.method = method
  52. self.defaults = defaults
  53. self.max_fname_arg_count = max_fname_arg_count
  54. def __call__(
  55. self,
  56. args,
  57. kwargs,
  58. fname=None,
  59. max_fname_arg_count=None,
  60. method: str | None = None,
  61. ) -> None:
  62. if not args and not kwargs:
  63. return None
  64. fname = self.fname if fname is None else fname
  65. max_fname_arg_count = (
  66. self.max_fname_arg_count
  67. if max_fname_arg_count is None
  68. else max_fname_arg_count
  69. )
  70. method = self.method if method is None else method
  71. if method == "args":
  72. validate_args(fname, args, max_fname_arg_count, self.defaults)
  73. elif method == "kwargs":
  74. validate_kwargs(fname, kwargs, self.defaults)
  75. elif method == "both":
  76. validate_args_and_kwargs(
  77. fname, args, kwargs, max_fname_arg_count, self.defaults
  78. )
  79. else:
  80. raise ValueError(f"invalid validation method '{method}'")
  81. ARGMINMAX_DEFAULTS = {"out": None}
  82. validate_argmin = CompatValidator(
  83. ARGMINMAX_DEFAULTS, fname="argmin", method="both", max_fname_arg_count=1
  84. )
  85. validate_argmax = CompatValidator(
  86. ARGMINMAX_DEFAULTS, fname="argmax", method="both", max_fname_arg_count=1
  87. )
  88. def process_skipna(skipna: bool | ndarray | None, args) -> tuple[bool, Any]:
  89. if isinstance(skipna, ndarray) or skipna is None:
  90. args = (skipna,) + args
  91. skipna = True
  92. return skipna, args
  93. def validate_argmin_with_skipna(skipna: bool | ndarray | None, args, kwargs) -> bool:
  94. """
  95. If 'Series.argmin' is called via the 'numpy' library, the third parameter
  96. in its signature is 'out', which takes either an ndarray or 'None', so
  97. check if the 'skipna' parameter is either an instance of ndarray or is
  98. None, since 'skipna' itself should be a boolean
  99. """
  100. skipna, args = process_skipna(skipna, args)
  101. validate_argmin(args, kwargs)
  102. return skipna
  103. def validate_argmax_with_skipna(skipna: bool | ndarray | None, args, kwargs) -> bool:
  104. """
  105. If 'Series.argmax' is called via the 'numpy' library, the third parameter
  106. in its signature is 'out', which takes either an ndarray or 'None', so
  107. check if the 'skipna' parameter is either an instance of ndarray or is
  108. None, since 'skipna' itself should be a boolean
  109. """
  110. skipna, args = process_skipna(skipna, args)
  111. validate_argmax(args, kwargs)
  112. return skipna
  113. ARGSORT_DEFAULTS: dict[str, int | str | None] = {}
  114. ARGSORT_DEFAULTS["axis"] = -1
  115. ARGSORT_DEFAULTS["kind"] = "quicksort"
  116. ARGSORT_DEFAULTS["order"] = None
  117. ARGSORT_DEFAULTS["kind"] = None
  118. ARGSORT_DEFAULTS["stable"] = None
  119. validate_argsort = CompatValidator(
  120. ARGSORT_DEFAULTS, fname="argsort", max_fname_arg_count=0, method="both"
  121. )
  122. # two different signatures of argsort, this second validation for when the
  123. # `kind` param is supported
  124. ARGSORT_DEFAULTS_KIND: dict[str, int | None] = {}
  125. ARGSORT_DEFAULTS_KIND["axis"] = -1
  126. ARGSORT_DEFAULTS_KIND["order"] = None
  127. ARGSORT_DEFAULTS_KIND["stable"] = None
  128. validate_argsort_kind = CompatValidator(
  129. ARGSORT_DEFAULTS_KIND, fname="argsort", max_fname_arg_count=0, method="both"
  130. )
  131. def validate_argsort_with_ascending(ascending: bool | int | None, args, kwargs) -> bool:
  132. """
  133. If 'Categorical.argsort' is called via the 'numpy' library, the first
  134. parameter in its signature is 'axis', which takes either an integer or
  135. 'None', so check if the 'ascending' parameter has either integer type or is
  136. None, since 'ascending' itself should be a boolean
  137. """
  138. if is_integer(ascending) or ascending is None:
  139. args = (ascending,) + args
  140. ascending = True
  141. validate_argsort_kind(args, kwargs, max_fname_arg_count=3)
  142. ascending = cast(bool, ascending)
  143. return ascending
  144. CLIP_DEFAULTS: dict[str, Any] = {"out": None}
  145. validate_clip = CompatValidator(
  146. CLIP_DEFAULTS, fname="clip", method="both", max_fname_arg_count=3
  147. )
  148. @overload
  149. def validate_clip_with_axis(axis: ndarray, args, kwargs) -> None:
  150. ...
  151. @overload
  152. def validate_clip_with_axis(axis: AxisNoneT, args, kwargs) -> AxisNoneT:
  153. ...
  154. def validate_clip_with_axis(
  155. axis: ndarray | AxisNoneT, args, kwargs
  156. ) -> AxisNoneT | None:
  157. """
  158. If 'NDFrame.clip' is called via the numpy library, the third parameter in
  159. its signature is 'out', which can takes an ndarray, so check if the 'axis'
  160. parameter is an instance of ndarray, since 'axis' itself should either be
  161. an integer or None
  162. """
  163. if isinstance(axis, ndarray):
  164. args = (axis,) + args
  165. # error: Incompatible types in assignment (expression has type "None",
  166. # variable has type "Union[ndarray[Any, Any], str, int]")
  167. axis = None # type: ignore[assignment]
  168. validate_clip(args, kwargs)
  169. # error: Incompatible return value type (got "Union[ndarray[Any, Any],
  170. # str, int]", expected "Union[str, int, None]")
  171. return axis # type: ignore[return-value]
  172. CUM_FUNC_DEFAULTS: dict[str, Any] = {}
  173. CUM_FUNC_DEFAULTS["dtype"] = None
  174. CUM_FUNC_DEFAULTS["out"] = None
  175. validate_cum_func = CompatValidator(
  176. CUM_FUNC_DEFAULTS, method="both", max_fname_arg_count=1
  177. )
  178. validate_cumsum = CompatValidator(
  179. CUM_FUNC_DEFAULTS, fname="cumsum", method="both", max_fname_arg_count=1
  180. )
  181. def validate_cum_func_with_skipna(skipna: bool, args, kwargs, name) -> bool:
  182. """
  183. If this function is called via the 'numpy' library, the third parameter in
  184. its signature is 'dtype', which takes either a 'numpy' dtype or 'None', so
  185. check if the 'skipna' parameter is a boolean or not
  186. """
  187. if not is_bool(skipna):
  188. args = (skipna,) + args
  189. skipna = True
  190. elif isinstance(skipna, np.bool_):
  191. skipna = bool(skipna)
  192. validate_cum_func(args, kwargs, fname=name)
  193. return skipna
  194. ALLANY_DEFAULTS: dict[str, bool | None] = {}
  195. ALLANY_DEFAULTS["dtype"] = None
  196. ALLANY_DEFAULTS["out"] = None
  197. ALLANY_DEFAULTS["keepdims"] = False
  198. ALLANY_DEFAULTS["axis"] = None
  199. validate_all = CompatValidator(
  200. ALLANY_DEFAULTS, fname="all", method="both", max_fname_arg_count=1
  201. )
  202. validate_any = CompatValidator(
  203. ALLANY_DEFAULTS, fname="any", method="both", max_fname_arg_count=1
  204. )
  205. LOGICAL_FUNC_DEFAULTS = {"out": None, "keepdims": False}
  206. validate_logical_func = CompatValidator(LOGICAL_FUNC_DEFAULTS, method="kwargs")
  207. MINMAX_DEFAULTS = {"axis": None, "dtype": None, "out": None, "keepdims": False}
  208. validate_min = CompatValidator(
  209. MINMAX_DEFAULTS, fname="min", method="both", max_fname_arg_count=1
  210. )
  211. validate_max = CompatValidator(
  212. MINMAX_DEFAULTS, fname="max", method="both", max_fname_arg_count=1
  213. )
  214. RESHAPE_DEFAULTS: dict[str, str] = {"order": "C"}
  215. validate_reshape = CompatValidator(
  216. RESHAPE_DEFAULTS, fname="reshape", method="both", max_fname_arg_count=1
  217. )
  218. REPEAT_DEFAULTS: dict[str, Any] = {"axis": None}
  219. validate_repeat = CompatValidator(
  220. REPEAT_DEFAULTS, fname="repeat", method="both", max_fname_arg_count=1
  221. )
  222. ROUND_DEFAULTS: dict[str, Any] = {"out": None}
  223. validate_round = CompatValidator(
  224. ROUND_DEFAULTS, fname="round", method="both", max_fname_arg_count=1
  225. )
  226. SORT_DEFAULTS: dict[str, int | str | None] = {}
  227. SORT_DEFAULTS["axis"] = -1
  228. SORT_DEFAULTS["kind"] = "quicksort"
  229. SORT_DEFAULTS["order"] = None
  230. validate_sort = CompatValidator(SORT_DEFAULTS, fname="sort", method="kwargs")
  231. STAT_FUNC_DEFAULTS: dict[str, Any | None] = {}
  232. STAT_FUNC_DEFAULTS["dtype"] = None
  233. STAT_FUNC_DEFAULTS["out"] = None
  234. SUM_DEFAULTS = STAT_FUNC_DEFAULTS.copy()
  235. SUM_DEFAULTS["axis"] = None
  236. SUM_DEFAULTS["keepdims"] = False
  237. SUM_DEFAULTS["initial"] = None
  238. PROD_DEFAULTS = SUM_DEFAULTS.copy()
  239. MEAN_DEFAULTS = SUM_DEFAULTS.copy()
  240. MEDIAN_DEFAULTS = STAT_FUNC_DEFAULTS.copy()
  241. MEDIAN_DEFAULTS["overwrite_input"] = False
  242. MEDIAN_DEFAULTS["keepdims"] = False
  243. STAT_FUNC_DEFAULTS["keepdims"] = False
  244. validate_stat_func = CompatValidator(STAT_FUNC_DEFAULTS, method="kwargs")
  245. validate_sum = CompatValidator(
  246. SUM_DEFAULTS, fname="sum", method="both", max_fname_arg_count=1
  247. )
  248. validate_prod = CompatValidator(
  249. PROD_DEFAULTS, fname="prod", method="both", max_fname_arg_count=1
  250. )
  251. validate_mean = CompatValidator(
  252. MEAN_DEFAULTS, fname="mean", method="both", max_fname_arg_count=1
  253. )
  254. validate_median = CompatValidator(
  255. MEDIAN_DEFAULTS, fname="median", method="both", max_fname_arg_count=1
  256. )
  257. STAT_DDOF_FUNC_DEFAULTS: dict[str, bool | None] = {}
  258. STAT_DDOF_FUNC_DEFAULTS["dtype"] = None
  259. STAT_DDOF_FUNC_DEFAULTS["out"] = None
  260. STAT_DDOF_FUNC_DEFAULTS["keepdims"] = False
  261. validate_stat_ddof_func = CompatValidator(STAT_DDOF_FUNC_DEFAULTS, method="kwargs")
  262. TAKE_DEFAULTS: dict[str, str | None] = {}
  263. TAKE_DEFAULTS["out"] = None
  264. TAKE_DEFAULTS["mode"] = "raise"
  265. validate_take = CompatValidator(TAKE_DEFAULTS, fname="take", method="kwargs")
  266. def validate_take_with_convert(convert: ndarray | bool | None, args, kwargs) -> bool:
  267. """
  268. If this function is called via the 'numpy' library, the third parameter in
  269. its signature is 'axis', which takes either an ndarray or 'None', so check
  270. if the 'convert' parameter is either an instance of ndarray or is None
  271. """
  272. if isinstance(convert, ndarray) or convert is None:
  273. args = (convert,) + args
  274. convert = True
  275. validate_take(args, kwargs, max_fname_arg_count=3, method="both")
  276. return convert
  277. TRANSPOSE_DEFAULTS = {"axes": None}
  278. validate_transpose = CompatValidator(
  279. TRANSPOSE_DEFAULTS, fname="transpose", method="both", max_fname_arg_count=0
  280. )
  281. def validate_groupby_func(name: str, args, kwargs, allowed=None) -> None:
  282. """
  283. 'args' and 'kwargs' should be empty, except for allowed kwargs because all
  284. of their necessary parameters are explicitly listed in the function
  285. signature
  286. """
  287. if allowed is None:
  288. allowed = []
  289. kwargs = set(kwargs) - set(allowed)
  290. if len(args) + len(kwargs) > 0:
  291. raise UnsupportedFunctionCall(
  292. "numpy operations are not valid with groupby. "
  293. f"Use .groupby(...).{name}() instead"
  294. )
  295. RESAMPLER_NUMPY_OPS = ("min", "max", "sum", "prod", "mean", "std", "var")
  296. def validate_resampler_func(method: str, args, kwargs) -> None:
  297. """
  298. 'args' and 'kwargs' should be empty because all of their necessary
  299. parameters are explicitly listed in the function signature
  300. """
  301. if len(args) + len(kwargs) > 0:
  302. if method in RESAMPLER_NUMPY_OPS:
  303. raise UnsupportedFunctionCall(
  304. "numpy operations are not valid with resample. "
  305. f"Use .resample(...).{method}() instead"
  306. )
  307. raise TypeError("too many arguments passed in")
  308. def validate_minmax_axis(axis: AxisInt | None, ndim: int = 1) -> None:
  309. """
  310. Ensure that the axis argument passed to min, max, argmin, or argmax is zero
  311. or None, as otherwise it will be incorrectly ignored.
  312. Parameters
  313. ----------
  314. axis : int or None
  315. ndim : int, default 1
  316. Raises
  317. ------
  318. ValueError
  319. """
  320. if axis is None:
  321. return
  322. if axis >= ndim or (axis < 0 and ndim + axis < 0):
  323. raise ValueError(f"`axis` must be fewer than the number of dimensions ({ndim})")
  324. _validation_funcs = {
  325. "median": validate_median,
  326. "mean": validate_mean,
  327. "min": validate_min,
  328. "max": validate_max,
  329. "sum": validate_sum,
  330. "prod": validate_prod,
  331. }
  332. def validate_func(fname, args, kwargs) -> None:
  333. if fname not in _validation_funcs:
  334. return validate_stat_func(args, kwargs, fname=fname)
  335. validation_func = _validation_funcs[fname]
  336. return validation_func(args, kwargs)