test_ewm.py 22 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525526527528529530531532533534535536537538539540541542543544545546547548549550551552553554555556557558559560561562563564565566567568569570571572573574575576577578579580581582583584585586587588589590591592593594595596597598599600601602603604605606607608609610611612613614615616617618619620621622623624625626627628629630631632633634635636637638639640641642643644645646647648649650651652653654655656657658659660661662663664665666667668669670671672673674675676677678679680681682683684685686687688689690691692693694695696697698699700701702703704705706707708709710711712713714715716717718719720721722723724725726727
  1. import numpy as np
  2. import pytest
  3. from pandas import (
  4. DataFrame,
  5. DatetimeIndex,
  6. Series,
  7. date_range,
  8. )
  9. import pandas._testing as tm
  10. def test_doc_string():
  11. df = DataFrame({"B": [0, 1, 2, np.nan, 4]})
  12. df
  13. df.ewm(com=0.5).mean()
  14. def test_constructor(frame_or_series):
  15. c = frame_or_series(range(5)).ewm
  16. # valid
  17. c(com=0.5)
  18. c(span=1.5)
  19. c(alpha=0.5)
  20. c(halflife=0.75)
  21. c(com=0.5, span=None)
  22. c(alpha=0.5, com=None)
  23. c(halflife=0.75, alpha=None)
  24. # not valid: mutually exclusive
  25. msg = "comass, span, halflife, and alpha are mutually exclusive"
  26. with pytest.raises(ValueError, match=msg):
  27. c(com=0.5, alpha=0.5)
  28. with pytest.raises(ValueError, match=msg):
  29. c(span=1.5, halflife=0.75)
  30. with pytest.raises(ValueError, match=msg):
  31. c(alpha=0.5, span=1.5)
  32. # not valid: com < 0
  33. msg = "comass must satisfy: comass >= 0"
  34. with pytest.raises(ValueError, match=msg):
  35. c(com=-0.5)
  36. # not valid: span < 1
  37. msg = "span must satisfy: span >= 1"
  38. with pytest.raises(ValueError, match=msg):
  39. c(span=0.5)
  40. # not valid: halflife <= 0
  41. msg = "halflife must satisfy: halflife > 0"
  42. with pytest.raises(ValueError, match=msg):
  43. c(halflife=0)
  44. # not valid: alpha <= 0 or alpha > 1
  45. msg = "alpha must satisfy: 0 < alpha <= 1"
  46. for alpha in (-0.5, 1.5):
  47. with pytest.raises(ValueError, match=msg):
  48. c(alpha=alpha)
  49. def test_ewma_times_not_datetime_type():
  50. msg = r"times must be datetime64 dtype."
  51. with pytest.raises(ValueError, match=msg):
  52. Series(range(5)).ewm(times=np.arange(5))
  53. def test_ewma_times_not_same_length():
  54. msg = "times must be the same length as the object."
  55. with pytest.raises(ValueError, match=msg):
  56. Series(range(5)).ewm(times=np.arange(4).astype("datetime64[ns]"))
  57. def test_ewma_halflife_not_correct_type():
  58. msg = "halflife must be a timedelta convertible object"
  59. with pytest.raises(ValueError, match=msg):
  60. Series(range(5)).ewm(halflife=1, times=np.arange(5).astype("datetime64[ns]"))
  61. def test_ewma_halflife_without_times(halflife_with_times):
  62. msg = "halflife can only be a timedelta convertible argument if times is not None."
  63. with pytest.raises(ValueError, match=msg):
  64. Series(range(5)).ewm(halflife=halflife_with_times)
  65. @pytest.mark.parametrize(
  66. "times",
  67. [
  68. np.arange(10).astype("datetime64[D]").astype("datetime64[ns]"),
  69. date_range("2000", freq="D", periods=10),
  70. date_range("2000", freq="D", periods=10).tz_localize("UTC"),
  71. ],
  72. )
  73. @pytest.mark.parametrize("min_periods", [0, 2])
  74. def test_ewma_with_times_equal_spacing(halflife_with_times, times, min_periods):
  75. halflife = halflife_with_times
  76. data = np.arange(10.0)
  77. data[::2] = np.nan
  78. df = DataFrame({"A": data})
  79. result = df.ewm(halflife=halflife, min_periods=min_periods, times=times).mean()
  80. expected = df.ewm(halflife=1.0, min_periods=min_periods).mean()
  81. tm.assert_frame_equal(result, expected)
  82. def test_ewma_with_times_variable_spacing(tz_aware_fixture, unit):
  83. tz = tz_aware_fixture
  84. halflife = "23 days"
  85. times = (
  86. DatetimeIndex(["2020-01-01", "2020-01-10T00:04:05", "2020-02-23T05:00:23"])
  87. .tz_localize(tz)
  88. .as_unit(unit)
  89. )
  90. data = np.arange(3)
  91. df = DataFrame(data)
  92. result = df.ewm(halflife=halflife, times=times).mean()
  93. expected = DataFrame([0.0, 0.5674161888241773, 1.545239952073459])
  94. tm.assert_frame_equal(result, expected)
  95. def test_ewm_with_nat_raises(halflife_with_times):
  96. # GH#38535
  97. ser = Series(range(1))
  98. times = DatetimeIndex(["NaT"])
  99. with pytest.raises(ValueError, match="Cannot convert NaT values to integer"):
  100. ser.ewm(com=0.1, halflife=halflife_with_times, times=times)
  101. def test_ewm_with_times_getitem(halflife_with_times):
  102. # GH 40164
  103. halflife = halflife_with_times
  104. data = np.arange(10.0)
  105. data[::2] = np.nan
  106. times = date_range("2000", freq="D", periods=10)
  107. df = DataFrame({"A": data, "B": data})
  108. result = df.ewm(halflife=halflife, times=times)["A"].mean()
  109. expected = df.ewm(halflife=1.0)["A"].mean()
  110. tm.assert_series_equal(result, expected)
  111. @pytest.mark.parametrize("arg", ["com", "halflife", "span", "alpha"])
  112. def test_ewm_getitem_attributes_retained(arg, adjust, ignore_na):
  113. # GH 40164
  114. kwargs = {arg: 1, "adjust": adjust, "ignore_na": ignore_na}
  115. ewm = DataFrame({"A": range(1), "B": range(1)}).ewm(**kwargs)
  116. expected = {attr: getattr(ewm, attr) for attr in ewm._attributes}
  117. ewm_slice = ewm["A"]
  118. result = {attr: getattr(ewm, attr) for attr in ewm_slice._attributes}
  119. assert result == expected
  120. def test_ewma_times_adjust_false_raises():
  121. # GH 40098
  122. with pytest.raises(
  123. NotImplementedError, match="times is not supported with adjust=False."
  124. ):
  125. Series(range(1)).ewm(
  126. 0.1, adjust=False, times=date_range("2000", freq="D", periods=1)
  127. )
  128. @pytest.mark.parametrize(
  129. "func, expected",
  130. [
  131. [
  132. "mean",
  133. DataFrame(
  134. {
  135. 0: range(5),
  136. 1: range(4, 9),
  137. 2: [7.428571, 9, 10.571429, 12.142857, 13.714286],
  138. },
  139. dtype=float,
  140. ),
  141. ],
  142. [
  143. "std",
  144. DataFrame(
  145. {
  146. 0: [np.nan] * 5,
  147. 1: [4.242641] * 5,
  148. 2: [4.6291, 5.196152, 5.781745, 6.380775, 6.989788],
  149. }
  150. ),
  151. ],
  152. [
  153. "var",
  154. DataFrame(
  155. {
  156. 0: [np.nan] * 5,
  157. 1: [18.0] * 5,
  158. 2: [21.428571, 27, 33.428571, 40.714286, 48.857143],
  159. }
  160. ),
  161. ],
  162. ],
  163. )
  164. def test_float_dtype_ewma(func, expected, float_numpy_dtype):
  165. # GH#42452
  166. df = DataFrame(
  167. {0: range(5), 1: range(6, 11), 2: range(10, 20, 2)}, dtype=float_numpy_dtype
  168. )
  169. msg = "Support for axis=1 in DataFrame.ewm is deprecated"
  170. with tm.assert_produces_warning(FutureWarning, match=msg):
  171. e = df.ewm(alpha=0.5, axis=1)
  172. result = getattr(e, func)()
  173. tm.assert_frame_equal(result, expected)
  174. def test_times_string_col_raises():
  175. # GH 43265
  176. df = DataFrame(
  177. {"A": np.arange(10.0), "time_col": date_range("2000", freq="D", periods=10)}
  178. )
  179. with pytest.raises(ValueError, match="times must be datetime64"):
  180. df.ewm(halflife="1 day", min_periods=0, times="time_col")
  181. def test_ewm_sum_adjust_false_notimplemented():
  182. data = Series(range(1)).ewm(com=1, adjust=False)
  183. with pytest.raises(NotImplementedError, match="sum is not"):
  184. data.sum()
  185. @pytest.mark.parametrize(
  186. "expected_data, ignore",
  187. [[[10.0, 5.0, 2.5, 11.25], False], [[10.0, 5.0, 5.0, 12.5], True]],
  188. )
  189. def test_ewm_sum(expected_data, ignore):
  190. # xref from Numbagg tests
  191. # https://github.com/numbagg/numbagg/blob/v0.2.1/numbagg/test/test_moving.py#L50
  192. data = Series([10, 0, np.nan, 10])
  193. result = data.ewm(alpha=0.5, ignore_na=ignore).sum()
  194. expected = Series(expected_data)
  195. tm.assert_series_equal(result, expected)
  196. def test_ewma_adjust():
  197. vals = Series(np.zeros(1000))
  198. vals[5] = 1
  199. result = vals.ewm(span=100, adjust=False).mean().sum()
  200. assert np.abs(result - 1) < 1e-2
  201. def test_ewma_cases(adjust, ignore_na):
  202. # try adjust/ignore_na args matrix
  203. s = Series([1.0, 2.0, 4.0, 8.0])
  204. if adjust:
  205. expected = Series([1.0, 1.6, 2.736842, 4.923077])
  206. else:
  207. expected = Series([1.0, 1.333333, 2.222222, 4.148148])
  208. result = s.ewm(com=2.0, adjust=adjust, ignore_na=ignore_na).mean()
  209. tm.assert_series_equal(result, expected)
  210. def test_ewma_nan_handling():
  211. s = Series([1.0] + [np.nan] * 5 + [1.0])
  212. result = s.ewm(com=5).mean()
  213. tm.assert_series_equal(result, Series([1.0] * len(s)))
  214. s = Series([np.nan] * 2 + [1.0] + [np.nan] * 2 + [1.0])
  215. result = s.ewm(com=5).mean()
  216. tm.assert_series_equal(result, Series([np.nan] * 2 + [1.0] * 4))
  217. @pytest.mark.parametrize(
  218. "s, adjust, ignore_na, w",
  219. [
  220. (
  221. Series([np.nan, 1.0, 101.0]),
  222. True,
  223. False,
  224. [np.nan, (1.0 - (1.0 / (1.0 + 2.0))), 1.0],
  225. ),
  226. (
  227. Series([np.nan, 1.0, 101.0]),
  228. True,
  229. True,
  230. [np.nan, (1.0 - (1.0 / (1.0 + 2.0))), 1.0],
  231. ),
  232. (
  233. Series([np.nan, 1.0, 101.0]),
  234. False,
  235. False,
  236. [np.nan, (1.0 - (1.0 / (1.0 + 2.0))), (1.0 / (1.0 + 2.0))],
  237. ),
  238. (
  239. Series([np.nan, 1.0, 101.0]),
  240. False,
  241. True,
  242. [np.nan, (1.0 - (1.0 / (1.0 + 2.0))), (1.0 / (1.0 + 2.0))],
  243. ),
  244. (
  245. Series([1.0, np.nan, 101.0]),
  246. True,
  247. False,
  248. [(1.0 - (1.0 / (1.0 + 2.0))) ** 2, np.nan, 1.0],
  249. ),
  250. (
  251. Series([1.0, np.nan, 101.0]),
  252. True,
  253. True,
  254. [(1.0 - (1.0 / (1.0 + 2.0))), np.nan, 1.0],
  255. ),
  256. (
  257. Series([1.0, np.nan, 101.0]),
  258. False,
  259. False,
  260. [(1.0 - (1.0 / (1.0 + 2.0))) ** 2, np.nan, (1.0 / (1.0 + 2.0))],
  261. ),
  262. (
  263. Series([1.0, np.nan, 101.0]),
  264. False,
  265. True,
  266. [(1.0 - (1.0 / (1.0 + 2.0))), np.nan, (1.0 / (1.0 + 2.0))],
  267. ),
  268. (
  269. Series([np.nan, 1.0, np.nan, np.nan, 101.0, np.nan]),
  270. True,
  271. False,
  272. [np.nan, (1.0 - (1.0 / (1.0 + 2.0))) ** 3, np.nan, np.nan, 1.0, np.nan],
  273. ),
  274. (
  275. Series([np.nan, 1.0, np.nan, np.nan, 101.0, np.nan]),
  276. True,
  277. True,
  278. [np.nan, (1.0 - (1.0 / (1.0 + 2.0))), np.nan, np.nan, 1.0, np.nan],
  279. ),
  280. (
  281. Series([np.nan, 1.0, np.nan, np.nan, 101.0, np.nan]),
  282. False,
  283. False,
  284. [
  285. np.nan,
  286. (1.0 - (1.0 / (1.0 + 2.0))) ** 3,
  287. np.nan,
  288. np.nan,
  289. (1.0 / (1.0 + 2.0)),
  290. np.nan,
  291. ],
  292. ),
  293. (
  294. Series([np.nan, 1.0, np.nan, np.nan, 101.0, np.nan]),
  295. False,
  296. True,
  297. [
  298. np.nan,
  299. (1.0 - (1.0 / (1.0 + 2.0))),
  300. np.nan,
  301. np.nan,
  302. (1.0 / (1.0 + 2.0)),
  303. np.nan,
  304. ],
  305. ),
  306. (
  307. Series([1.0, np.nan, 101.0, 50.0]),
  308. True,
  309. False,
  310. [
  311. (1.0 - (1.0 / (1.0 + 2.0))) ** 3,
  312. np.nan,
  313. (1.0 - (1.0 / (1.0 + 2.0))),
  314. 1.0,
  315. ],
  316. ),
  317. (
  318. Series([1.0, np.nan, 101.0, 50.0]),
  319. True,
  320. True,
  321. [
  322. (1.0 - (1.0 / (1.0 + 2.0))) ** 2,
  323. np.nan,
  324. (1.0 - (1.0 / (1.0 + 2.0))),
  325. 1.0,
  326. ],
  327. ),
  328. (
  329. Series([1.0, np.nan, 101.0, 50.0]),
  330. False,
  331. False,
  332. [
  333. (1.0 - (1.0 / (1.0 + 2.0))) ** 3,
  334. np.nan,
  335. (1.0 - (1.0 / (1.0 + 2.0))) * (1.0 / (1.0 + 2.0)),
  336. (1.0 / (1.0 + 2.0))
  337. * ((1.0 - (1.0 / (1.0 + 2.0))) ** 2 + (1.0 / (1.0 + 2.0))),
  338. ],
  339. ),
  340. (
  341. Series([1.0, np.nan, 101.0, 50.0]),
  342. False,
  343. True,
  344. [
  345. (1.0 - (1.0 / (1.0 + 2.0))) ** 2,
  346. np.nan,
  347. (1.0 - (1.0 / (1.0 + 2.0))) * (1.0 / (1.0 + 2.0)),
  348. (1.0 / (1.0 + 2.0)),
  349. ],
  350. ),
  351. ],
  352. )
  353. def test_ewma_nan_handling_cases(s, adjust, ignore_na, w):
  354. # GH 7603
  355. expected = (s.multiply(w).cumsum() / Series(w).cumsum()).ffill()
  356. result = s.ewm(com=2.0, adjust=adjust, ignore_na=ignore_na).mean()
  357. tm.assert_series_equal(result, expected)
  358. if ignore_na is False:
  359. # check that ignore_na defaults to False
  360. result = s.ewm(com=2.0, adjust=adjust).mean()
  361. tm.assert_series_equal(result, expected)
  362. def test_ewm_alpha():
  363. # GH 10789
  364. arr = np.random.default_rng(2).standard_normal(100)
  365. locs = np.arange(20, 40)
  366. arr[locs] = np.nan
  367. s = Series(arr)
  368. a = s.ewm(alpha=0.61722699889169674).mean()
  369. b = s.ewm(com=0.62014947789973052).mean()
  370. c = s.ewm(span=2.240298955799461).mean()
  371. d = s.ewm(halflife=0.721792864318).mean()
  372. tm.assert_series_equal(a, b)
  373. tm.assert_series_equal(a, c)
  374. tm.assert_series_equal(a, d)
  375. def test_ewm_domain_checks():
  376. # GH 12492
  377. arr = np.random.default_rng(2).standard_normal(100)
  378. locs = np.arange(20, 40)
  379. arr[locs] = np.nan
  380. s = Series(arr)
  381. msg = "comass must satisfy: comass >= 0"
  382. with pytest.raises(ValueError, match=msg):
  383. s.ewm(com=-0.1)
  384. s.ewm(com=0.0)
  385. s.ewm(com=0.1)
  386. msg = "span must satisfy: span >= 1"
  387. with pytest.raises(ValueError, match=msg):
  388. s.ewm(span=-0.1)
  389. with pytest.raises(ValueError, match=msg):
  390. s.ewm(span=0.0)
  391. with pytest.raises(ValueError, match=msg):
  392. s.ewm(span=0.9)
  393. s.ewm(span=1.0)
  394. s.ewm(span=1.1)
  395. msg = "halflife must satisfy: halflife > 0"
  396. with pytest.raises(ValueError, match=msg):
  397. s.ewm(halflife=-0.1)
  398. with pytest.raises(ValueError, match=msg):
  399. s.ewm(halflife=0.0)
  400. s.ewm(halflife=0.1)
  401. msg = "alpha must satisfy: 0 < alpha <= 1"
  402. with pytest.raises(ValueError, match=msg):
  403. s.ewm(alpha=-0.1)
  404. with pytest.raises(ValueError, match=msg):
  405. s.ewm(alpha=0.0)
  406. s.ewm(alpha=0.1)
  407. s.ewm(alpha=1.0)
  408. with pytest.raises(ValueError, match=msg):
  409. s.ewm(alpha=1.1)
  410. @pytest.mark.parametrize("method", ["mean", "std", "var"])
  411. def test_ew_empty_series(method):
  412. vals = Series([], dtype=np.float64)
  413. ewm = vals.ewm(3)
  414. result = getattr(ewm, method)()
  415. tm.assert_almost_equal(result, vals)
  416. @pytest.mark.parametrize("min_periods", [0, 1])
  417. @pytest.mark.parametrize("name", ["mean", "var", "std"])
  418. def test_ew_min_periods(min_periods, name):
  419. # excluding NaNs correctly
  420. arr = np.random.default_rng(2).standard_normal(50)
  421. arr[:10] = np.nan
  422. arr[-10:] = np.nan
  423. s = Series(arr)
  424. # check min_periods
  425. # GH 7898
  426. result = getattr(s.ewm(com=50, min_periods=2), name)()
  427. assert result[:11].isna().all()
  428. assert not result[11:].isna().any()
  429. result = getattr(s.ewm(com=50, min_periods=min_periods), name)()
  430. if name == "mean":
  431. assert result[:10].isna().all()
  432. assert not result[10:].isna().any()
  433. else:
  434. # ewm.std, ewm.var (with bias=False) require at least
  435. # two values
  436. assert result[:11].isna().all()
  437. assert not result[11:].isna().any()
  438. # check series of length 0
  439. result = getattr(Series(dtype=object).ewm(com=50, min_periods=min_periods), name)()
  440. tm.assert_series_equal(result, Series(dtype="float64"))
  441. # check series of length 1
  442. result = getattr(Series([1.0]).ewm(50, min_periods=min_periods), name)()
  443. if name == "mean":
  444. tm.assert_series_equal(result, Series([1.0]))
  445. else:
  446. # ewm.std, ewm.var with bias=False require at least
  447. # two values
  448. tm.assert_series_equal(result, Series([np.nan]))
  449. # pass in ints
  450. result2 = getattr(Series(np.arange(50)).ewm(span=10), name)()
  451. assert result2.dtype == np.float64
  452. @pytest.mark.parametrize("name", ["cov", "corr"])
  453. def test_ewm_corr_cov(name):
  454. A = Series(np.random.default_rng(2).standard_normal(50), index=range(50))
  455. B = A[2:] + np.random.default_rng(2).standard_normal(48)
  456. A[:10] = np.nan
  457. B.iloc[-10:] = np.nan
  458. result = getattr(A.ewm(com=20, min_periods=5), name)(B)
  459. assert np.isnan(result.values[:14]).all()
  460. assert not np.isnan(result.values[14:]).any()
  461. @pytest.mark.parametrize("min_periods", [0, 1, 2])
  462. @pytest.mark.parametrize("name", ["cov", "corr"])
  463. def test_ewm_corr_cov_min_periods(name, min_periods):
  464. # GH 7898
  465. A = Series(np.random.default_rng(2).standard_normal(50), index=range(50))
  466. B = A[2:] + np.random.default_rng(2).standard_normal(48)
  467. A[:10] = np.nan
  468. B.iloc[-10:] = np.nan
  469. result = getattr(A.ewm(com=20, min_periods=min_periods), name)(B)
  470. # binary functions (ewmcov, ewmcorr) with bias=False require at
  471. # least two values
  472. assert np.isnan(result.values[:11]).all()
  473. assert not np.isnan(result.values[11:]).any()
  474. # check series of length 0
  475. empty = Series([], dtype=np.float64)
  476. result = getattr(empty.ewm(com=50, min_periods=min_periods), name)(empty)
  477. tm.assert_series_equal(result, empty)
  478. # check series of length 1
  479. result = getattr(Series([1.0]).ewm(com=50, min_periods=min_periods), name)(
  480. Series([1.0])
  481. )
  482. tm.assert_series_equal(result, Series([np.nan]))
  483. @pytest.mark.parametrize("name", ["cov", "corr"])
  484. def test_different_input_array_raise_exception(name):
  485. A = Series(np.random.default_rng(2).standard_normal(50), index=range(50))
  486. A[:10] = np.nan
  487. msg = "other must be a DataFrame or Series"
  488. # exception raised is Exception
  489. with pytest.raises(ValueError, match=msg):
  490. getattr(A.ewm(com=20, min_periods=5), name)(
  491. np.random.default_rng(2).standard_normal(50)
  492. )
  493. @pytest.mark.parametrize("name", ["var", "std", "mean"])
  494. def test_ewma_series(series, name):
  495. series_result = getattr(series.ewm(com=10), name)()
  496. assert isinstance(series_result, Series)
  497. @pytest.mark.parametrize("name", ["var", "std", "mean"])
  498. def test_ewma_frame(frame, name):
  499. frame_result = getattr(frame.ewm(com=10), name)()
  500. assert isinstance(frame_result, DataFrame)
  501. def test_ewma_span_com_args(series):
  502. A = series.ewm(com=9.5).mean()
  503. B = series.ewm(span=20).mean()
  504. tm.assert_almost_equal(A, B)
  505. msg = "comass, span, halflife, and alpha are mutually exclusive"
  506. with pytest.raises(ValueError, match=msg):
  507. series.ewm(com=9.5, span=20)
  508. msg = "Must pass one of comass, span, halflife, or alpha"
  509. with pytest.raises(ValueError, match=msg):
  510. series.ewm().mean()
  511. def test_ewma_halflife_arg(series):
  512. A = series.ewm(com=13.932726172912965).mean()
  513. B = series.ewm(halflife=10.0).mean()
  514. tm.assert_almost_equal(A, B)
  515. msg = "comass, span, halflife, and alpha are mutually exclusive"
  516. with pytest.raises(ValueError, match=msg):
  517. series.ewm(span=20, halflife=50)
  518. with pytest.raises(ValueError, match=msg):
  519. series.ewm(com=9.5, halflife=50)
  520. with pytest.raises(ValueError, match=msg):
  521. series.ewm(com=9.5, span=20, halflife=50)
  522. msg = "Must pass one of comass, span, halflife, or alpha"
  523. with pytest.raises(ValueError, match=msg):
  524. series.ewm()
  525. def test_ewm_alpha_arg(series):
  526. # GH 10789
  527. s = series
  528. msg = "Must pass one of comass, span, halflife, or alpha"
  529. with pytest.raises(ValueError, match=msg):
  530. s.ewm()
  531. msg = "comass, span, halflife, and alpha are mutually exclusive"
  532. with pytest.raises(ValueError, match=msg):
  533. s.ewm(com=10.0, alpha=0.5)
  534. with pytest.raises(ValueError, match=msg):
  535. s.ewm(span=10.0, alpha=0.5)
  536. with pytest.raises(ValueError, match=msg):
  537. s.ewm(halflife=10.0, alpha=0.5)
  538. @pytest.mark.parametrize("func", ["cov", "corr"])
  539. def test_ewm_pairwise_cov_corr(func, frame):
  540. result = getattr(frame.ewm(span=10, min_periods=5), func)()
  541. result = result.loc[(slice(None), 1), 5]
  542. result.index = result.index.droplevel(1)
  543. expected = getattr(frame[1].ewm(span=10, min_periods=5), func)(frame[5])
  544. tm.assert_series_equal(result, expected, check_names=False)
  545. def test_numeric_only_frame(arithmetic_win_operators, numeric_only):
  546. # GH#46560
  547. kernel = arithmetic_win_operators
  548. df = DataFrame({"a": [1], "b": 2, "c": 3})
  549. df["c"] = df["c"].astype(object)
  550. ewm = df.ewm(span=2, min_periods=1)
  551. op = getattr(ewm, kernel, None)
  552. if op is not None:
  553. result = op(numeric_only=numeric_only)
  554. columns = ["a", "b"] if numeric_only else ["a", "b", "c"]
  555. expected = df[columns].agg([kernel]).reset_index(drop=True).astype(float)
  556. assert list(expected.columns) == columns
  557. tm.assert_frame_equal(result, expected)
  558. @pytest.mark.parametrize("kernel", ["corr", "cov"])
  559. @pytest.mark.parametrize("use_arg", [True, False])
  560. def test_numeric_only_corr_cov_frame(kernel, numeric_only, use_arg):
  561. # GH#46560
  562. df = DataFrame({"a": [1, 2, 3], "b": 2, "c": 3})
  563. df["c"] = df["c"].astype(object)
  564. arg = (df,) if use_arg else ()
  565. ewm = df.ewm(span=2, min_periods=1)
  566. op = getattr(ewm, kernel)
  567. result = op(*arg, numeric_only=numeric_only)
  568. # Compare result to op using float dtypes, dropping c when numeric_only is True
  569. columns = ["a", "b"] if numeric_only else ["a", "b", "c"]
  570. df2 = df[columns].astype(float)
  571. arg2 = (df2,) if use_arg else ()
  572. ewm2 = df2.ewm(span=2, min_periods=1)
  573. op2 = getattr(ewm2, kernel)
  574. expected = op2(*arg2, numeric_only=numeric_only)
  575. tm.assert_frame_equal(result, expected)
  576. @pytest.mark.parametrize("dtype", [int, object])
  577. def test_numeric_only_series(arithmetic_win_operators, numeric_only, dtype):
  578. # GH#46560
  579. kernel = arithmetic_win_operators
  580. ser = Series([1], dtype=dtype)
  581. ewm = ser.ewm(span=2, min_periods=1)
  582. op = getattr(ewm, kernel, None)
  583. if op is None:
  584. # Nothing to test
  585. pytest.skip("No op to test")
  586. if numeric_only and dtype is object:
  587. msg = f"ExponentialMovingWindow.{kernel} does not implement numeric_only"
  588. with pytest.raises(NotImplementedError, match=msg):
  589. op(numeric_only=numeric_only)
  590. else:
  591. result = op(numeric_only=numeric_only)
  592. expected = ser.agg([kernel]).reset_index(drop=True).astype(float)
  593. tm.assert_series_equal(result, expected)
  594. @pytest.mark.parametrize("kernel", ["corr", "cov"])
  595. @pytest.mark.parametrize("use_arg", [True, False])
  596. @pytest.mark.parametrize("dtype", [int, object])
  597. def test_numeric_only_corr_cov_series(kernel, use_arg, numeric_only, dtype):
  598. # GH#46560
  599. ser = Series([1, 2, 3], dtype=dtype)
  600. arg = (ser,) if use_arg else ()
  601. ewm = ser.ewm(span=2, min_periods=1)
  602. op = getattr(ewm, kernel)
  603. if numeric_only and dtype is object:
  604. msg = f"ExponentialMovingWindow.{kernel} does not implement numeric_only"
  605. with pytest.raises(NotImplementedError, match=msg):
  606. op(*arg, numeric_only=numeric_only)
  607. else:
  608. result = op(*arg, numeric_only=numeric_only)
  609. ser2 = ser.astype(float)
  610. arg2 = (ser2,) if use_arg else ()
  611. ewm2 = ser2.ewm(span=2, min_periods=1)
  612. op2 = getattr(ewm2, kernel)
  613. expected = op2(*arg2, numeric_only=numeric_only)
  614. tm.assert_series_equal(result, expected)