test_online.py 3.6 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112
  1. import numpy as np
  2. import pytest
  3. from pandas.compat import is_platform_arm
  4. from pandas import (
  5. DataFrame,
  6. Series,
  7. )
  8. import pandas._testing as tm
  9. from pandas.util.version import Version
  10. pytestmark = [pytest.mark.single_cpu]
  11. numba = pytest.importorskip("numba")
  12. pytestmark.append(
  13. pytest.mark.skipif(
  14. Version(numba.__version__) == Version("0.61") and is_platform_arm(),
  15. reason=f"Segfaults on ARM platforms with numba {numba.__version__}",
  16. )
  17. )
  18. @pytest.mark.filterwarnings("ignore")
  19. # Filter warnings when parallel=True and the function can't be parallelized by Numba
  20. class TestEWM:
  21. def test_invalid_update(self):
  22. df = DataFrame({"a": range(5), "b": range(5)})
  23. online_ewm = df.head(2).ewm(0.5).online()
  24. with pytest.raises(
  25. ValueError,
  26. match="Must call mean with update=None first before passing update",
  27. ):
  28. online_ewm.mean(update=df.head(1))
  29. @pytest.mark.slow
  30. @pytest.mark.parametrize(
  31. "obj", [DataFrame({"a": range(5), "b": range(5)}), Series(range(5), name="foo")]
  32. )
  33. def test_online_vs_non_online_mean(
  34. self, obj, nogil, parallel, nopython, adjust, ignore_na
  35. ):
  36. expected = obj.ewm(0.5, adjust=adjust, ignore_na=ignore_na).mean()
  37. engine_kwargs = {"nogil": nogil, "parallel": parallel, "nopython": nopython}
  38. online_ewm = (
  39. obj.head(2)
  40. .ewm(0.5, adjust=adjust, ignore_na=ignore_na)
  41. .online(engine_kwargs=engine_kwargs)
  42. )
  43. # Test resetting once
  44. for _ in range(2):
  45. result = online_ewm.mean()
  46. tm.assert_equal(result, expected.head(2))
  47. result = online_ewm.mean(update=obj.tail(3))
  48. tm.assert_equal(result, expected.tail(3))
  49. online_ewm.reset()
  50. @pytest.mark.xfail(raises=NotImplementedError)
  51. @pytest.mark.parametrize(
  52. "obj", [DataFrame({"a": range(5), "b": range(5)}), Series(range(5), name="foo")]
  53. )
  54. def test_update_times_mean(
  55. self, obj, nogil, parallel, nopython, adjust, ignore_na, halflife_with_times
  56. ):
  57. times = Series(
  58. np.array(
  59. ["2020-01-01", "2020-01-05", "2020-01-07", "2020-01-17", "2020-01-21"],
  60. dtype="datetime64[ns]",
  61. )
  62. )
  63. expected = obj.ewm(
  64. 0.5,
  65. adjust=adjust,
  66. ignore_na=ignore_na,
  67. times=times,
  68. halflife=halflife_with_times,
  69. ).mean()
  70. engine_kwargs = {"nogil": nogil, "parallel": parallel, "nopython": nopython}
  71. online_ewm = (
  72. obj.head(2)
  73. .ewm(
  74. 0.5,
  75. adjust=adjust,
  76. ignore_na=ignore_na,
  77. times=times.head(2),
  78. halflife=halflife_with_times,
  79. )
  80. .online(engine_kwargs=engine_kwargs)
  81. )
  82. # Test resetting once
  83. for _ in range(2):
  84. result = online_ewm.mean()
  85. tm.assert_equal(result, expected.head(2))
  86. result = online_ewm.mean(update=obj.tail(3), update_times=times.tail(3))
  87. tm.assert_equal(result, expected.tail(3))
  88. online_ewm.reset()
  89. @pytest.mark.parametrize("method", ["aggregate", "std", "corr", "cov", "var"])
  90. def test_ewm_notimplementederror_raises(self, method):
  91. ser = Series(range(10))
  92. kwargs = {}
  93. if method == "aggregate":
  94. kwargs["func"] = lambda x: x
  95. with pytest.raises(NotImplementedError, match=".* is not implemented."):
  96. getattr(ser.ewm(1).online(), method)(**kwargs)