test_numeric_only.py 19 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525526527528529530531532
  1. import re
  2. import numpy as np
  3. import pytest
  4. from pandas._libs import lib
  5. import pandas as pd
  6. from pandas import (
  7. DataFrame,
  8. Index,
  9. Series,
  10. Timestamp,
  11. date_range,
  12. )
  13. import pandas._testing as tm
  14. from pandas.tests.groupby import get_groupby_method_args
  15. class TestNumericOnly:
  16. # make sure that we are passing thru kwargs to our agg functions
  17. @pytest.fixture
  18. def df(self):
  19. # GH3668
  20. # GH5724
  21. df = DataFrame(
  22. {
  23. "group": [1, 1, 2],
  24. "int": [1, 2, 3],
  25. "float": [4.0, 5.0, 6.0],
  26. "string": Series(["a", "b", "c"], dtype="str"),
  27. "object": Series(["a", "b", "c"], dtype=object),
  28. "category_string": Series(list("abc")).astype("category"),
  29. "category_int": [7, 8, 9],
  30. "datetime": date_range("20130101", periods=3),
  31. "datetimetz": date_range("20130101", periods=3, tz="US/Eastern"),
  32. "timedelta": pd.timedelta_range("1 s", periods=3, freq="s"),
  33. },
  34. columns=[
  35. "group",
  36. "int",
  37. "float",
  38. "string",
  39. "object",
  40. "category_string",
  41. "category_int",
  42. "datetime",
  43. "datetimetz",
  44. "timedelta",
  45. ],
  46. )
  47. return df
  48. @pytest.mark.parametrize("method", ["mean", "median"])
  49. def test_averages(self, df, method):
  50. # mean / median
  51. expected_columns_numeric = Index(["int", "float", "category_int"])
  52. gb = df.groupby("group")
  53. expected = DataFrame(
  54. {
  55. "category_int": [7.5, 9],
  56. "float": [4.5, 6.0],
  57. "timedelta": [pd.Timedelta("1.5s"), pd.Timedelta("3s")],
  58. "int": [1.5, 3],
  59. "datetime": [
  60. Timestamp("2013-01-01 12:00:00"),
  61. Timestamp("2013-01-03 00:00:00"),
  62. ],
  63. "datetimetz": [
  64. Timestamp("2013-01-01 12:00:00", tz="US/Eastern"),
  65. Timestamp("2013-01-03 00:00:00", tz="US/Eastern"),
  66. ],
  67. },
  68. index=Index([1, 2], name="group"),
  69. columns=[
  70. "int",
  71. "float",
  72. "category_int",
  73. ],
  74. )
  75. result = getattr(gb, method)(numeric_only=True)
  76. tm.assert_frame_equal(result.reindex_like(expected), expected)
  77. expected_columns = expected.columns
  78. self._check(df, method, expected_columns, expected_columns_numeric)
  79. @pytest.mark.parametrize("method", ["min", "max"])
  80. def test_extrema(self, df, method):
  81. # TODO: min, max *should* handle
  82. # categorical (ordered) dtype
  83. expected_columns = Index(
  84. [
  85. "int",
  86. "float",
  87. "string",
  88. "category_int",
  89. "datetime",
  90. "datetimetz",
  91. "timedelta",
  92. ]
  93. )
  94. expected_columns_numeric = expected_columns
  95. self._check(df, method, expected_columns, expected_columns_numeric)
  96. @pytest.mark.parametrize("method", ["first", "last"])
  97. def test_first_last(self, df, method):
  98. expected_columns = Index(
  99. [
  100. "int",
  101. "float",
  102. "string",
  103. "object",
  104. "category_string",
  105. "category_int",
  106. "datetime",
  107. "datetimetz",
  108. "timedelta",
  109. ]
  110. )
  111. expected_columns_numeric = expected_columns
  112. self._check(df, method, expected_columns, expected_columns_numeric)
  113. @pytest.mark.parametrize("method", ["sum", "cumsum"])
  114. def test_sum_cumsum(self, df, method):
  115. expected_columns_numeric = Index(["int", "float", "category_int"])
  116. expected_columns = Index(
  117. ["int", "float", "string", "category_int", "timedelta"]
  118. )
  119. if method == "cumsum":
  120. # cumsum loses string
  121. expected_columns = Index(["int", "float", "category_int", "timedelta"])
  122. self._check(df, method, expected_columns, expected_columns_numeric)
  123. @pytest.mark.parametrize("method", ["prod", "cumprod"])
  124. def test_prod_cumprod(self, df, method):
  125. expected_columns = Index(["int", "float", "category_int"])
  126. expected_columns_numeric = expected_columns
  127. self._check(df, method, expected_columns, expected_columns_numeric)
  128. @pytest.mark.parametrize("method", ["cummin", "cummax"])
  129. def test_cummin_cummax(self, df, method):
  130. # like min, max, but don't include strings
  131. expected_columns = Index(
  132. ["int", "float", "category_int", "datetime", "datetimetz", "timedelta"]
  133. )
  134. # GH#15561: numeric_only=False set by default like min/max
  135. expected_columns_numeric = expected_columns
  136. self._check(df, method, expected_columns, expected_columns_numeric)
  137. def _check(self, df, method, expected_columns, expected_columns_numeric):
  138. gb = df.groupby("group")
  139. # object dtypes for transformations are not implemented in Cython and
  140. # have no Python fallback
  141. exception = (
  142. (NotImplementedError, TypeError) if method.startswith("cum") else TypeError
  143. )
  144. if method in ("min", "max", "cummin", "cummax", "cumsum", "cumprod"):
  145. # The methods default to numeric_only=False and raise TypeError
  146. msg = "|".join(
  147. [
  148. "Categorical is not ordered",
  149. f"Cannot perform {method} with non-ordered Categorical",
  150. re.escape(f"agg function failed [how->{method},dtype->object]"),
  151. # cumsum/cummin/cummax/cumprod
  152. "function is not implemented for this dtype",
  153. f"dtype 'str' does not support operation '{method}'",
  154. ]
  155. )
  156. with pytest.raises(exception, match=msg):
  157. getattr(gb, method)()
  158. elif method in ("sum", "mean", "median", "prod"):
  159. msg = "|".join(
  160. [
  161. "category type does not support sum operations",
  162. re.escape(f"agg function failed [how->{method},dtype->object]"),
  163. re.escape(f"agg function failed [how->{method},dtype->string]"),
  164. f"dtype 'str' does not support operation '{method}'",
  165. ]
  166. )
  167. with pytest.raises(exception, match=msg):
  168. getattr(gb, method)()
  169. else:
  170. result = getattr(gb, method)()
  171. tm.assert_index_equal(result.columns, expected_columns_numeric)
  172. if method not in ("first", "last"):
  173. msg = "|".join(
  174. [
  175. "Categorical is not ordered",
  176. "category type does not support",
  177. "function is not implemented for this dtype",
  178. f"Cannot perform {method} with non-ordered Categorical",
  179. re.escape(f"agg function failed [how->{method},dtype->object]"),
  180. re.escape(f"agg function failed [how->{method},dtype->string]"),
  181. f"dtype 'str' does not support operation '{method}'",
  182. ]
  183. )
  184. with pytest.raises(exception, match=msg):
  185. getattr(gb, method)(numeric_only=False)
  186. else:
  187. result = getattr(gb, method)(numeric_only=False)
  188. tm.assert_index_equal(result.columns, expected_columns)
  189. @pytest.mark.parametrize("numeric_only", [True, False, None])
  190. def test_axis1_numeric_only(request, groupby_func, numeric_only, using_infer_string):
  191. if groupby_func in ("idxmax", "idxmin"):
  192. pytest.skip("idxmax and idx_min tested in test_idxmin_idxmax_axis1")
  193. if groupby_func in ("corrwith", "skew"):
  194. msg = "GH#47723 groupby.corrwith and skew do not correctly implement axis=1"
  195. request.applymarker(pytest.mark.xfail(reason=msg))
  196. df = DataFrame(
  197. np.random.default_rng(2).standard_normal((10, 4)), columns=["A", "B", "C", "D"]
  198. )
  199. df["E"] = "x"
  200. groups = [1, 2, 3, 1, 2, 3, 1, 2, 3, 4]
  201. gb = df.groupby(groups)
  202. method = getattr(gb, groupby_func)
  203. args = get_groupby_method_args(groupby_func, df)
  204. kwargs = {"axis": 1}
  205. if numeric_only is not None:
  206. # when numeric_only is None we don't pass any argument
  207. kwargs["numeric_only"] = numeric_only
  208. # Functions without numeric_only and axis args
  209. no_args = ("cumprod", "cumsum", "diff", "fillna", "pct_change", "rank", "shift")
  210. # Functions with axis args
  211. has_axis = (
  212. "cumprod",
  213. "cumsum",
  214. "diff",
  215. "pct_change",
  216. "rank",
  217. "shift",
  218. "cummax",
  219. "cummin",
  220. "idxmin",
  221. "idxmax",
  222. "fillna",
  223. )
  224. warn_msg = f"DataFrameGroupBy.{groupby_func} with axis=1 is deprecated"
  225. if numeric_only is not None and groupby_func in no_args:
  226. msg = "got an unexpected keyword argument 'numeric_only'"
  227. if groupby_func in ["cumprod", "cumsum"]:
  228. with pytest.raises(TypeError, match=msg):
  229. with tm.assert_produces_warning(FutureWarning, match=warn_msg):
  230. method(*args, **kwargs)
  231. else:
  232. with pytest.raises(TypeError, match=msg):
  233. method(*args, **kwargs)
  234. elif groupby_func not in has_axis:
  235. msg = "got an unexpected keyword argument 'axis'"
  236. with pytest.raises(TypeError, match=msg):
  237. method(*args, **kwargs)
  238. # fillna and shift are successful even on object dtypes
  239. elif (numeric_only is None or not numeric_only) and groupby_func not in (
  240. "fillna",
  241. "shift",
  242. ):
  243. msgs = (
  244. # cummax, cummin, rank
  245. "not supported between instances of",
  246. # cumprod
  247. "can't multiply sequence by non-int of type 'float'",
  248. # cumsum, diff, pct_change
  249. "unsupported operand type",
  250. "has no kernel",
  251. "operation 'sub' not supported for dtype 'str' with dtype 'float64'",
  252. )
  253. if using_infer_string:
  254. pa = pytest.importorskip("pyarrow")
  255. errs = (TypeError, pa.lib.ArrowNotImplementedError)
  256. else:
  257. errs = TypeError
  258. with pytest.raises(errs, match=f"({'|'.join(msgs)})"):
  259. with tm.assert_produces_warning(FutureWarning, match=warn_msg):
  260. method(*args, **kwargs)
  261. else:
  262. with tm.assert_produces_warning(FutureWarning, match=warn_msg):
  263. result = method(*args, **kwargs)
  264. df_expected = df.drop(columns="E").T if numeric_only else df.T
  265. expected = getattr(df_expected, groupby_func)(*args).T
  266. if groupby_func == "shift" and not numeric_only:
  267. # shift with axis=1 leaves the leftmost column as numeric
  268. # but transposing for expected gives us object dtype
  269. expected = expected.astype(float)
  270. tm.assert_equal(result, expected)
  271. @pytest.mark.parametrize(
  272. "kernel, has_arg",
  273. [
  274. ("all", False),
  275. ("any", False),
  276. ("bfill", False),
  277. ("corr", True),
  278. ("corrwith", True),
  279. ("cov", True),
  280. ("cummax", True),
  281. ("cummin", True),
  282. ("cumprod", True),
  283. ("cumsum", True),
  284. ("diff", False),
  285. ("ffill", False),
  286. ("fillna", False),
  287. ("first", True),
  288. ("idxmax", True),
  289. ("idxmin", True),
  290. ("last", True),
  291. ("max", True),
  292. ("mean", True),
  293. ("median", True),
  294. ("min", True),
  295. ("nth", False),
  296. ("nunique", False),
  297. ("pct_change", False),
  298. ("prod", True),
  299. ("quantile", True),
  300. ("sem", True),
  301. ("skew", True),
  302. ("std", True),
  303. ("sum", True),
  304. ("var", True),
  305. ],
  306. )
  307. @pytest.mark.parametrize("numeric_only", [True, False, lib.no_default])
  308. @pytest.mark.parametrize("keys", [["a1"], ["a1", "a2"]])
  309. def test_numeric_only(kernel, has_arg, numeric_only, keys):
  310. # GH#46072
  311. # drops_nuisance: Whether the op drops nuisance columns even when numeric_only=False
  312. # has_arg: Whether the op has a numeric_only arg
  313. df = DataFrame({"a1": [1, 1], "a2": [2, 2], "a3": [5, 6], "b": 2 * [object]})
  314. args = get_groupby_method_args(kernel, df)
  315. kwargs = {} if numeric_only is lib.no_default else {"numeric_only": numeric_only}
  316. gb = df.groupby(keys)
  317. method = getattr(gb, kernel)
  318. if has_arg and numeric_only is True:
  319. # Cases where b does not appear in the result
  320. result = method(*args, **kwargs)
  321. assert "b" not in result.columns
  322. elif (
  323. # kernels that work on any dtype and have numeric_only arg
  324. kernel in ("first", "last")
  325. or (
  326. # kernels that work on any dtype and don't have numeric_only arg
  327. kernel in ("any", "all", "bfill", "ffill", "fillna", "nth", "nunique")
  328. and numeric_only is lib.no_default
  329. )
  330. ):
  331. warn = FutureWarning if kernel == "fillna" else None
  332. msg = "DataFrameGroupBy.fillna is deprecated"
  333. with tm.assert_produces_warning(warn, match=msg):
  334. result = method(*args, **kwargs)
  335. assert "b" in result.columns
  336. elif has_arg:
  337. assert numeric_only is not True
  338. # kernels that are successful on any dtype were above; this will fail
  339. # object dtypes for transformations are not implemented in Cython and
  340. # have no Python fallback
  341. exception = NotImplementedError if kernel.startswith("cum") else TypeError
  342. msg = "|".join(
  343. [
  344. "not allowed for this dtype",
  345. "cannot be performed against 'object' dtypes",
  346. # On PY39 message is "a number"; on PY310 and after is "a real number"
  347. "must be a string or a.* number",
  348. "unsupported operand type",
  349. "function is not implemented for this dtype",
  350. re.escape(f"agg function failed [how->{kernel},dtype->object]"),
  351. ]
  352. )
  353. if kernel == "quantile":
  354. msg = "dtype 'object' does not support operation 'quantile'"
  355. elif kernel == "idxmin":
  356. msg = "'<' not supported between instances of 'type' and 'type'"
  357. elif kernel == "idxmax":
  358. msg = "'>' not supported between instances of 'type' and 'type'"
  359. with pytest.raises(exception, match=msg):
  360. method(*args, **kwargs)
  361. elif not has_arg and numeric_only is not lib.no_default:
  362. with pytest.raises(
  363. TypeError, match="got an unexpected keyword argument 'numeric_only'"
  364. ):
  365. method(*args, **kwargs)
  366. else:
  367. assert kernel in ("diff", "pct_change")
  368. assert numeric_only is lib.no_default
  369. # Doesn't have numeric_only argument and fails on nuisance columns
  370. with pytest.raises(TypeError, match=r"unsupported operand type"):
  371. method(*args, **kwargs)
  372. @pytest.mark.filterwarnings("ignore:Downcasting object dtype arrays:FutureWarning")
  373. @pytest.mark.parametrize("dtype", [bool, int, float, object])
  374. def test_deprecate_numeric_only_series(dtype, groupby_func, request):
  375. # GH#46560
  376. grouper = [0, 0, 1]
  377. ser = Series([1, 0, 0], dtype=dtype)
  378. gb = ser.groupby(grouper)
  379. if groupby_func == "corrwith":
  380. # corrwith is not implemented on SeriesGroupBy
  381. assert not hasattr(gb, groupby_func)
  382. return
  383. method = getattr(gb, groupby_func)
  384. expected_ser = Series([1, 0, 0])
  385. expected_gb = expected_ser.groupby(grouper)
  386. expected_method = getattr(expected_gb, groupby_func)
  387. args = get_groupby_method_args(groupby_func, ser)
  388. fails_on_numeric_object = (
  389. "corr",
  390. "cov",
  391. "cummax",
  392. "cummin",
  393. "cumprod",
  394. "cumsum",
  395. "quantile",
  396. )
  397. # ops that give an object result on object input
  398. obj_result = (
  399. "first",
  400. "last",
  401. "nth",
  402. "bfill",
  403. "ffill",
  404. "shift",
  405. "sum",
  406. "diff",
  407. "pct_change",
  408. "var",
  409. "mean",
  410. "median",
  411. "min",
  412. "max",
  413. "prod",
  414. "skew",
  415. )
  416. # Test default behavior; kernels that fail may be enabled in the future but kernels
  417. # that succeed should not be allowed to fail (without deprecation, at least)
  418. if groupby_func in fails_on_numeric_object and dtype is object:
  419. if groupby_func == "quantile":
  420. msg = "dtype 'object' does not support operation 'quantile'"
  421. else:
  422. msg = "is not supported for object dtype"
  423. warn = FutureWarning if groupby_func == "fillna" else None
  424. warn_msg = "DataFrameGroupBy.fillna is deprecated"
  425. with tm.assert_produces_warning(warn, match=warn_msg):
  426. with pytest.raises(TypeError, match=msg):
  427. method(*args)
  428. elif dtype is object:
  429. warn = FutureWarning if groupby_func == "fillna" else None
  430. warn_msg = "SeriesGroupBy.fillna is deprecated"
  431. with tm.assert_produces_warning(warn, match=warn_msg):
  432. result = method(*args)
  433. with tm.assert_produces_warning(warn, match=warn_msg):
  434. expected = expected_method(*args)
  435. if groupby_func in obj_result:
  436. expected = expected.astype(object)
  437. tm.assert_series_equal(result, expected)
  438. has_numeric_only = (
  439. "first",
  440. "last",
  441. "max",
  442. "mean",
  443. "median",
  444. "min",
  445. "prod",
  446. "quantile",
  447. "sem",
  448. "skew",
  449. "std",
  450. "sum",
  451. "var",
  452. "cummax",
  453. "cummin",
  454. "cumprod",
  455. "cumsum",
  456. )
  457. if groupby_func not in has_numeric_only:
  458. msg = "got an unexpected keyword argument 'numeric_only'"
  459. with pytest.raises(TypeError, match=msg):
  460. method(*args, numeric_only=True)
  461. elif dtype is object:
  462. msg = "|".join(
  463. [
  464. "SeriesGroupBy.sem called with numeric_only=True and dtype object",
  465. "Series.skew does not allow numeric_only=True with non-numeric",
  466. "cum(sum|prod|min|max) is not supported for object dtype",
  467. r"Cannot use numeric_only=True with SeriesGroupBy\..* and non-numeric",
  468. ]
  469. )
  470. with pytest.raises(TypeError, match=msg):
  471. method(*args, numeric_only=True)
  472. elif dtype == bool and groupby_func == "quantile":
  473. msg = "Allowing bool dtype in SeriesGroupBy.quantile"
  474. with tm.assert_produces_warning(FutureWarning, match=msg):
  475. # GH#51424
  476. result = method(*args, numeric_only=True)
  477. expected = method(*args, numeric_only=False)
  478. tm.assert_series_equal(result, expected)
  479. else:
  480. result = method(*args, numeric_only=True)
  481. expected = method(*args, numeric_only=False)
  482. tm.assert_series_equal(result, expected)