apply.py 66 KB

12345678910111213141516171819202122232425262728293031323334353637383940414243444546474849505152535455565758596061626364656667686970717273747576777879808182838485868788899091929394959697989910010110210310410510610710810911011111211311411511611711811912012112212312412512612712812913013113213313413513613713813914014114214314414514614714814915015115215315415515615715815916016116216316416516616716816917017117217317417517617717817918018118218318418518618718818919019119219319419519619719819920020120220320420520620720820921021121221321421521621721821922022122222322422522622722822923023123223323423523623723823924024124224324424524624724824925025125225325425525625725825926026126226326426526626726826927027127227327427527627727827928028128228328428528628728828929029129229329429529629729829930030130230330430530630730830931031131231331431531631731831932032132232332432532632732832933033133233333433533633733833934034134234334434534634734834935035135235335435535635735835936036136236336436536636736836937037137237337437537637737837938038138238338438538638738838939039139239339439539639739839940040140240340440540640740840941041141241341441541641741841942042142242342442542642742842943043143243343443543643743843944044144244344444544644744844945045145245345445545645745845946046146246346446546646746846947047147247347447547647747847948048148248348448548648748848949049149249349449549649749849950050150250350450550650750850951051151251351451551651751851952052152252352452552652752852953053153253353453553653753853954054154254354454554654754854955055155255355455555655755855956056156256356456556656756856957057157257357457557657757857958058158258358458558658758858959059159259359459559659759859960060160260360460560660760860961061161261361461561661761861962062162262362462562662762862963063163263363463563663763863964064164264364464564664764864965065165265365465565665765865966066166266366466566666766866967067167267367467567667767867968068168268368468568668768868969069169269369469569669769869970070170270370470570670770870971071171271371471571671771871972072172272372472572672772872973073173273373473573673773873974074174274374474574674774874975075175275375475575675775875976076176276376476576676776876977077177277377477577677777877978078178278378478578678778878979079179279379479579679779879980080180280380480580680780880981081181281381481581681781881982082182282382482582682782882983083183283383483583683783883984084184284384484584684784884985085185285385485585685785885986086186286386486586686786886987087187287387487587687787887988088188288388488588688788888989089189289389489589689789889990090190290390490590690790890991091191291391491591691791891992092192292392492592692792892993093193293393493593693793893994094194294394494594694794894995095195295395495595695795895996096196296396496596696796896997097197297397497597697797897998098198298398498598698798898999099199299399499599699799899910001001100210031004100510061007100810091010101110121013101410151016101710181019102010211022102310241025102610271028102910301031103210331034103510361037103810391040104110421043104410451046104710481049105010511052105310541055105610571058105910601061106210631064106510661067106810691070107110721073107410751076107710781079108010811082108310841085108610871088108910901091109210931094109510961097109810991100110111021103110411051106110711081109111011111112111311141115111611171118111911201121112211231124112511261127112811291130113111321133113411351136113711381139114011411142114311441145114611471148114911501151115211531154115511561157115811591160116111621163116411651166116711681169117011711172117311741175117611771178117911801181118211831184118511861187118811891190119111921193119411951196119711981199120012011202120312041205120612071208120912101211121212131214121512161217121812191220122112221223122412251226122712281229123012311232123312341235123612371238123912401241124212431244124512461247124812491250125112521253125412551256125712581259126012611262126312641265126612671268126912701271127212731274127512761277127812791280128112821283128412851286128712881289129012911292129312941295129612971298129913001301130213031304130513061307130813091310131113121313131413151316131713181319132013211322132313241325132613271328132913301331133213331334133513361337133813391340134113421343134413451346134713481349135013511352135313541355135613571358135913601361136213631364136513661367136813691370137113721373137413751376137713781379138013811382138313841385138613871388138913901391139213931394139513961397139813991400140114021403140414051406140714081409141014111412141314141415141614171418141914201421142214231424142514261427142814291430143114321433143414351436143714381439144014411442144314441445144614471448144914501451145214531454145514561457145814591460146114621463146414651466146714681469147014711472147314741475147614771478147914801481148214831484148514861487148814891490149114921493149414951496149714981499150015011502150315041505150615071508150915101511151215131514151515161517151815191520152115221523152415251526152715281529153015311532153315341535153615371538153915401541154215431544154515461547154815491550155115521553155415551556155715581559156015611562156315641565156615671568156915701571157215731574157515761577157815791580158115821583158415851586158715881589159015911592159315941595159615971598159916001601160216031604160516061607160816091610161116121613161416151616161716181619162016211622162316241625162616271628162916301631163216331634163516361637163816391640164116421643164416451646164716481649165016511652165316541655165616571658165916601661166216631664166516661667166816691670167116721673167416751676167716781679168016811682168316841685168616871688168916901691169216931694169516961697169816991700170117021703170417051706170717081709171017111712171317141715171617171718171917201721172217231724172517261727172817291730173117321733173417351736173717381739174017411742174317441745174617471748174917501751175217531754175517561757175817591760176117621763176417651766176717681769177017711772177317741775177617771778177917801781178217831784178517861787178817891790179117921793179417951796179717981799180018011802180318041805180618071808180918101811181218131814181518161817181818191820182118221823182418251826182718281829183018311832183318341835183618371838183918401841184218431844184518461847184818491850185118521853185418551856185718581859186018611862186318641865186618671868186918701871187218731874187518761877187818791880188118821883188418851886188718881889189018911892189318941895189618971898189919001901190219031904190519061907190819091910191119121913191419151916191719181919192019211922192319241925192619271928192919301931193219331934193519361937193819391940194119421943194419451946194719481949195019511952195319541955195619571958195919601961196219631964196519661967196819691970197119721973197419751976197719781979198019811982198319841985198619871988198919901991199219931994199519961997199819992000200120022003200420052006200720082009201020112012201320142015201620172018201920202021202220232024202520262027202820292030203120322033203420352036203720382039204020412042204320442045204620472048204920502051205220532054205520562057
  1. from __future__ import annotations
  2. import abc
  3. from collections import defaultdict
  4. import functools
  5. from functools import partial
  6. import inspect
  7. from typing import (
  8. TYPE_CHECKING,
  9. Any,
  10. Callable,
  11. Literal,
  12. cast,
  13. )
  14. import warnings
  15. import numpy as np
  16. from pandas._config import option_context
  17. from pandas._libs import lib
  18. from pandas._libs.internals import BlockValuesRefs
  19. from pandas._typing import (
  20. AggFuncType,
  21. AggFuncTypeBase,
  22. AggFuncTypeDict,
  23. AggObjType,
  24. Axis,
  25. AxisInt,
  26. NDFrameT,
  27. npt,
  28. )
  29. from pandas.compat._optional import import_optional_dependency
  30. from pandas.errors import SpecificationError
  31. from pandas.util._decorators import cache_readonly
  32. from pandas.util._exceptions import find_stack_level
  33. from pandas.core.dtypes.cast import is_nested_object
  34. from pandas.core.dtypes.common import (
  35. is_dict_like,
  36. is_extension_array_dtype,
  37. is_list_like,
  38. is_numeric_dtype,
  39. is_sequence,
  40. )
  41. from pandas.core.dtypes.dtypes import (
  42. CategoricalDtype,
  43. ExtensionDtype,
  44. )
  45. from pandas.core.dtypes.generic import (
  46. ABCDataFrame,
  47. ABCNDFrame,
  48. ABCSeries,
  49. )
  50. from pandas.core._numba.executor import generate_apply_looper
  51. import pandas.core.common as com
  52. from pandas.core.construction import ensure_wrapped_if_datetimelike
  53. if TYPE_CHECKING:
  54. from collections.abc import (
  55. Generator,
  56. Hashable,
  57. Iterable,
  58. MutableMapping,
  59. Sequence,
  60. )
  61. from pandas import (
  62. DataFrame,
  63. Index,
  64. Series,
  65. )
  66. from pandas.core.groupby import GroupBy
  67. from pandas.core.resample import Resampler
  68. from pandas.core.window.rolling import BaseWindow
  69. ResType = dict[int, Any]
  70. def frame_apply(
  71. obj: DataFrame,
  72. func: AggFuncType,
  73. axis: Axis = 0,
  74. raw: bool = False,
  75. result_type: str | None = None,
  76. by_row: Literal[False, "compat"] = "compat",
  77. engine: str = "python",
  78. engine_kwargs: dict[str, bool] | None = None,
  79. args=None,
  80. kwargs=None,
  81. ) -> FrameApply:
  82. """construct and return a row or column based frame apply object"""
  83. axis = obj._get_axis_number(axis)
  84. klass: type[FrameApply]
  85. if axis == 0:
  86. klass = FrameRowApply
  87. elif axis == 1:
  88. klass = FrameColumnApply
  89. _, func, _, _ = reconstruct_func(func, **kwargs)
  90. assert func is not None
  91. return klass(
  92. obj,
  93. func,
  94. raw=raw,
  95. result_type=result_type,
  96. by_row=by_row,
  97. engine=engine,
  98. engine_kwargs=engine_kwargs,
  99. args=args,
  100. kwargs=kwargs,
  101. )
  102. class Apply(metaclass=abc.ABCMeta):
  103. axis: AxisInt
  104. def __init__(
  105. self,
  106. obj: AggObjType,
  107. func: AggFuncType,
  108. raw: bool,
  109. result_type: str | None,
  110. *,
  111. by_row: Literal[False, "compat", "_compat"] = "compat",
  112. engine: str = "python",
  113. engine_kwargs: dict[str, bool] | None = None,
  114. args,
  115. kwargs,
  116. ) -> None:
  117. self.obj = obj
  118. self.raw = raw
  119. assert by_row is False or by_row in ["compat", "_compat"]
  120. self.by_row = by_row
  121. self.args = args or ()
  122. self.kwargs = kwargs or {}
  123. self.engine = engine
  124. self.engine_kwargs = {} if engine_kwargs is None else engine_kwargs
  125. if result_type not in [None, "reduce", "broadcast", "expand"]:
  126. raise ValueError(
  127. "invalid value for result_type, must be one "
  128. "of {None, 'reduce', 'broadcast', 'expand'}"
  129. )
  130. self.result_type = result_type
  131. self.func = func
  132. @abc.abstractmethod
  133. def apply(self) -> DataFrame | Series:
  134. pass
  135. @abc.abstractmethod
  136. def agg_or_apply_list_like(
  137. self, op_name: Literal["agg", "apply"]
  138. ) -> DataFrame | Series:
  139. pass
  140. @abc.abstractmethod
  141. def agg_or_apply_dict_like(
  142. self, op_name: Literal["agg", "apply"]
  143. ) -> DataFrame | Series:
  144. pass
  145. def agg(self) -> DataFrame | Series | None:
  146. """
  147. Provide an implementation for the aggregators.
  148. Returns
  149. -------
  150. Result of aggregation, or None if agg cannot be performed by
  151. this method.
  152. """
  153. obj = self.obj
  154. func = self.func
  155. args = self.args
  156. kwargs = self.kwargs
  157. if isinstance(func, str):
  158. return self.apply_str()
  159. if is_dict_like(func):
  160. return self.agg_dict_like()
  161. elif is_list_like(func):
  162. # we require a list, but not a 'str'
  163. return self.agg_list_like()
  164. if callable(func):
  165. f = com.get_cython_func(func)
  166. if f and not args and not kwargs:
  167. warn_alias_replacement(obj, func, f)
  168. return getattr(obj, f)()
  169. # caller can react
  170. return None
  171. def transform(self) -> DataFrame | Series:
  172. """
  173. Transform a DataFrame or Series.
  174. Returns
  175. -------
  176. DataFrame or Series
  177. Result of applying ``func`` along the given axis of the
  178. Series or DataFrame.
  179. Raises
  180. ------
  181. ValueError
  182. If the transform function fails or does not transform.
  183. """
  184. obj = self.obj
  185. func = self.func
  186. axis = self.axis
  187. args = self.args
  188. kwargs = self.kwargs
  189. is_series = obj.ndim == 1
  190. if obj._get_axis_number(axis) == 1:
  191. assert not is_series
  192. return obj.T.transform(func, 0, *args, **kwargs).T
  193. if is_list_like(func) and not is_dict_like(func):
  194. func = cast(list[AggFuncTypeBase], func)
  195. # Convert func equivalent dict
  196. if is_series:
  197. func = {com.get_callable_name(v) or v: v for v in func}
  198. else:
  199. func = {col: func for col in obj}
  200. if is_dict_like(func):
  201. func = cast(AggFuncTypeDict, func)
  202. return self.transform_dict_like(func)
  203. # func is either str or callable
  204. func = cast(AggFuncTypeBase, func)
  205. try:
  206. result = self.transform_str_or_callable(func)
  207. except TypeError:
  208. raise
  209. except Exception as err:
  210. raise ValueError("Transform function failed") from err
  211. # Functions that transform may return empty Series/DataFrame
  212. # when the dtype is not appropriate
  213. if (
  214. isinstance(result, (ABCSeries, ABCDataFrame))
  215. and result.empty
  216. and not obj.empty
  217. ):
  218. raise ValueError("Transform function failed")
  219. # error: Argument 1 to "__get__" of "AxisProperty" has incompatible type
  220. # "Union[Series, DataFrame, GroupBy[Any], SeriesGroupBy,
  221. # DataFrameGroupBy, BaseWindow, Resampler]"; expected "Union[DataFrame,
  222. # Series]"
  223. if not isinstance(result, (ABCSeries, ABCDataFrame)) or not result.index.equals(
  224. obj.index # type: ignore[arg-type]
  225. ):
  226. raise ValueError("Function did not transform")
  227. return result
  228. def transform_dict_like(self, func) -> DataFrame:
  229. """
  230. Compute transform in the case of a dict-like func
  231. """
  232. from pandas.core.reshape.concat import concat
  233. obj = self.obj
  234. args = self.args
  235. kwargs = self.kwargs
  236. # transform is currently only for Series/DataFrame
  237. assert isinstance(obj, ABCNDFrame)
  238. if len(func) == 0:
  239. raise ValueError("No transform functions were provided")
  240. func = self.normalize_dictlike_arg("transform", obj, func)
  241. results: dict[Hashable, DataFrame | Series] = {}
  242. for name, how in func.items():
  243. colg = obj._gotitem(name, ndim=1)
  244. results[name] = colg.transform(how, 0, *args, **kwargs)
  245. return concat(results, axis=1)
  246. def transform_str_or_callable(self, func) -> DataFrame | Series:
  247. """
  248. Compute transform in the case of a string or callable func
  249. """
  250. obj = self.obj
  251. args = self.args
  252. kwargs = self.kwargs
  253. if isinstance(func, str):
  254. return self._apply_str(obj, func, *args, **kwargs)
  255. if not args and not kwargs:
  256. f = com.get_cython_func(func)
  257. if f:
  258. warn_alias_replacement(obj, func, f)
  259. return getattr(obj, f)()
  260. # Two possible ways to use a UDF - apply or call directly
  261. try:
  262. return obj.apply(func, args=args, **kwargs)
  263. except Exception:
  264. return func(obj, *args, **kwargs)
  265. def agg_list_like(self) -> DataFrame | Series:
  266. """
  267. Compute aggregation in the case of a list-like argument.
  268. Returns
  269. -------
  270. Result of aggregation.
  271. """
  272. return self.agg_or_apply_list_like(op_name="agg")
  273. def compute_list_like(
  274. self,
  275. op_name: Literal["agg", "apply"],
  276. selected_obj: Series | DataFrame,
  277. kwargs: dict[str, Any],
  278. ) -> tuple[list[Hashable] | Index, list[Any]]:
  279. """
  280. Compute agg/apply results for like-like input.
  281. Parameters
  282. ----------
  283. op_name : {"agg", "apply"}
  284. Operation being performed.
  285. selected_obj : Series or DataFrame
  286. Data to perform operation on.
  287. kwargs : dict
  288. Keyword arguments to pass to the functions.
  289. Returns
  290. -------
  291. keys : list[Hashable] or Index
  292. Index labels for result.
  293. results : list
  294. Data for result. When aggregating with a Series, this can contain any
  295. Python objects.
  296. """
  297. func = cast(list[AggFuncTypeBase], self.func)
  298. obj = self.obj
  299. results = []
  300. keys = []
  301. # degenerate case
  302. if selected_obj.ndim == 1:
  303. for a in func:
  304. colg = obj._gotitem(selected_obj.name, ndim=1, subset=selected_obj)
  305. args = (
  306. [self.axis, *self.args]
  307. if include_axis(op_name, colg)
  308. else self.args
  309. )
  310. new_res = getattr(colg, op_name)(a, *args, **kwargs)
  311. results.append(new_res)
  312. # make sure we find a good name
  313. name = com.get_callable_name(a) or a
  314. keys.append(name)
  315. else:
  316. indices = []
  317. for index, col in enumerate(selected_obj):
  318. colg = obj._gotitem(col, ndim=1, subset=selected_obj.iloc[:, index])
  319. args = (
  320. [self.axis, *self.args]
  321. if include_axis(op_name, colg)
  322. else self.args
  323. )
  324. new_res = getattr(colg, op_name)(func, *args, **kwargs)
  325. results.append(new_res)
  326. indices.append(index)
  327. # error: Incompatible types in assignment (expression has type "Any |
  328. # Index", variable has type "list[Any | Callable[..., Any] | str]")
  329. keys = selected_obj.columns.take(indices) # type: ignore[assignment]
  330. return keys, results
  331. def wrap_results_list_like(
  332. self, keys: Iterable[Hashable], results: list[Series | DataFrame]
  333. ):
  334. from pandas.core.reshape.concat import concat
  335. obj = self.obj
  336. try:
  337. return concat(results, keys=keys, axis=1, sort=False)
  338. except TypeError as err:
  339. # we are concatting non-NDFrame objects,
  340. # e.g. a list of scalars
  341. from pandas import Series
  342. result = Series(results, index=keys, name=obj.name)
  343. if is_nested_object(result):
  344. raise ValueError(
  345. "cannot combine transform and aggregation operations"
  346. ) from err
  347. return result
  348. def agg_dict_like(self) -> DataFrame | Series:
  349. """
  350. Compute aggregation in the case of a dict-like argument.
  351. Returns
  352. -------
  353. Result of aggregation.
  354. """
  355. return self.agg_or_apply_dict_like(op_name="agg")
  356. def compute_dict_like(
  357. self,
  358. op_name: Literal["agg", "apply"],
  359. selected_obj: Series | DataFrame,
  360. selection: Hashable | Sequence[Hashable],
  361. kwargs: dict[str, Any],
  362. ) -> tuple[list[Hashable], list[Any]]:
  363. """
  364. Compute agg/apply results for dict-like input.
  365. Parameters
  366. ----------
  367. op_name : {"agg", "apply"}
  368. Operation being performed.
  369. selected_obj : Series or DataFrame
  370. Data to perform operation on.
  371. selection : hashable or sequence of hashables
  372. Used by GroupBy, Window, and Resample if selection is applied to the object.
  373. kwargs : dict
  374. Keyword arguments to pass to the functions.
  375. Returns
  376. -------
  377. keys : list[hashable]
  378. Index labels for result.
  379. results : list
  380. Data for result. When aggregating with a Series, this can contain any
  381. Python object.
  382. """
  383. from pandas.core.groupby.generic import (
  384. DataFrameGroupBy,
  385. SeriesGroupBy,
  386. )
  387. obj = self.obj
  388. is_groupby = isinstance(obj, (DataFrameGroupBy, SeriesGroupBy))
  389. func = cast(AggFuncTypeDict, self.func)
  390. func = self.normalize_dictlike_arg(op_name, selected_obj, func)
  391. is_non_unique_col = (
  392. selected_obj.ndim == 2
  393. and selected_obj.columns.nunique() < len(selected_obj.columns)
  394. )
  395. if selected_obj.ndim == 1:
  396. # key only used for output
  397. colg = obj._gotitem(selection, ndim=1)
  398. results = [getattr(colg, op_name)(how, **kwargs) for _, how in func.items()]
  399. keys = list(func.keys())
  400. elif not is_groupby and is_non_unique_col:
  401. # key used for column selection and output
  402. # GH#51099
  403. results = []
  404. keys = []
  405. for key, how in func.items():
  406. indices = selected_obj.columns.get_indexer_for([key])
  407. labels = selected_obj.columns.take(indices)
  408. label_to_indices = defaultdict(list)
  409. for index, label in zip(indices, labels):
  410. label_to_indices[label].append(index)
  411. key_data = [
  412. getattr(selected_obj._ixs(indice, axis=1), op_name)(how, **kwargs)
  413. for label, indices in label_to_indices.items()
  414. for indice in indices
  415. ]
  416. keys += [key] * len(key_data)
  417. results += key_data
  418. else:
  419. # key used for column selection and output
  420. results = [
  421. getattr(obj._gotitem(key, ndim=1), op_name)(how, **kwargs)
  422. for key, how in func.items()
  423. ]
  424. keys = list(func.keys())
  425. return keys, results
  426. def wrap_results_dict_like(
  427. self,
  428. selected_obj: Series | DataFrame,
  429. result_index: list[Hashable],
  430. result_data: list,
  431. ):
  432. from pandas import Index
  433. from pandas.core.reshape.concat import concat
  434. obj = self.obj
  435. # Avoid making two isinstance calls in all and any below
  436. is_ndframe = [isinstance(r, ABCNDFrame) for r in result_data]
  437. if all(is_ndframe):
  438. results = dict(zip(result_index, result_data))
  439. keys_to_use: Iterable[Hashable]
  440. keys_to_use = [k for k in result_index if not results[k].empty]
  441. # Have to check, if at least one DataFrame is not empty.
  442. keys_to_use = keys_to_use if keys_to_use != [] else result_index
  443. if selected_obj.ndim == 2:
  444. # keys are columns, so we can preserve names
  445. ktu = Index(keys_to_use)
  446. ktu._set_names(selected_obj.columns.names)
  447. keys_to_use = ktu
  448. axis: AxisInt = 0 if isinstance(obj, ABCSeries) else 1
  449. result = concat(
  450. {k: results[k] for k in keys_to_use},
  451. axis=axis,
  452. keys=keys_to_use,
  453. )
  454. elif any(is_ndframe):
  455. # There is a mix of NDFrames and scalars
  456. raise ValueError(
  457. "cannot perform both aggregation "
  458. "and transformation operations "
  459. "simultaneously"
  460. )
  461. else:
  462. from pandas import Series
  463. # we have a list of scalars
  464. # GH 36212 use name only if obj is a series
  465. if obj.ndim == 1:
  466. obj = cast("Series", obj)
  467. name = obj.name
  468. else:
  469. name = None
  470. result = Series(result_data, index=result_index, name=name)
  471. return result
  472. def apply_str(self) -> DataFrame | Series:
  473. """
  474. Compute apply in case of a string.
  475. Returns
  476. -------
  477. result: Series or DataFrame
  478. """
  479. # Caller is responsible for checking isinstance(self.f, str)
  480. func = cast(str, self.func)
  481. obj = self.obj
  482. from pandas.core.groupby.generic import (
  483. DataFrameGroupBy,
  484. SeriesGroupBy,
  485. )
  486. # Support for `frame.transform('method')`
  487. # Some methods (shift, etc.) require the axis argument, others
  488. # don't, so inspect and insert if necessary.
  489. method = getattr(obj, func, None)
  490. if callable(method):
  491. sig = inspect.getfullargspec(method)
  492. arg_names = (*sig.args, *sig.kwonlyargs)
  493. if self.axis != 0 and (
  494. "axis" not in arg_names or func in ("corrwith", "skew")
  495. ):
  496. raise ValueError(f"Operation {func} does not support axis=1")
  497. if "axis" in arg_names:
  498. if isinstance(obj, (SeriesGroupBy, DataFrameGroupBy)):
  499. # Try to avoid FutureWarning for deprecated axis keyword;
  500. # If self.axis matches the axis we would get by not passing
  501. # axis, we safely exclude the keyword.
  502. default_axis = 0
  503. if func in ["idxmax", "idxmin"]:
  504. # DataFrameGroupBy.idxmax, idxmin axis defaults to self.axis,
  505. # whereas other axis keywords default to 0
  506. default_axis = self.obj.axis
  507. if default_axis != self.axis:
  508. self.kwargs["axis"] = self.axis
  509. else:
  510. self.kwargs["axis"] = self.axis
  511. return self._apply_str(obj, func, *self.args, **self.kwargs)
  512. def apply_list_or_dict_like(self) -> DataFrame | Series:
  513. """
  514. Compute apply in case of a list-like or dict-like.
  515. Returns
  516. -------
  517. result: Series, DataFrame, or None
  518. Result when self.func is a list-like or dict-like, None otherwise.
  519. """
  520. if self.engine == "numba":
  521. raise NotImplementedError(
  522. "The 'numba' engine doesn't support list-like/"
  523. "dict likes of callables yet."
  524. )
  525. if self.axis == 1 and isinstance(self.obj, ABCDataFrame):
  526. return self.obj.T.apply(self.func, 0, args=self.args, **self.kwargs).T
  527. func = self.func
  528. kwargs = self.kwargs
  529. if is_dict_like(func):
  530. result = self.agg_or_apply_dict_like(op_name="apply")
  531. else:
  532. result = self.agg_or_apply_list_like(op_name="apply")
  533. result = reconstruct_and_relabel_result(result, func, **kwargs)
  534. return result
  535. def normalize_dictlike_arg(
  536. self, how: str, obj: DataFrame | Series, func: AggFuncTypeDict
  537. ) -> AggFuncTypeDict:
  538. """
  539. Handler for dict-like argument.
  540. Ensures that necessary columns exist if obj is a DataFrame, and
  541. that a nested renamer is not passed. Also normalizes to all lists
  542. when values consists of a mix of list and non-lists.
  543. """
  544. assert how in ("apply", "agg", "transform")
  545. # Can't use func.values(); wouldn't work for a Series
  546. if (
  547. how == "agg"
  548. and isinstance(obj, ABCSeries)
  549. and any(is_list_like(v) for _, v in func.items())
  550. ) or (any(is_dict_like(v) for _, v in func.items())):
  551. # GH 15931 - deprecation of renaming keys
  552. raise SpecificationError("nested renamer is not supported")
  553. if obj.ndim != 1:
  554. # Check for missing columns on a frame
  555. from pandas import Index
  556. cols = Index(list(func.keys())).difference(obj.columns, sort=True)
  557. if len(cols) > 0:
  558. raise KeyError(f"Column(s) {list(cols)} do not exist")
  559. aggregator_types = (list, tuple, dict)
  560. # if we have a dict of any non-scalars
  561. # eg. {'A' : ['mean']}, normalize all to
  562. # be list-likes
  563. # Cannot use func.values() because arg may be a Series
  564. if any(isinstance(x, aggregator_types) for _, x in func.items()):
  565. new_func: AggFuncTypeDict = {}
  566. for k, v in func.items():
  567. if not isinstance(v, aggregator_types):
  568. new_func[k] = [v]
  569. else:
  570. new_func[k] = v
  571. func = new_func
  572. return func
  573. def _apply_str(self, obj, func: str, *args, **kwargs):
  574. """
  575. if arg is a string, then try to operate on it:
  576. - try to find a function (or attribute) on obj
  577. - try to find a numpy function
  578. - raise
  579. """
  580. assert isinstance(func, str)
  581. if hasattr(obj, func):
  582. f = getattr(obj, func)
  583. if callable(f):
  584. return f(*args, **kwargs)
  585. # people may aggregate on a non-callable attribute
  586. # but don't let them think they can pass args to it
  587. assert len(args) == 0
  588. assert len([kwarg for kwarg in kwargs if kwarg not in ["axis"]]) == 0
  589. return f
  590. elif hasattr(np, func) and hasattr(obj, "__array__"):
  591. # in particular exclude Window
  592. f = getattr(np, func)
  593. return f(obj, *args, **kwargs)
  594. else:
  595. msg = f"'{func}' is not a valid function for '{type(obj).__name__}' object"
  596. raise AttributeError(msg)
  597. class NDFrameApply(Apply):
  598. """
  599. Methods shared by FrameApply and SeriesApply but
  600. not GroupByApply or ResamplerWindowApply
  601. """
  602. obj: DataFrame | Series
  603. @property
  604. def index(self) -> Index:
  605. return self.obj.index
  606. @property
  607. def agg_axis(self) -> Index:
  608. return self.obj._get_agg_axis(self.axis)
  609. def agg_or_apply_list_like(
  610. self, op_name: Literal["agg", "apply"]
  611. ) -> DataFrame | Series:
  612. obj = self.obj
  613. kwargs = self.kwargs
  614. if op_name == "apply":
  615. if isinstance(self, FrameApply):
  616. by_row = self.by_row
  617. elif isinstance(self, SeriesApply):
  618. by_row = "_compat" if self.by_row else False
  619. else:
  620. by_row = False
  621. kwargs = {**kwargs, "by_row": by_row}
  622. if getattr(obj, "axis", 0) == 1:
  623. raise NotImplementedError("axis other than 0 is not supported")
  624. keys, results = self.compute_list_like(op_name, obj, kwargs)
  625. result = self.wrap_results_list_like(keys, results)
  626. return result
  627. def agg_or_apply_dict_like(
  628. self, op_name: Literal["agg", "apply"]
  629. ) -> DataFrame | Series:
  630. assert op_name in ["agg", "apply"]
  631. obj = self.obj
  632. kwargs = {}
  633. if op_name == "apply":
  634. by_row = "_compat" if self.by_row else False
  635. kwargs.update({"by_row": by_row})
  636. if getattr(obj, "axis", 0) == 1:
  637. raise NotImplementedError("axis other than 0 is not supported")
  638. selection = None
  639. result_index, result_data = self.compute_dict_like(
  640. op_name, obj, selection, kwargs
  641. )
  642. result = self.wrap_results_dict_like(obj, result_index, result_data)
  643. return result
  644. class FrameApply(NDFrameApply):
  645. obj: DataFrame
  646. def __init__(
  647. self,
  648. obj: AggObjType,
  649. func: AggFuncType,
  650. raw: bool,
  651. result_type: str | None,
  652. *,
  653. by_row: Literal[False, "compat"] = False,
  654. engine: str = "python",
  655. engine_kwargs: dict[str, bool] | None = None,
  656. args,
  657. kwargs,
  658. ) -> None:
  659. if by_row is not False and by_row != "compat":
  660. raise ValueError(f"by_row={by_row} not allowed")
  661. super().__init__(
  662. obj,
  663. func,
  664. raw,
  665. result_type,
  666. by_row=by_row,
  667. engine=engine,
  668. engine_kwargs=engine_kwargs,
  669. args=args,
  670. kwargs=kwargs,
  671. )
  672. # ---------------------------------------------------------------
  673. # Abstract Methods
  674. @property
  675. @abc.abstractmethod
  676. def result_index(self) -> Index:
  677. pass
  678. @property
  679. @abc.abstractmethod
  680. def result_columns(self) -> Index:
  681. pass
  682. @property
  683. @abc.abstractmethod
  684. def series_generator(self) -> Generator[Series, None, None]:
  685. pass
  686. @staticmethod
  687. @functools.cache
  688. @abc.abstractmethod
  689. def generate_numba_apply_func(
  690. func, nogil=True, nopython=True, parallel=False
  691. ) -> Callable[[npt.NDArray, Index, Index], dict[int, Any]]:
  692. pass
  693. @abc.abstractmethod
  694. def apply_with_numba(self):
  695. pass
  696. def validate_values_for_numba(self):
  697. # Validate column dtyps all OK
  698. for colname, dtype in self.obj.dtypes.items():
  699. if not is_numeric_dtype(dtype):
  700. raise ValueError(
  701. f"Column {colname} must have a numeric dtype. "
  702. f"Found '{dtype}' instead"
  703. )
  704. if is_extension_array_dtype(dtype):
  705. raise ValueError(
  706. f"Column {colname} is backed by an extension array, "
  707. f"which is not supported by the numba engine."
  708. )
  709. @abc.abstractmethod
  710. def wrap_results_for_axis(
  711. self, results: ResType, res_index: Index
  712. ) -> DataFrame | Series:
  713. pass
  714. # ---------------------------------------------------------------
  715. @property
  716. def res_columns(self) -> Index:
  717. return self.result_columns
  718. @property
  719. def columns(self) -> Index:
  720. return self.obj.columns
  721. @cache_readonly
  722. def values(self):
  723. return self.obj.values
  724. def apply(self) -> DataFrame | Series:
  725. """compute the results"""
  726. # dispatch to handle list-like or dict-like
  727. if is_list_like(self.func):
  728. if self.engine == "numba":
  729. raise NotImplementedError(
  730. "the 'numba' engine doesn't support lists of callables yet"
  731. )
  732. return self.apply_list_or_dict_like()
  733. # all empty
  734. if len(self.columns) == 0 and len(self.index) == 0:
  735. return self.apply_empty_result()
  736. # string dispatch
  737. if isinstance(self.func, str):
  738. if self.engine == "numba":
  739. raise NotImplementedError(
  740. "the 'numba' engine doesn't support using "
  741. "a string as the callable function"
  742. )
  743. return self.apply_str()
  744. # ufunc
  745. elif isinstance(self.func, np.ufunc):
  746. if self.engine == "numba":
  747. raise NotImplementedError(
  748. "the 'numba' engine doesn't support "
  749. "using a numpy ufunc as the callable function"
  750. )
  751. with np.errstate(all="ignore"):
  752. results = self.obj._mgr.apply("apply", func=self.func)
  753. # _constructor will retain self.index and self.columns
  754. return self.obj._constructor_from_mgr(results, axes=results.axes)
  755. # broadcasting
  756. if self.result_type == "broadcast":
  757. if self.engine == "numba":
  758. raise NotImplementedError(
  759. "the 'numba' engine doesn't support result_type='broadcast'"
  760. )
  761. return self.apply_broadcast(self.obj)
  762. # one axis empty
  763. elif not all(self.obj.shape):
  764. return self.apply_empty_result()
  765. # raw
  766. elif self.raw:
  767. return self.apply_raw(engine=self.engine, engine_kwargs=self.engine_kwargs)
  768. return self.apply_standard()
  769. def agg(self):
  770. obj = self.obj
  771. axis = self.axis
  772. # TODO: Avoid having to change state
  773. self.obj = self.obj if self.axis == 0 else self.obj.T
  774. self.axis = 0
  775. result = None
  776. try:
  777. result = super().agg()
  778. finally:
  779. self.obj = obj
  780. self.axis = axis
  781. if axis == 1:
  782. result = result.T if result is not None else result
  783. if result is None:
  784. result = self.obj.apply(self.func, axis, args=self.args, **self.kwargs)
  785. return result
  786. def apply_empty_result(self):
  787. """
  788. we have an empty result; at least 1 axis is 0
  789. we will try to apply the function to an empty
  790. series in order to see if this is a reduction function
  791. """
  792. assert callable(self.func)
  793. # we are not asked to reduce or infer reduction
  794. # so just return a copy of the existing object
  795. if self.result_type not in ["reduce", None]:
  796. return self.obj.copy()
  797. # we may need to infer
  798. should_reduce = self.result_type == "reduce"
  799. from pandas import Series
  800. if not should_reduce:
  801. try:
  802. if self.axis == 0:
  803. r = self.func(
  804. Series([], dtype=np.float64), *self.args, **self.kwargs
  805. )
  806. else:
  807. r = self.func(
  808. Series(index=self.columns, dtype=np.float64),
  809. *self.args,
  810. **self.kwargs,
  811. )
  812. except Exception:
  813. pass
  814. else:
  815. should_reduce = not isinstance(r, Series)
  816. if should_reduce:
  817. if len(self.agg_axis):
  818. r = self.func(Series([], dtype=np.float64), *self.args, **self.kwargs)
  819. else:
  820. r = np.nan
  821. return self.obj._constructor_sliced(r, index=self.agg_axis)
  822. else:
  823. return self.obj.copy()
  824. def apply_raw(self, engine="python", engine_kwargs=None):
  825. """apply to the values as a numpy array"""
  826. def wrap_function(func):
  827. """
  828. Wrap user supplied function to work around numpy issue.
  829. see https://github.com/numpy/numpy/issues/8352
  830. """
  831. def wrapper(*args, **kwargs):
  832. result = func(*args, **kwargs)
  833. if isinstance(result, str):
  834. result = np.array(result, dtype=object)
  835. return result
  836. return wrapper
  837. if engine == "numba":
  838. engine_kwargs = {} if engine_kwargs is None else engine_kwargs
  839. # error: Argument 1 to "__call__" of "_lru_cache_wrapper" has
  840. # incompatible type "Callable[..., Any] | str | list[Callable
  841. # [..., Any] | str] | dict[Hashable,Callable[..., Any] | str |
  842. # list[Callable[..., Any] | str]]"; expected "Hashable"
  843. nb_looper = generate_apply_looper(
  844. self.func, **engine_kwargs # type: ignore[arg-type]
  845. )
  846. result = nb_looper(self.values, self.axis)
  847. # If we made the result 2-D, squeeze it back to 1-D
  848. result = np.squeeze(result)
  849. else:
  850. result = np.apply_along_axis(
  851. wrap_function(self.func),
  852. self.axis,
  853. self.values,
  854. *self.args,
  855. **self.kwargs,
  856. )
  857. # TODO: mixed type case
  858. if result.ndim == 2:
  859. return self.obj._constructor(result, index=self.index, columns=self.columns)
  860. else:
  861. return self.obj._constructor_sliced(result, index=self.agg_axis)
  862. def apply_broadcast(self, target: DataFrame) -> DataFrame:
  863. assert callable(self.func)
  864. result_values = np.empty_like(target.values)
  865. # axis which we want to compare compliance
  866. result_compare = target.shape[0]
  867. for i, col in enumerate(target.columns):
  868. res = self.func(target[col], *self.args, **self.kwargs)
  869. ares = np.asarray(res).ndim
  870. # must be a scalar or 1d
  871. if ares > 1:
  872. raise ValueError("too many dims to broadcast")
  873. if ares == 1:
  874. # must match return dim
  875. if result_compare != len(res):
  876. raise ValueError("cannot broadcast result")
  877. result_values[:, i] = res
  878. # we *always* preserve the original index / columns
  879. result = self.obj._constructor(
  880. result_values, index=target.index, columns=target.columns
  881. )
  882. return result
  883. def apply_standard(self):
  884. if self.engine == "python":
  885. results, res_index = self.apply_series_generator()
  886. else:
  887. results, res_index = self.apply_series_numba()
  888. # wrap results
  889. return self.wrap_results(results, res_index)
  890. def apply_series_generator(self) -> tuple[ResType, Index]:
  891. assert callable(self.func)
  892. series_gen = self.series_generator
  893. res_index = self.result_index
  894. results = {}
  895. with option_context("mode.chained_assignment", None):
  896. for i, v in enumerate(series_gen):
  897. # ignore SettingWithCopy here in case the user mutates
  898. results[i] = self.func(v, *self.args, **self.kwargs)
  899. if isinstance(results[i], ABCSeries):
  900. # If we have a view on v, we need to make a copy because
  901. # series_generator will swap out the underlying data
  902. results[i] = results[i].copy(deep=False)
  903. return results, res_index
  904. def apply_series_numba(self):
  905. if self.engine_kwargs.get("parallel", False):
  906. raise NotImplementedError(
  907. "Parallel apply is not supported when raw=False and engine='numba'"
  908. )
  909. if not self.obj.index.is_unique or not self.columns.is_unique:
  910. raise NotImplementedError(
  911. "The index/columns must be unique when raw=False and engine='numba'"
  912. )
  913. self.validate_values_for_numba()
  914. results = self.apply_with_numba()
  915. return results, self.result_index
  916. def wrap_results(self, results: ResType, res_index: Index) -> DataFrame | Series:
  917. from pandas import Series
  918. # see if we can infer the results
  919. if len(results) > 0 and 0 in results and is_sequence(results[0]):
  920. return self.wrap_results_for_axis(results, res_index)
  921. # dict of scalars
  922. # the default dtype of an empty Series is `object`, but this
  923. # code can be hit by df.mean() where the result should have dtype
  924. # float64 even if it's an empty Series.
  925. constructor_sliced = self.obj._constructor_sliced
  926. if len(results) == 0 and constructor_sliced is Series:
  927. result = constructor_sliced(results, dtype=np.float64)
  928. else:
  929. result = constructor_sliced(results)
  930. result.index = res_index
  931. return result
  932. def apply_str(self) -> DataFrame | Series:
  933. # Caller is responsible for checking isinstance(self.func, str)
  934. # TODO: GH#39993 - Avoid special-casing by replacing with lambda
  935. if self.func == "size":
  936. # Special-cased because DataFrame.size returns a single scalar
  937. obj = self.obj
  938. value = obj.shape[self.axis]
  939. return obj._constructor_sliced(value, index=self.agg_axis)
  940. return super().apply_str()
  941. class FrameRowApply(FrameApply):
  942. axis: AxisInt = 0
  943. @property
  944. def series_generator(self) -> Generator[Series, None, None]:
  945. return (self.obj._ixs(i, axis=1) for i in range(len(self.columns)))
  946. @staticmethod
  947. @functools.cache
  948. def generate_numba_apply_func(
  949. func, nogil=True, nopython=True, parallel=False
  950. ) -> Callable[[npt.NDArray, Index, Index], dict[int, Any]]:
  951. numba = import_optional_dependency("numba")
  952. from pandas import Series
  953. # Import helper from extensions to cast string object -> np strings
  954. # Note: This also has the side effect of loading our numba extensions
  955. from pandas.core._numba.extensions import maybe_cast_str
  956. jitted_udf = numba.extending.register_jitable(func)
  957. # Currently the parallel argument doesn't get passed through here
  958. # (it's disabled) since the dicts in numba aren't thread-safe.
  959. @numba.jit(nogil=nogil, nopython=nopython, parallel=parallel)
  960. def numba_func(values, col_names, df_index):
  961. results = {}
  962. for j in range(values.shape[1]):
  963. # Create the series
  964. ser = Series(
  965. values[:, j], index=df_index, name=maybe_cast_str(col_names[j])
  966. )
  967. results[j] = jitted_udf(ser)
  968. return results
  969. return numba_func
  970. def apply_with_numba(self) -> dict[int, Any]:
  971. nb_func = self.generate_numba_apply_func(
  972. cast(Callable, self.func), **self.engine_kwargs
  973. )
  974. from pandas.core._numba.extensions import set_numba_data
  975. index = self.obj.index
  976. columns = self.obj.columns
  977. # Convert from numba dict to regular dict
  978. # Our isinstance checks in the df constructor don't pass for numbas typed dict
  979. with set_numba_data(index) as index, set_numba_data(columns) as columns:
  980. res = dict(nb_func(self.values, columns, index))
  981. return res
  982. @property
  983. def result_index(self) -> Index:
  984. return self.columns
  985. @property
  986. def result_columns(self) -> Index:
  987. return self.index
  988. def wrap_results_for_axis(
  989. self, results: ResType, res_index: Index
  990. ) -> DataFrame | Series:
  991. """return the results for the rows"""
  992. if self.result_type == "reduce":
  993. # e.g. test_apply_dict GH#8735
  994. res = self.obj._constructor_sliced(results)
  995. res.index = res_index
  996. return res
  997. elif self.result_type is None and all(
  998. isinstance(x, dict) for x in results.values()
  999. ):
  1000. # Our operation was a to_dict op e.g.
  1001. # test_apply_dict GH#8735, test_apply_reduce_to_dict GH#25196 #37544
  1002. res = self.obj._constructor_sliced(results)
  1003. res.index = res_index
  1004. return res
  1005. try:
  1006. result = self.obj._constructor(data=results)
  1007. except ValueError as err:
  1008. if "All arrays must be of the same length" in str(err):
  1009. # e.g. result = [[2, 3], [1.5], ['foo', 'bar']]
  1010. # see test_agg_listlike_result GH#29587
  1011. res = self.obj._constructor_sliced(results)
  1012. res.index = res_index
  1013. return res
  1014. else:
  1015. raise
  1016. if not isinstance(results[0], ABCSeries):
  1017. if len(result.index) == len(self.res_columns):
  1018. result.index = self.res_columns
  1019. if len(result.columns) == len(res_index):
  1020. result.columns = res_index
  1021. return result
  1022. class FrameColumnApply(FrameApply):
  1023. axis: AxisInt = 1
  1024. def apply_broadcast(self, target: DataFrame) -> DataFrame:
  1025. result = super().apply_broadcast(target.T)
  1026. return result.T
  1027. @property
  1028. def series_generator(self) -> Generator[Series, None, None]:
  1029. values = self.values
  1030. values = ensure_wrapped_if_datetimelike(values)
  1031. assert len(values) > 0
  1032. # We create one Series object, and will swap out the data inside
  1033. # of it. Kids: don't do this at home.
  1034. ser = self.obj._ixs(0, axis=0)
  1035. mgr = ser._mgr
  1036. is_view = mgr.blocks[0].refs.has_reference() # type: ignore[union-attr]
  1037. if isinstance(ser.dtype, ExtensionDtype):
  1038. # values will be incorrect for this block
  1039. # TODO(EA2D): special case would be unnecessary with 2D EAs
  1040. obj = self.obj
  1041. for i in range(len(obj)):
  1042. yield obj._ixs(i, axis=0)
  1043. else:
  1044. for arr, name in zip(values, self.index):
  1045. # GH#35462 re-pin mgr in case setitem changed it
  1046. ser._mgr = mgr
  1047. mgr.set_values(arr)
  1048. object.__setattr__(ser, "_name", name)
  1049. if not is_view:
  1050. # In apply_series_generator we store the a shallow copy of the
  1051. # result, which potentially increases the ref count of this reused
  1052. # `ser` object (depending on the result of the applied function)
  1053. # -> if that happened and `ser` is already a copy, then we reset
  1054. # the refs here to avoid triggering a unnecessary CoW inside the
  1055. # applied function (https://github.com/pandas-dev/pandas/pull/56212)
  1056. mgr.blocks[0].refs = BlockValuesRefs(mgr.blocks[0]) # type: ignore[union-attr]
  1057. yield ser
  1058. @staticmethod
  1059. @functools.cache
  1060. def generate_numba_apply_func(
  1061. func, nogil=True, nopython=True, parallel=False
  1062. ) -> Callable[[npt.NDArray, Index, Index], dict[int, Any]]:
  1063. numba = import_optional_dependency("numba")
  1064. from pandas import Series
  1065. from pandas.core._numba.extensions import maybe_cast_str
  1066. jitted_udf = numba.extending.register_jitable(func)
  1067. @numba.jit(nogil=nogil, nopython=nopython, parallel=parallel)
  1068. def numba_func(values, col_names_index, index):
  1069. results = {}
  1070. # Currently the parallel argument doesn't get passed through here
  1071. # (it's disabled) since the dicts in numba aren't thread-safe.
  1072. for i in range(values.shape[0]):
  1073. # Create the series
  1074. # TODO: values corrupted without the copy
  1075. ser = Series(
  1076. values[i].copy(),
  1077. index=col_names_index,
  1078. name=maybe_cast_str(index[i]),
  1079. )
  1080. results[i] = jitted_udf(ser)
  1081. return results
  1082. return numba_func
  1083. def apply_with_numba(self) -> dict[int, Any]:
  1084. nb_func = self.generate_numba_apply_func(
  1085. cast(Callable, self.func), **self.engine_kwargs
  1086. )
  1087. from pandas.core._numba.extensions import set_numba_data
  1088. # Convert from numba dict to regular dict
  1089. # Our isinstance checks in the df constructor don't pass for numbas typed dict
  1090. with set_numba_data(self.obj.index) as index, set_numba_data(
  1091. self.columns
  1092. ) as columns:
  1093. res = dict(nb_func(self.values, columns, index))
  1094. return res
  1095. @property
  1096. def result_index(self) -> Index:
  1097. return self.index
  1098. @property
  1099. def result_columns(self) -> Index:
  1100. return self.columns
  1101. def wrap_results_for_axis(
  1102. self, results: ResType, res_index: Index
  1103. ) -> DataFrame | Series:
  1104. """return the results for the columns"""
  1105. result: DataFrame | Series
  1106. # we have requested to expand
  1107. if self.result_type == "expand":
  1108. result = self.infer_to_same_shape(results, res_index)
  1109. # we have a non-series and don't want inference
  1110. elif not isinstance(results[0], ABCSeries):
  1111. result = self.obj._constructor_sliced(results)
  1112. result.index = res_index
  1113. # we may want to infer results
  1114. else:
  1115. result = self.infer_to_same_shape(results, res_index)
  1116. return result
  1117. def infer_to_same_shape(self, results: ResType, res_index: Index) -> DataFrame:
  1118. """infer the results to the same shape as the input object"""
  1119. result = self.obj._constructor(data=results)
  1120. result = result.T
  1121. # set the index
  1122. result.index = res_index
  1123. # infer dtypes
  1124. result = result.infer_objects(copy=False)
  1125. return result
  1126. class SeriesApply(NDFrameApply):
  1127. obj: Series
  1128. axis: AxisInt = 0
  1129. by_row: Literal[False, "compat", "_compat"] # only relevant for apply()
  1130. def __init__(
  1131. self,
  1132. obj: Series,
  1133. func: AggFuncType,
  1134. *,
  1135. convert_dtype: bool | lib.NoDefault = lib.no_default,
  1136. by_row: Literal[False, "compat", "_compat"] = "compat",
  1137. args,
  1138. kwargs,
  1139. ) -> None:
  1140. if convert_dtype is lib.no_default:
  1141. convert_dtype = True
  1142. else:
  1143. warnings.warn(
  1144. "the convert_dtype parameter is deprecated and will be removed in a "
  1145. "future version. Do ``ser.astype(object).apply()`` "
  1146. "instead if you want ``convert_dtype=False``.",
  1147. FutureWarning,
  1148. stacklevel=find_stack_level(),
  1149. )
  1150. self.convert_dtype = convert_dtype
  1151. super().__init__(
  1152. obj,
  1153. func,
  1154. raw=False,
  1155. result_type=None,
  1156. by_row=by_row,
  1157. args=args,
  1158. kwargs=kwargs,
  1159. )
  1160. def apply(self) -> DataFrame | Series:
  1161. obj = self.obj
  1162. if len(obj) == 0:
  1163. return self.apply_empty_result()
  1164. # dispatch to handle list-like or dict-like
  1165. if is_list_like(self.func):
  1166. return self.apply_list_or_dict_like()
  1167. if isinstance(self.func, str):
  1168. # if we are a string, try to dispatch
  1169. return self.apply_str()
  1170. if self.by_row == "_compat":
  1171. return self.apply_compat()
  1172. # self.func is Callable
  1173. return self.apply_standard()
  1174. def agg(self):
  1175. result = super().agg()
  1176. if result is None:
  1177. obj = self.obj
  1178. func = self.func
  1179. # string, list-like, and dict-like are entirely handled in super
  1180. assert callable(func)
  1181. # GH53325: The setup below is just to keep current behavior while emitting a
  1182. # deprecation message. In the future this will all be replaced with a simple
  1183. # `result = f(self.obj, *self.args, **self.kwargs)`.
  1184. try:
  1185. result = obj.apply(func, args=self.args, **self.kwargs)
  1186. except (ValueError, AttributeError, TypeError):
  1187. result = func(obj, *self.args, **self.kwargs)
  1188. else:
  1189. msg = (
  1190. f"using {func} in {type(obj).__name__}.agg cannot aggregate and "
  1191. f"has been deprecated. Use {type(obj).__name__}.transform to "
  1192. f"keep behavior unchanged."
  1193. )
  1194. warnings.warn(msg, FutureWarning, stacklevel=find_stack_level())
  1195. return result
  1196. def apply_empty_result(self) -> Series:
  1197. obj = self.obj
  1198. return obj._constructor(dtype=obj.dtype, index=obj.index).__finalize__(
  1199. obj, method="apply"
  1200. )
  1201. def apply_compat(self):
  1202. """compat apply method for funcs in listlikes and dictlikes.
  1203. Used for each callable when giving listlikes and dictlikes of callables to
  1204. apply. Needed for compatibility with Pandas < v2.1.
  1205. .. versionadded:: 2.1.0
  1206. """
  1207. obj = self.obj
  1208. func = self.func
  1209. if callable(func):
  1210. f = com.get_cython_func(func)
  1211. if f and not self.args and not self.kwargs:
  1212. return obj.apply(func, by_row=False)
  1213. try:
  1214. result = obj.apply(func, by_row="compat")
  1215. except (ValueError, AttributeError, TypeError):
  1216. result = obj.apply(func, by_row=False)
  1217. return result
  1218. def apply_standard(self) -> DataFrame | Series:
  1219. # caller is responsible for ensuring that f is Callable
  1220. func = cast(Callable, self.func)
  1221. obj = self.obj
  1222. if isinstance(func, np.ufunc):
  1223. with np.errstate(all="ignore"):
  1224. return func(obj, *self.args, **self.kwargs)
  1225. elif not self.by_row:
  1226. return func(obj, *self.args, **self.kwargs)
  1227. if self.args or self.kwargs:
  1228. # _map_values does not support args/kwargs
  1229. def curried(x):
  1230. return func(x, *self.args, **self.kwargs)
  1231. else:
  1232. curried = func
  1233. # row-wise access
  1234. # apply doesn't have a `na_action` keyword and for backward compat reasons
  1235. # we need to give `na_action="ignore"` for categorical data.
  1236. # TODO: remove the `na_action="ignore"` when that default has been changed in
  1237. # Categorical (GH51645).
  1238. action = "ignore" if isinstance(obj.dtype, CategoricalDtype) else None
  1239. mapped = obj._map_values(
  1240. mapper=curried, na_action=action, convert=self.convert_dtype
  1241. )
  1242. if len(mapped) and isinstance(mapped[0], ABCSeries):
  1243. # GH#43986 Need to do list(mapped) in order to get treated as nested
  1244. # See also GH#25959 regarding EA support
  1245. return obj._constructor_expanddim(list(mapped), index=obj.index)
  1246. else:
  1247. return obj._constructor(mapped, index=obj.index).__finalize__(
  1248. obj, method="apply"
  1249. )
  1250. class GroupByApply(Apply):
  1251. obj: GroupBy | Resampler | BaseWindow
  1252. def __init__(
  1253. self,
  1254. obj: GroupBy[NDFrameT],
  1255. func: AggFuncType,
  1256. *,
  1257. args,
  1258. kwargs,
  1259. ) -> None:
  1260. kwargs = kwargs.copy()
  1261. self.axis = obj.obj._get_axis_number(kwargs.get("axis", 0))
  1262. super().__init__(
  1263. obj,
  1264. func,
  1265. raw=False,
  1266. result_type=None,
  1267. args=args,
  1268. kwargs=kwargs,
  1269. )
  1270. def apply(self):
  1271. raise NotImplementedError
  1272. def transform(self):
  1273. raise NotImplementedError
  1274. def agg_or_apply_list_like(
  1275. self, op_name: Literal["agg", "apply"]
  1276. ) -> DataFrame | Series:
  1277. obj = self.obj
  1278. kwargs = self.kwargs
  1279. if op_name == "apply":
  1280. kwargs = {**kwargs, "by_row": False}
  1281. if getattr(obj, "axis", 0) == 1:
  1282. raise NotImplementedError("axis other than 0 is not supported")
  1283. if obj._selected_obj.ndim == 1:
  1284. # For SeriesGroupBy this matches _obj_with_exclusions
  1285. selected_obj = obj._selected_obj
  1286. else:
  1287. selected_obj = obj._obj_with_exclusions
  1288. # Only set as_index=True on groupby objects, not Window or Resample
  1289. # that inherit from this class.
  1290. with com.temp_setattr(
  1291. obj, "as_index", True, condition=hasattr(obj, "as_index")
  1292. ):
  1293. keys, results = self.compute_list_like(op_name, selected_obj, kwargs)
  1294. result = self.wrap_results_list_like(keys, results)
  1295. return result
  1296. def agg_or_apply_dict_like(
  1297. self, op_name: Literal["agg", "apply"]
  1298. ) -> DataFrame | Series:
  1299. from pandas.core.groupby.generic import (
  1300. DataFrameGroupBy,
  1301. SeriesGroupBy,
  1302. )
  1303. assert op_name in ["agg", "apply"]
  1304. obj = self.obj
  1305. kwargs = {}
  1306. if op_name == "apply":
  1307. by_row = "_compat" if self.by_row else False
  1308. kwargs.update({"by_row": by_row})
  1309. if getattr(obj, "axis", 0) == 1:
  1310. raise NotImplementedError("axis other than 0 is not supported")
  1311. selected_obj = obj._selected_obj
  1312. selection = obj._selection
  1313. is_groupby = isinstance(obj, (DataFrameGroupBy, SeriesGroupBy))
  1314. # Numba Groupby engine/engine-kwargs passthrough
  1315. if is_groupby:
  1316. engine = self.kwargs.get("engine", None)
  1317. engine_kwargs = self.kwargs.get("engine_kwargs", None)
  1318. kwargs.update({"engine": engine, "engine_kwargs": engine_kwargs})
  1319. with com.temp_setattr(
  1320. obj, "as_index", True, condition=hasattr(obj, "as_index")
  1321. ):
  1322. result_index, result_data = self.compute_dict_like(
  1323. op_name, selected_obj, selection, kwargs
  1324. )
  1325. result = self.wrap_results_dict_like(selected_obj, result_index, result_data)
  1326. return result
  1327. class ResamplerWindowApply(GroupByApply):
  1328. axis: AxisInt = 0
  1329. obj: Resampler | BaseWindow
  1330. def __init__(
  1331. self,
  1332. obj: Resampler | BaseWindow,
  1333. func: AggFuncType,
  1334. *,
  1335. args,
  1336. kwargs,
  1337. ) -> None:
  1338. super(GroupByApply, self).__init__(
  1339. obj,
  1340. func,
  1341. raw=False,
  1342. result_type=None,
  1343. args=args,
  1344. kwargs=kwargs,
  1345. )
  1346. def apply(self):
  1347. raise NotImplementedError
  1348. def transform(self):
  1349. raise NotImplementedError
  1350. def reconstruct_func(
  1351. func: AggFuncType | None, **kwargs
  1352. ) -> tuple[bool, AggFuncType, tuple[str, ...] | None, npt.NDArray[np.intp] | None]:
  1353. """
  1354. This is the internal function to reconstruct func given if there is relabeling
  1355. or not and also normalize the keyword to get new order of columns.
  1356. If named aggregation is applied, `func` will be None, and kwargs contains the
  1357. column and aggregation function information to be parsed;
  1358. If named aggregation is not applied, `func` is either string (e.g. 'min') or
  1359. Callable, or list of them (e.g. ['min', np.max]), or the dictionary of column name
  1360. and str/Callable/list of them (e.g. {'A': 'min'}, or {'A': [np.min, lambda x: x]})
  1361. If relabeling is True, will return relabeling, reconstructed func, column
  1362. names, and the reconstructed order of columns.
  1363. If relabeling is False, the columns and order will be None.
  1364. Parameters
  1365. ----------
  1366. func: agg function (e.g. 'min' or Callable) or list of agg functions
  1367. (e.g. ['min', np.max]) or dictionary (e.g. {'A': ['min', np.max]}).
  1368. **kwargs: dict, kwargs used in is_multi_agg_with_relabel and
  1369. normalize_keyword_aggregation function for relabelling
  1370. Returns
  1371. -------
  1372. relabelling: bool, if there is relabelling or not
  1373. func: normalized and mangled func
  1374. columns: tuple of column names
  1375. order: array of columns indices
  1376. Examples
  1377. --------
  1378. >>> reconstruct_func(None, **{"foo": ("col", "min")})
  1379. (True, defaultdict(<class 'list'>, {'col': ['min']}), ('foo',), array([0]))
  1380. >>> reconstruct_func("min")
  1381. (False, 'min', None, None)
  1382. """
  1383. relabeling = func is None and is_multi_agg_with_relabel(**kwargs)
  1384. columns: tuple[str, ...] | None = None
  1385. order: npt.NDArray[np.intp] | None = None
  1386. if not relabeling:
  1387. if isinstance(func, list) and len(func) > len(set(func)):
  1388. # GH 28426 will raise error if duplicated function names are used and
  1389. # there is no reassigned name
  1390. raise SpecificationError(
  1391. "Function names must be unique if there is no new column names "
  1392. "assigned"
  1393. )
  1394. if func is None:
  1395. # nicer error message
  1396. raise TypeError("Must provide 'func' or tuples of '(column, aggfunc).")
  1397. if relabeling:
  1398. # error: Incompatible types in assignment (expression has type
  1399. # "MutableMapping[Hashable, list[Callable[..., Any] | str]]", variable has type
  1400. # "Callable[..., Any] | str | list[Callable[..., Any] | str] |
  1401. # MutableMapping[Hashable, Callable[..., Any] | str | list[Callable[..., Any] |
  1402. # str]] | None")
  1403. func, columns, order = normalize_keyword_aggregation( # type: ignore[assignment]
  1404. kwargs
  1405. )
  1406. assert func is not None
  1407. return relabeling, func, columns, order
  1408. def is_multi_agg_with_relabel(**kwargs) -> bool:
  1409. """
  1410. Check whether kwargs passed to .agg look like multi-agg with relabeling.
  1411. Parameters
  1412. ----------
  1413. **kwargs : dict
  1414. Returns
  1415. -------
  1416. bool
  1417. Examples
  1418. --------
  1419. >>> is_multi_agg_with_relabel(a="max")
  1420. False
  1421. >>> is_multi_agg_with_relabel(a_max=("a", "max"), a_min=("a", "min"))
  1422. True
  1423. >>> is_multi_agg_with_relabel()
  1424. False
  1425. """
  1426. return all(isinstance(v, tuple) and len(v) == 2 for v in kwargs.values()) and (
  1427. len(kwargs) > 0
  1428. )
  1429. def normalize_keyword_aggregation(
  1430. kwargs: dict,
  1431. ) -> tuple[
  1432. MutableMapping[Hashable, list[AggFuncTypeBase]],
  1433. tuple[str, ...],
  1434. npt.NDArray[np.intp],
  1435. ]:
  1436. """
  1437. Normalize user-provided "named aggregation" kwargs.
  1438. Transforms from the new ``Mapping[str, NamedAgg]`` style kwargs
  1439. to the old Dict[str, List[scalar]]].
  1440. Parameters
  1441. ----------
  1442. kwargs : dict
  1443. Returns
  1444. -------
  1445. aggspec : dict
  1446. The transformed kwargs.
  1447. columns : tuple[str, ...]
  1448. The user-provided keys.
  1449. col_idx_order : List[int]
  1450. List of columns indices.
  1451. Examples
  1452. --------
  1453. >>> normalize_keyword_aggregation({"output": ("input", "sum")})
  1454. (defaultdict(<class 'list'>, {'input': ['sum']}), ('output',), array([0]))
  1455. """
  1456. from pandas.core.indexes.base import Index
  1457. # Normalize the aggregation functions as Mapping[column, List[func]],
  1458. # process normally, then fixup the names.
  1459. # TODO: aggspec type: typing.Dict[str, List[AggScalar]]
  1460. aggspec = defaultdict(list)
  1461. order = []
  1462. columns, pairs = list(zip(*kwargs.items()))
  1463. for column, aggfunc in pairs:
  1464. aggspec[column].append(aggfunc)
  1465. order.append((column, com.get_callable_name(aggfunc) or aggfunc))
  1466. # uniquify aggfunc name if duplicated in order list
  1467. uniquified_order = _make_unique_kwarg_list(order)
  1468. # GH 25719, due to aggspec will change the order of assigned columns in aggregation
  1469. # uniquified_aggspec will store uniquified order list and will compare it with order
  1470. # based on index
  1471. aggspec_order = [
  1472. (column, com.get_callable_name(aggfunc) or aggfunc)
  1473. for column, aggfuncs in aggspec.items()
  1474. for aggfunc in aggfuncs
  1475. ]
  1476. uniquified_aggspec = _make_unique_kwarg_list(aggspec_order)
  1477. # get the new index of columns by comparison
  1478. col_idx_order = Index(uniquified_aggspec).get_indexer(uniquified_order)
  1479. return aggspec, columns, col_idx_order
  1480. def _make_unique_kwarg_list(
  1481. seq: Sequence[tuple[Any, Any]]
  1482. ) -> Sequence[tuple[Any, Any]]:
  1483. """
  1484. Uniquify aggfunc name of the pairs in the order list
  1485. Examples:
  1486. --------
  1487. >>> kwarg_list = [('a', '<lambda>'), ('a', '<lambda>'), ('b', '<lambda>')]
  1488. >>> _make_unique_kwarg_list(kwarg_list)
  1489. [('a', '<lambda>_0'), ('a', '<lambda>_1'), ('b', '<lambda>')]
  1490. """
  1491. return [
  1492. (pair[0], f"{pair[1]}_{seq[:i].count(pair)}") if seq.count(pair) > 1 else pair
  1493. for i, pair in enumerate(seq)
  1494. ]
  1495. def relabel_result(
  1496. result: DataFrame | Series,
  1497. func: dict[str, list[Callable | str]],
  1498. columns: Iterable[Hashable],
  1499. order: Iterable[int],
  1500. ) -> dict[Hashable, Series]:
  1501. """
  1502. Internal function to reorder result if relabelling is True for
  1503. dataframe.agg, and return the reordered result in dict.
  1504. Parameters:
  1505. ----------
  1506. result: Result from aggregation
  1507. func: Dict of (column name, funcs)
  1508. columns: New columns name for relabelling
  1509. order: New order for relabelling
  1510. Examples
  1511. --------
  1512. >>> from pandas.core.apply import relabel_result
  1513. >>> result = pd.DataFrame(
  1514. ... {"A": [np.nan, 2, np.nan], "C": [6, np.nan, np.nan], "B": [np.nan, 4, 2.5]},
  1515. ... index=["max", "mean", "min"]
  1516. ... )
  1517. >>> funcs = {"A": ["max"], "C": ["max"], "B": ["mean", "min"]}
  1518. >>> columns = ("foo", "aab", "bar", "dat")
  1519. >>> order = [0, 1, 2, 3]
  1520. >>> result_in_dict = relabel_result(result, funcs, columns, order)
  1521. >>> pd.DataFrame(result_in_dict, index=columns)
  1522. A C B
  1523. foo 2.0 NaN NaN
  1524. aab NaN 6.0 NaN
  1525. bar NaN NaN 4.0
  1526. dat NaN NaN 2.5
  1527. """
  1528. from pandas.core.indexes.base import Index
  1529. reordered_indexes = [
  1530. pair[0] for pair in sorted(zip(columns, order), key=lambda t: t[1])
  1531. ]
  1532. reordered_result_in_dict: dict[Hashable, Series] = {}
  1533. idx = 0
  1534. reorder_mask = not isinstance(result, ABCSeries) and len(result.columns) > 1
  1535. for col, fun in func.items():
  1536. s = result[col].dropna()
  1537. # In the `_aggregate`, the callable names are obtained and used in `result`, and
  1538. # these names are ordered alphabetically. e.g.
  1539. # C2 C1
  1540. # <lambda> 1 NaN
  1541. # amax NaN 4.0
  1542. # max NaN 4.0
  1543. # sum 18.0 6.0
  1544. # Therefore, the order of functions for each column could be shuffled
  1545. # accordingly so need to get the callable name if it is not parsed names, and
  1546. # reorder the aggregated result for each column.
  1547. # e.g. if df.agg(c1=("C2", sum), c2=("C2", lambda x: min(x))), correct order is
  1548. # [sum, <lambda>], but in `result`, it will be [<lambda>, sum], and we need to
  1549. # reorder so that aggregated values map to their functions regarding the order.
  1550. # However there is only one column being used for aggregation, not need to
  1551. # reorder since the index is not sorted, and keep as is in `funcs`, e.g.
  1552. # A
  1553. # min 1.0
  1554. # mean 1.5
  1555. # mean 1.5
  1556. if reorder_mask:
  1557. fun = [
  1558. com.get_callable_name(f) if not isinstance(f, str) else f for f in fun
  1559. ]
  1560. col_idx_order = Index(s.index).get_indexer(fun)
  1561. s = s.iloc[col_idx_order]
  1562. # assign the new user-provided "named aggregation" as index names, and reindex
  1563. # it based on the whole user-provided names.
  1564. s.index = reordered_indexes[idx : idx + len(fun)]
  1565. reordered_result_in_dict[col] = s.reindex(columns, copy=False)
  1566. idx = idx + len(fun)
  1567. return reordered_result_in_dict
  1568. def reconstruct_and_relabel_result(result, func, **kwargs) -> DataFrame | Series:
  1569. from pandas import DataFrame
  1570. relabeling, func, columns, order = reconstruct_func(func, **kwargs)
  1571. if relabeling:
  1572. # This is to keep the order to columns occurrence unchanged, and also
  1573. # keep the order of new columns occurrence unchanged
  1574. # For the return values of reconstruct_func, if relabeling is
  1575. # False, columns and order will be None.
  1576. assert columns is not None
  1577. assert order is not None
  1578. result_in_dict = relabel_result(result, func, columns, order)
  1579. result = DataFrame(result_in_dict, index=columns)
  1580. return result
  1581. # TODO: Can't use, because mypy doesn't like us setting __name__
  1582. # error: "partial[Any]" has no attribute "__name__"
  1583. # the type is:
  1584. # typing.Sequence[Callable[..., ScalarResult]]
  1585. # -> typing.Sequence[Callable[..., ScalarResult]]:
  1586. def _managle_lambda_list(aggfuncs: Sequence[Any]) -> Sequence[Any]:
  1587. """
  1588. Possibly mangle a list of aggfuncs.
  1589. Parameters
  1590. ----------
  1591. aggfuncs : Sequence
  1592. Returns
  1593. -------
  1594. mangled: list-like
  1595. A new AggSpec sequence, where lambdas have been converted
  1596. to have unique names.
  1597. Notes
  1598. -----
  1599. If just one aggfunc is passed, the name will not be mangled.
  1600. """
  1601. if len(aggfuncs) <= 1:
  1602. # don't mangle for .agg([lambda x: .])
  1603. return aggfuncs
  1604. i = 0
  1605. mangled_aggfuncs = []
  1606. for aggfunc in aggfuncs:
  1607. if com.get_callable_name(aggfunc) == "<lambda>":
  1608. aggfunc = partial(aggfunc)
  1609. aggfunc.__name__ = f"<lambda_{i}>"
  1610. i += 1
  1611. mangled_aggfuncs.append(aggfunc)
  1612. return mangled_aggfuncs
  1613. def maybe_mangle_lambdas(agg_spec: Any) -> Any:
  1614. """
  1615. Make new lambdas with unique names.
  1616. Parameters
  1617. ----------
  1618. agg_spec : Any
  1619. An argument to GroupBy.agg.
  1620. Non-dict-like `agg_spec` are pass through as is.
  1621. For dict-like `agg_spec` a new spec is returned
  1622. with name-mangled lambdas.
  1623. Returns
  1624. -------
  1625. mangled : Any
  1626. Same type as the input.
  1627. Examples
  1628. --------
  1629. >>> maybe_mangle_lambdas('sum')
  1630. 'sum'
  1631. >>> maybe_mangle_lambdas([lambda: 1, lambda: 2]) # doctest: +SKIP
  1632. [<function __main__.<lambda_0>,
  1633. <function pandas...._make_lambda.<locals>.f(*args, **kwargs)>]
  1634. """
  1635. is_dict = is_dict_like(agg_spec)
  1636. if not (is_dict or is_list_like(agg_spec)):
  1637. return agg_spec
  1638. mangled_aggspec = type(agg_spec)() # dict or OrderedDict
  1639. if is_dict:
  1640. for key, aggfuncs in agg_spec.items():
  1641. if is_list_like(aggfuncs) and not is_dict_like(aggfuncs):
  1642. mangled_aggfuncs = _managle_lambda_list(aggfuncs)
  1643. else:
  1644. mangled_aggfuncs = aggfuncs
  1645. mangled_aggspec[key] = mangled_aggfuncs
  1646. else:
  1647. mangled_aggspec = _managle_lambda_list(agg_spec)
  1648. return mangled_aggspec
  1649. def validate_func_kwargs(
  1650. kwargs: dict,
  1651. ) -> tuple[list[str], list[str | Callable[..., Any]]]:
  1652. """
  1653. Validates types of user-provided "named aggregation" kwargs.
  1654. `TypeError` is raised if aggfunc is not `str` or callable.
  1655. Parameters
  1656. ----------
  1657. kwargs : dict
  1658. Returns
  1659. -------
  1660. columns : List[str]
  1661. List of user-provided keys.
  1662. func : List[Union[str, callable[...,Any]]]
  1663. List of user-provided aggfuncs
  1664. Examples
  1665. --------
  1666. >>> validate_func_kwargs({'one': 'min', 'two': 'max'})
  1667. (['one', 'two'], ['min', 'max'])
  1668. """
  1669. tuple_given_message = "func is expected but received {} in **kwargs."
  1670. columns = list(kwargs)
  1671. func = []
  1672. for col_func in kwargs.values():
  1673. if not (isinstance(col_func, str) or callable(col_func)):
  1674. raise TypeError(tuple_given_message.format(type(col_func).__name__))
  1675. func.append(col_func)
  1676. if not columns:
  1677. no_arg_message = "Must provide 'func' or named aggregation **kwargs."
  1678. raise TypeError(no_arg_message)
  1679. return columns, func
  1680. def include_axis(op_name: Literal["agg", "apply"], colg: Series | DataFrame) -> bool:
  1681. return isinstance(colg, ABCDataFrame) or (
  1682. isinstance(colg, ABCSeries) and op_name == "agg"
  1683. )
  1684. def warn_alias_replacement(
  1685. obj: AggObjType,
  1686. func: Callable,
  1687. alias: str,
  1688. ) -> None:
  1689. if alias.startswith("np."):
  1690. full_alias = alias
  1691. else:
  1692. full_alias = f"{type(obj).__name__}.{alias}"
  1693. alias = f'"{alias}"'
  1694. warnings.warn(
  1695. f"The provided callable {func} is currently using "
  1696. f"{full_alias}. In a future version of pandas, "
  1697. f"the provided callable will be used directly. To keep current "
  1698. f"behavior pass the string {alias} instead.",
  1699. category=FutureWarning,
  1700. stacklevel=find_stack_level(),
  1701. )