test_api.py 13 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398
  1. import numpy as np
  2. import pytest
  3. from pandas.errors import (
  4. DataError,
  5. SpecificationError,
  6. )
  7. from pandas import (
  8. DataFrame,
  9. Index,
  10. MultiIndex,
  11. Period,
  12. Series,
  13. Timestamp,
  14. concat,
  15. date_range,
  16. timedelta_range,
  17. )
  18. import pandas._testing as tm
  19. def test_getitem(step):
  20. frame = DataFrame(np.random.default_rng(2).standard_normal((5, 5)))
  21. r = frame.rolling(window=5, step=step)
  22. tm.assert_index_equal(r._selected_obj.columns, frame[::step].columns)
  23. r = frame.rolling(window=5, step=step)[1]
  24. assert r._selected_obj.name == frame[::step].columns[1]
  25. # technically this is allowed
  26. r = frame.rolling(window=5, step=step)[1, 3]
  27. tm.assert_index_equal(r._selected_obj.columns, frame[::step].columns[[1, 3]])
  28. r = frame.rolling(window=5, step=step)[[1, 3]]
  29. tm.assert_index_equal(r._selected_obj.columns, frame[::step].columns[[1, 3]])
  30. def test_select_bad_cols():
  31. df = DataFrame([[1, 2]], columns=["A", "B"])
  32. g = df.rolling(window=5)
  33. with pytest.raises(KeyError, match="Columns not found: 'C'"):
  34. g[["C"]]
  35. with pytest.raises(KeyError, match="^[^A]+$"):
  36. # A should not be referenced as a bad column...
  37. # will have to rethink regex if you change message!
  38. g[["A", "C"]]
  39. def test_attribute_access():
  40. df = DataFrame([[1, 2]], columns=["A", "B"])
  41. r = df.rolling(window=5)
  42. tm.assert_series_equal(r.A.sum(), r["A"].sum())
  43. msg = "'Rolling' object has no attribute 'F'"
  44. with pytest.raises(AttributeError, match=msg):
  45. r.F
  46. def tests_skip_nuisance(step):
  47. df = DataFrame({"A": range(5), "B": range(5, 10), "C": "foo"})
  48. r = df.rolling(window=3, step=step)
  49. result = r[["A", "B"]].sum()
  50. expected = DataFrame(
  51. {"A": [np.nan, np.nan, 3, 6, 9], "B": [np.nan, np.nan, 18, 21, 24]},
  52. columns=list("AB"),
  53. )[::step]
  54. tm.assert_frame_equal(result, expected)
  55. def test_sum_object_str_raises(step):
  56. df = DataFrame({"A": range(5), "B": range(5, 10), "C": "foo"})
  57. r = df.rolling(window=3, step=step)
  58. with pytest.raises(
  59. DataError, match="Cannot aggregate non-numeric type: object|str"
  60. ):
  61. # GH#42738, enforced in 2.0
  62. r.sum()
  63. def test_agg(step):
  64. df = DataFrame({"A": range(5), "B": range(0, 10, 2)})
  65. r = df.rolling(window=3, step=step)
  66. a_mean = r["A"].mean()
  67. a_std = r["A"].std()
  68. a_sum = r["A"].sum()
  69. b_mean = r["B"].mean()
  70. b_std = r["B"].std()
  71. with tm.assert_produces_warning(FutureWarning, match="using Rolling.[mean|std]"):
  72. result = r.aggregate([np.mean, np.std])
  73. expected = concat([a_mean, a_std, b_mean, b_std], axis=1)
  74. expected.columns = MultiIndex.from_product([["A", "B"], ["mean", "std"]])
  75. tm.assert_frame_equal(result, expected)
  76. with tm.assert_produces_warning(FutureWarning, match="using Rolling.[mean|std]"):
  77. result = r.aggregate({"A": np.mean, "B": np.std})
  78. expected = concat([a_mean, b_std], axis=1)
  79. tm.assert_frame_equal(result, expected, check_like=True)
  80. result = r.aggregate({"A": ["mean", "std"]})
  81. expected = concat([a_mean, a_std], axis=1)
  82. expected.columns = MultiIndex.from_tuples([("A", "mean"), ("A", "std")])
  83. tm.assert_frame_equal(result, expected)
  84. result = r["A"].aggregate(["mean", "sum"])
  85. expected = concat([a_mean, a_sum], axis=1)
  86. expected.columns = ["mean", "sum"]
  87. tm.assert_frame_equal(result, expected)
  88. msg = "nested renamer is not supported"
  89. with pytest.raises(SpecificationError, match=msg):
  90. # using a dict with renaming
  91. r.aggregate({"A": {"mean": "mean", "sum": "sum"}})
  92. with pytest.raises(SpecificationError, match=msg):
  93. r.aggregate(
  94. {"A": {"mean": "mean", "sum": "sum"}, "B": {"mean2": "mean", "sum2": "sum"}}
  95. )
  96. result = r.aggregate({"A": ["mean", "std"], "B": ["mean", "std"]})
  97. expected = concat([a_mean, a_std, b_mean, b_std], axis=1)
  98. exp_cols = [("A", "mean"), ("A", "std"), ("B", "mean"), ("B", "std")]
  99. expected.columns = MultiIndex.from_tuples(exp_cols)
  100. tm.assert_frame_equal(result, expected, check_like=True)
  101. @pytest.mark.parametrize(
  102. "func", [["min"], ["mean", "max"], {"b": "sum"}, {"b": "prod", "c": "median"}]
  103. )
  104. def test_multi_axis_1_raises(func):
  105. # GH#46904
  106. df = DataFrame({"a": [1, 1, 2], "b": [3, 4, 5], "c": [6, 7, 8]})
  107. msg = "Support for axis=1 in DataFrame.rolling is deprecated"
  108. with tm.assert_produces_warning(FutureWarning, match=msg):
  109. r = df.rolling(window=3, axis=1)
  110. with pytest.raises(NotImplementedError, match="axis other than 0 is not supported"):
  111. r.agg(func)
  112. def test_agg_apply(raw):
  113. # passed lambda
  114. df = DataFrame({"A": range(5), "B": range(0, 10, 2)})
  115. r = df.rolling(window=3)
  116. a_sum = r["A"].sum()
  117. with tm.assert_produces_warning(FutureWarning, match="using Rolling.[sum|std]"):
  118. result = r.agg({"A": np.sum, "B": lambda x: np.std(x, ddof=1)})
  119. rcustom = r["B"].apply(lambda x: np.std(x, ddof=1), raw=raw)
  120. expected = concat([a_sum, rcustom], axis=1)
  121. tm.assert_frame_equal(result, expected, check_like=True)
  122. def test_agg_consistency(step):
  123. df = DataFrame({"A": range(5), "B": range(0, 10, 2)})
  124. r = df.rolling(window=3, step=step)
  125. with tm.assert_produces_warning(FutureWarning, match="using Rolling.[sum|mean]"):
  126. result = r.agg([np.sum, np.mean]).columns
  127. expected = MultiIndex.from_product([list("AB"), ["sum", "mean"]])
  128. tm.assert_index_equal(result, expected)
  129. with tm.assert_produces_warning(FutureWarning, match="using Rolling.[sum|mean]"):
  130. result = r["A"].agg([np.sum, np.mean]).columns
  131. expected = Index(["sum", "mean"])
  132. tm.assert_index_equal(result, expected)
  133. with tm.assert_produces_warning(FutureWarning, match="using Rolling.[sum|mean]"):
  134. result = r.agg({"A": [np.sum, np.mean]}).columns
  135. expected = MultiIndex.from_tuples([("A", "sum"), ("A", "mean")])
  136. tm.assert_index_equal(result, expected)
  137. def test_agg_nested_dicts():
  138. # API change for disallowing these types of nested dicts
  139. df = DataFrame({"A": range(5), "B": range(0, 10, 2)})
  140. r = df.rolling(window=3)
  141. msg = "nested renamer is not supported"
  142. with pytest.raises(SpecificationError, match=msg):
  143. r.aggregate({"r1": {"A": ["mean", "sum"]}, "r2": {"B": ["mean", "sum"]}})
  144. expected = concat(
  145. [r["A"].mean(), r["A"].std(), r["B"].mean(), r["B"].std()], axis=1
  146. )
  147. expected.columns = MultiIndex.from_tuples(
  148. [("ra", "mean"), ("ra", "std"), ("rb", "mean"), ("rb", "std")]
  149. )
  150. with pytest.raises(SpecificationError, match=msg):
  151. r[["A", "B"]].agg({"A": {"ra": ["mean", "std"]}, "B": {"rb": ["mean", "std"]}})
  152. with pytest.raises(SpecificationError, match=msg):
  153. r.agg({"A": {"ra": ["mean", "std"]}, "B": {"rb": ["mean", "std"]}})
  154. def test_count_nonnumeric_types(step):
  155. # GH12541
  156. cols = [
  157. "int",
  158. "float",
  159. "string",
  160. "datetime",
  161. "timedelta",
  162. "periods",
  163. "fl_inf",
  164. "fl_nan",
  165. "str_nan",
  166. "dt_nat",
  167. "periods_nat",
  168. ]
  169. dt_nat_col = [Timestamp("20170101"), Timestamp("20170203"), Timestamp(None)]
  170. df = DataFrame(
  171. {
  172. "int": [1, 2, 3],
  173. "float": [4.0, 5.0, 6.0],
  174. "string": list("abc"),
  175. "datetime": date_range("20170101", periods=3),
  176. "timedelta": timedelta_range("1 s", periods=3, freq="s"),
  177. "periods": [
  178. Period("2012-01"),
  179. Period("2012-02"),
  180. Period("2012-03"),
  181. ],
  182. "fl_inf": [1.0, 2.0, np.inf],
  183. "fl_nan": [1.0, 2.0, np.nan],
  184. "str_nan": ["aa", "bb", np.nan],
  185. "dt_nat": dt_nat_col,
  186. "periods_nat": [
  187. Period("2012-01"),
  188. Period("2012-02"),
  189. Period(None),
  190. ],
  191. },
  192. columns=cols,
  193. )
  194. expected = DataFrame(
  195. {
  196. "int": [1.0, 2.0, 2.0],
  197. "float": [1.0, 2.0, 2.0],
  198. "string": [1.0, 2.0, 2.0],
  199. "datetime": [1.0, 2.0, 2.0],
  200. "timedelta": [1.0, 2.0, 2.0],
  201. "periods": [1.0, 2.0, 2.0],
  202. "fl_inf": [1.0, 2.0, 2.0],
  203. "fl_nan": [1.0, 2.0, 1.0],
  204. "str_nan": [1.0, 2.0, 1.0],
  205. "dt_nat": [1.0, 2.0, 1.0],
  206. "periods_nat": [1.0, 2.0, 1.0],
  207. },
  208. columns=cols,
  209. )[::step]
  210. result = df.rolling(window=2, min_periods=0, step=step).count()
  211. tm.assert_frame_equal(result, expected)
  212. result = df.rolling(1, min_periods=0, step=step).count()
  213. expected = df.notna().astype(float)[::step]
  214. tm.assert_frame_equal(result, expected)
  215. def test_preserve_metadata():
  216. # GH 10565
  217. s = Series(np.arange(100), name="foo")
  218. s2 = s.rolling(30).sum()
  219. s3 = s.rolling(20).sum()
  220. assert s2.name == "foo"
  221. assert s3.name == "foo"
  222. @pytest.mark.parametrize(
  223. "func,window_size,expected_vals",
  224. [
  225. (
  226. "rolling",
  227. 2,
  228. [
  229. [np.nan, np.nan, np.nan, np.nan],
  230. [15.0, 20.0, 25.0, 20.0],
  231. [25.0, 30.0, 35.0, 30.0],
  232. [np.nan, np.nan, np.nan, np.nan],
  233. [20.0, 30.0, 35.0, 30.0],
  234. [35.0, 40.0, 60.0, 40.0],
  235. [60.0, 80.0, 85.0, 80],
  236. ],
  237. ),
  238. (
  239. "expanding",
  240. None,
  241. [
  242. [10.0, 10.0, 20.0, 20.0],
  243. [15.0, 20.0, 25.0, 20.0],
  244. [20.0, 30.0, 30.0, 20.0],
  245. [10.0, 10.0, 30.0, 30.0],
  246. [20.0, 30.0, 35.0, 30.0],
  247. [26.666667, 40.0, 50.0, 30.0],
  248. [40.0, 80.0, 60.0, 30.0],
  249. ],
  250. ),
  251. ],
  252. )
  253. def test_multiple_agg_funcs(func, window_size, expected_vals):
  254. # GH 15072
  255. df = DataFrame(
  256. [
  257. ["A", 10, 20],
  258. ["A", 20, 30],
  259. ["A", 30, 40],
  260. ["B", 10, 30],
  261. ["B", 30, 40],
  262. ["B", 40, 80],
  263. ["B", 80, 90],
  264. ],
  265. columns=["stock", "low", "high"],
  266. )
  267. f = getattr(df.groupby("stock"), func)
  268. if window_size:
  269. window = f(window_size)
  270. else:
  271. window = f()
  272. index = MultiIndex.from_tuples(
  273. [("A", 0), ("A", 1), ("A", 2), ("B", 3), ("B", 4), ("B", 5), ("B", 6)],
  274. names=["stock", None],
  275. )
  276. columns = MultiIndex.from_tuples(
  277. [("low", "mean"), ("low", "max"), ("high", "mean"), ("high", "min")]
  278. )
  279. expected = DataFrame(expected_vals, index=index, columns=columns)
  280. result = window.agg({"low": ["mean", "max"], "high": ["mean", "min"]})
  281. tm.assert_frame_equal(result, expected)
  282. def test_dont_modify_attributes_after_methods(
  283. arithmetic_win_operators, closed, center, min_periods, step
  284. ):
  285. # GH 39554
  286. roll_obj = Series(range(1)).rolling(
  287. 1, center=center, closed=closed, min_periods=min_periods, step=step
  288. )
  289. expected = {attr: getattr(roll_obj, attr) for attr in roll_obj._attributes}
  290. getattr(roll_obj, arithmetic_win_operators)()
  291. result = {attr: getattr(roll_obj, attr) for attr in roll_obj._attributes}
  292. assert result == expected
  293. def test_centered_axis_validation(step):
  294. # ok
  295. msg = "The 'axis' keyword in Series.rolling is deprecated"
  296. with tm.assert_produces_warning(FutureWarning, match=msg):
  297. Series(np.ones(10)).rolling(window=3, center=True, axis=0, step=step).mean()
  298. # bad axis
  299. msg = "No axis named 1 for object type Series"
  300. with pytest.raises(ValueError, match=msg):
  301. Series(np.ones(10)).rolling(window=3, center=True, axis=1, step=step).mean()
  302. # ok ok
  303. df = DataFrame(np.ones((10, 10)))
  304. msg = "The 'axis' keyword in DataFrame.rolling is deprecated"
  305. with tm.assert_produces_warning(FutureWarning, match=msg):
  306. df.rolling(window=3, center=True, axis=0, step=step).mean()
  307. msg = "Support for axis=1 in DataFrame.rolling is deprecated"
  308. with tm.assert_produces_warning(FutureWarning, match=msg):
  309. df.rolling(window=3, center=True, axis=1, step=step).mean()
  310. # bad axis
  311. msg = "No axis named 2 for object type DataFrame"
  312. with pytest.raises(ValueError, match=msg):
  313. (df.rolling(window=3, center=True, axis=2, step=step).mean())
  314. def test_rolling_min_min_periods(step):
  315. a = Series([1, 2, 3, 4, 5])
  316. result = a.rolling(window=100, min_periods=1, step=step).min()
  317. expected = Series(np.ones(len(a)))[::step]
  318. tm.assert_series_equal(result, expected)
  319. msg = "min_periods 5 must be <= window 3"
  320. with pytest.raises(ValueError, match=msg):
  321. Series([1, 2, 3]).rolling(window=3, min_periods=5, step=step).min()
  322. def test_rolling_max_min_periods(step):
  323. a = Series([1, 2, 3, 4, 5], dtype=np.float64)
  324. result = a.rolling(window=100, min_periods=1, step=step).max()
  325. expected = a[::step]
  326. tm.assert_almost_equal(result, expected)
  327. msg = "min_periods 5 must be <= window 3"
  328. with pytest.raises(ValueError, match=msg):
  329. Series([1, 2, 3]).rolling(window=3, min_periods=5, step=step).max()