test_apply.py 9.6 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328
  1. import numpy as np
  2. import pytest
  3. from pandas import (
  4. DataFrame,
  5. Index,
  6. MultiIndex,
  7. Series,
  8. Timestamp,
  9. concat,
  10. date_range,
  11. isna,
  12. notna,
  13. )
  14. import pandas._testing as tm
  15. from pandas.tseries import offsets
  16. # suppress warnings about empty slices, as we are deliberately testing
  17. # with a 0-length Series
  18. pytestmark = pytest.mark.filterwarnings(
  19. "ignore:.*(empty slice|0 for slice).*:RuntimeWarning"
  20. )
  21. def f(x):
  22. return x[np.isfinite(x)].mean()
  23. @pytest.mark.parametrize("bad_raw", [None, 1, 0])
  24. def test_rolling_apply_invalid_raw(bad_raw):
  25. with pytest.raises(ValueError, match="raw parameter must be `True` or `False`"):
  26. Series(range(3)).rolling(1).apply(len, raw=bad_raw)
  27. def test_rolling_apply_out_of_bounds(engine_and_raw):
  28. # gh-1850
  29. engine, raw = engine_and_raw
  30. vals = Series([1, 2, 3, 4])
  31. result = vals.rolling(10).apply(np.sum, engine=engine, raw=raw)
  32. assert result.isna().all()
  33. result = vals.rolling(10, min_periods=1).apply(np.sum, engine=engine, raw=raw)
  34. expected = Series([1, 3, 6, 10], dtype=float)
  35. tm.assert_almost_equal(result, expected)
  36. @pytest.mark.parametrize("window", [2, "2s"])
  37. def test_rolling_apply_with_pandas_objects(window):
  38. # 5071
  39. df = DataFrame(
  40. {
  41. "A": np.random.default_rng(2).standard_normal(5),
  42. "B": np.random.default_rng(2).integers(0, 10, size=5),
  43. },
  44. index=date_range("20130101", periods=5, freq="s"),
  45. )
  46. # we have an equal spaced timeseries index
  47. # so simulate removing the first period
  48. def f(x):
  49. if x.index[0] == df.index[0]:
  50. return np.nan
  51. return x.iloc[-1]
  52. result = df.rolling(window).apply(f, raw=False)
  53. expected = df.iloc[2:].reindex_like(df)
  54. tm.assert_frame_equal(result, expected)
  55. with tm.external_error_raised(AttributeError):
  56. df.rolling(window).apply(f, raw=True)
  57. def test_rolling_apply(engine_and_raw, step):
  58. engine, raw = engine_and_raw
  59. expected = Series([], dtype="float64")
  60. result = expected.rolling(10, step=step).apply(
  61. lambda x: x.mean(), engine=engine, raw=raw
  62. )
  63. tm.assert_series_equal(result, expected)
  64. # gh-8080
  65. s = Series([None, None, None])
  66. result = s.rolling(2, min_periods=0, step=step).apply(
  67. lambda x: len(x), engine=engine, raw=raw
  68. )
  69. expected = Series([1.0, 2.0, 2.0])[::step]
  70. tm.assert_series_equal(result, expected)
  71. result = s.rolling(2, min_periods=0, step=step).apply(len, engine=engine, raw=raw)
  72. tm.assert_series_equal(result, expected)
  73. def test_all_apply(engine_and_raw):
  74. engine, raw = engine_and_raw
  75. df = (
  76. DataFrame(
  77. {"A": date_range("20130101", periods=5, freq="s"), "B": range(5)}
  78. ).set_index("A")
  79. * 2
  80. )
  81. er = df.rolling(window=1)
  82. r = df.rolling(window="1s")
  83. result = r.apply(lambda x: 1, engine=engine, raw=raw)
  84. expected = er.apply(lambda x: 1, engine=engine, raw=raw)
  85. tm.assert_frame_equal(result, expected)
  86. def test_ragged_apply(engine_and_raw):
  87. engine, raw = engine_and_raw
  88. df = DataFrame({"B": range(5)})
  89. df.index = [
  90. Timestamp("20130101 09:00:00"),
  91. Timestamp("20130101 09:00:02"),
  92. Timestamp("20130101 09:00:03"),
  93. Timestamp("20130101 09:00:05"),
  94. Timestamp("20130101 09:00:06"),
  95. ]
  96. f = lambda x: 1
  97. result = df.rolling(window="1s", min_periods=1).apply(f, engine=engine, raw=raw)
  98. expected = df.copy()
  99. expected["B"] = 1.0
  100. tm.assert_frame_equal(result, expected)
  101. result = df.rolling(window="2s", min_periods=1).apply(f, engine=engine, raw=raw)
  102. expected = df.copy()
  103. expected["B"] = 1.0
  104. tm.assert_frame_equal(result, expected)
  105. result = df.rolling(window="5s", min_periods=1).apply(f, engine=engine, raw=raw)
  106. expected = df.copy()
  107. expected["B"] = 1.0
  108. tm.assert_frame_equal(result, expected)
  109. def test_invalid_engine():
  110. with pytest.raises(ValueError, match="engine must be either 'numba' or 'cython'"):
  111. Series(range(1)).rolling(1).apply(lambda x: x, engine="foo")
  112. def test_invalid_engine_kwargs_cython():
  113. with pytest.raises(ValueError, match="cython engine does not accept engine_kwargs"):
  114. Series(range(1)).rolling(1).apply(
  115. lambda x: x, engine="cython", engine_kwargs={"nopython": False}
  116. )
  117. def test_invalid_raw_numba():
  118. with pytest.raises(
  119. ValueError, match="raw must be `True` when using the numba engine"
  120. ):
  121. Series(range(1)).rolling(1).apply(lambda x: x, raw=False, engine="numba")
  122. @pytest.mark.parametrize("args_kwargs", [[None, {"par": 10}], [(10,), None]])
  123. def test_rolling_apply_args_kwargs(args_kwargs):
  124. # GH 33433
  125. def numpysum(x, par):
  126. return np.sum(x + par)
  127. df = DataFrame({"gr": [1, 1], "a": [1, 2]})
  128. idx = Index(["gr", "a"])
  129. expected = DataFrame([[11.0, 11.0], [11.0, 12.0]], columns=idx)
  130. result = df.rolling(1).apply(numpysum, args=args_kwargs[0], kwargs=args_kwargs[1])
  131. tm.assert_frame_equal(result, expected)
  132. midx = MultiIndex.from_tuples([(1, 0), (1, 1)], names=["gr", None])
  133. expected = Series([11.0, 12.0], index=midx, name="a")
  134. gb_rolling = df.groupby("gr")["a"].rolling(1)
  135. result = gb_rolling.apply(numpysum, args=args_kwargs[0], kwargs=args_kwargs[1])
  136. tm.assert_series_equal(result, expected)
  137. def test_nans(raw):
  138. obj = Series(np.random.default_rng(2).standard_normal(50))
  139. obj[:10] = np.nan
  140. obj[-10:] = np.nan
  141. result = obj.rolling(50, min_periods=30).apply(f, raw=raw)
  142. tm.assert_almost_equal(result.iloc[-1], np.mean(obj[10:-10]))
  143. # min_periods is working correctly
  144. result = obj.rolling(20, min_periods=15).apply(f, raw=raw)
  145. assert isna(result.iloc[23])
  146. assert not isna(result.iloc[24])
  147. assert not isna(result.iloc[-6])
  148. assert isna(result.iloc[-5])
  149. obj2 = Series(np.random.default_rng(2).standard_normal(20))
  150. result = obj2.rolling(10, min_periods=5).apply(f, raw=raw)
  151. assert isna(result.iloc[3])
  152. assert notna(result.iloc[4])
  153. result0 = obj.rolling(20, min_periods=0).apply(f, raw=raw)
  154. result1 = obj.rolling(20, min_periods=1).apply(f, raw=raw)
  155. tm.assert_almost_equal(result0, result1)
  156. def test_center(raw):
  157. obj = Series(np.random.default_rng(2).standard_normal(50))
  158. obj[:10] = np.nan
  159. obj[-10:] = np.nan
  160. result = obj.rolling(20, min_periods=15, center=True).apply(f, raw=raw)
  161. expected = (
  162. concat([obj, Series([np.nan] * 9)])
  163. .rolling(20, min_periods=15)
  164. .apply(f, raw=raw)
  165. .iloc[9:]
  166. .reset_index(drop=True)
  167. )
  168. tm.assert_series_equal(result, expected)
  169. def test_series(raw, series):
  170. result = series.rolling(50).apply(f, raw=raw)
  171. assert isinstance(result, Series)
  172. tm.assert_almost_equal(result.iloc[-1], np.mean(series[-50:]))
  173. def test_frame(raw, frame):
  174. result = frame.rolling(50).apply(f, raw=raw)
  175. assert isinstance(result, DataFrame)
  176. tm.assert_series_equal(
  177. result.iloc[-1, :],
  178. frame.iloc[-50:, :].apply(np.mean, axis=0, raw=raw),
  179. check_names=False,
  180. )
  181. def test_time_rule_series(raw, series):
  182. win = 25
  183. minp = 10
  184. ser = series[::2].resample("B").mean()
  185. series_result = ser.rolling(window=win, min_periods=minp).apply(f, raw=raw)
  186. last_date = series_result.index[-1]
  187. prev_date = last_date - 24 * offsets.BDay()
  188. trunc_series = series[::2].truncate(prev_date, last_date)
  189. tm.assert_almost_equal(series_result.iloc[-1], np.mean(trunc_series))
  190. def test_time_rule_frame(raw, frame):
  191. win = 25
  192. minp = 10
  193. frm = frame[::2].resample("B").mean()
  194. frame_result = frm.rolling(window=win, min_periods=minp).apply(f, raw=raw)
  195. last_date = frame_result.index[-1]
  196. prev_date = last_date - 24 * offsets.BDay()
  197. trunc_frame = frame[::2].truncate(prev_date, last_date)
  198. tm.assert_series_equal(
  199. frame_result.xs(last_date),
  200. trunc_frame.apply(np.mean, raw=raw),
  201. check_names=False,
  202. )
  203. @pytest.mark.parametrize("minp", [0, 99, 100])
  204. def test_min_periods(raw, series, minp, step):
  205. result = series.rolling(len(series) + 1, min_periods=minp, step=step).apply(
  206. f, raw=raw
  207. )
  208. expected = series.rolling(len(series), min_periods=minp, step=step).apply(
  209. f, raw=raw
  210. )
  211. nan_mask = isna(result)
  212. tm.assert_series_equal(nan_mask, isna(expected))
  213. nan_mask = ~nan_mask
  214. tm.assert_almost_equal(result[nan_mask], expected[nan_mask])
  215. def test_center_reindex_series(raw, series):
  216. # shifter index
  217. s = [f"x{x:d}" for x in range(12)]
  218. minp = 10
  219. series_xp = (
  220. series.reindex(list(series.index) + s)
  221. .rolling(window=25, min_periods=minp)
  222. .apply(f, raw=raw)
  223. .shift(-12)
  224. .reindex(series.index)
  225. )
  226. series_rs = series.rolling(window=25, min_periods=minp, center=True).apply(
  227. f, raw=raw
  228. )
  229. tm.assert_series_equal(series_xp, series_rs)
  230. def test_center_reindex_frame(raw):
  231. # shifter index
  232. frame = DataFrame(range(100), index=date_range("2020-01-01", freq="D", periods=100))
  233. s = [f"x{x:d}" for x in range(12)]
  234. minp = 10
  235. frame_xp = (
  236. frame.reindex(list(frame.index) + s)
  237. .rolling(window=25, min_periods=minp)
  238. .apply(f, raw=raw)
  239. .shift(-12)
  240. .reindex(frame.index)
  241. )
  242. frame_rs = frame.rolling(window=25, min_periods=minp, center=True).apply(f, raw=raw)
  243. tm.assert_frame_equal(frame_xp, frame_rs)
  244. def test_axis1(raw):
  245. # GH 45912
  246. df = DataFrame([1, 2])
  247. msg = "Support for axis=1 in DataFrame.rolling is deprecated"
  248. with tm.assert_produces_warning(FutureWarning, match=msg):
  249. result = df.rolling(window=1, axis=1).apply(np.sum, raw=raw)
  250. expected = DataFrame([1.0, 2.0])
  251. tm.assert_frame_equal(result, expected)