test_numba.py 16 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465
  1. import numpy as np
  2. import pytest
  3. from pandas.compat import is_platform_arm
  4. from pandas.errors import NumbaUtilError
  5. import pandas.util._test_decorators as td
  6. from pandas import (
  7. DataFrame,
  8. Series,
  9. option_context,
  10. to_datetime,
  11. )
  12. import pandas._testing as tm
  13. from pandas.util.version import Version
  14. pytestmark = [pytest.mark.single_cpu]
  15. numba = pytest.importorskip("numba")
  16. pytestmark.append(
  17. pytest.mark.skipif(
  18. Version(numba.__version__) == Version("0.61") and is_platform_arm(),
  19. reason=f"Segfaults on ARM platforms with numba {numba.__version__}",
  20. )
  21. )
  22. @pytest.fixture(params=["single", "table"])
  23. def method(request):
  24. """method keyword in rolling/expanding/ewm constructor"""
  25. return request.param
  26. @pytest.fixture(
  27. params=[
  28. ["sum", {}],
  29. ["mean", {}],
  30. ["median", {}],
  31. ["max", {}],
  32. ["min", {}],
  33. ["var", {}],
  34. ["var", {"ddof": 0}],
  35. ["std", {}],
  36. ["std", {"ddof": 0}],
  37. ]
  38. )
  39. def arithmetic_numba_supported_operators(request):
  40. return request.param
  41. @td.skip_if_no("numba")
  42. @pytest.mark.filterwarnings("ignore")
  43. # Filter warnings when parallel=True and the function can't be parallelized by Numba
  44. class TestEngine:
  45. @pytest.mark.parametrize("jit", [True, False])
  46. def test_numba_vs_cython_apply(self, jit, nogil, parallel, nopython, center, step):
  47. def f(x, *args):
  48. arg_sum = 0
  49. for arg in args:
  50. arg_sum += arg
  51. return np.mean(x) + arg_sum
  52. if jit:
  53. import numba
  54. f = numba.jit(f)
  55. engine_kwargs = {"nogil": nogil, "parallel": parallel, "nopython": nopython}
  56. args = (2,)
  57. s = Series(range(10))
  58. result = s.rolling(2, center=center, step=step).apply(
  59. f, args=args, engine="numba", engine_kwargs=engine_kwargs, raw=True
  60. )
  61. expected = s.rolling(2, center=center, step=step).apply(
  62. f, engine="cython", args=args, raw=True
  63. )
  64. tm.assert_series_equal(result, expected)
  65. @pytest.mark.parametrize(
  66. "data",
  67. [
  68. DataFrame(np.eye(5)),
  69. DataFrame(
  70. [
  71. [5, 7, 7, 7, np.nan, np.inf, 4, 3, 3, 3],
  72. [5, 7, 7, 7, np.nan, np.inf, 7, 3, 3, 3],
  73. [np.nan, np.nan, 5, 6, 7, 5, 5, 5, 5, 5],
  74. ]
  75. ).T,
  76. Series(range(5), name="foo"),
  77. Series([20, 10, 10, np.inf, 1, 1, 2, 3]),
  78. Series([20, 10, 10, np.nan, 10, 1, 2, 3]),
  79. ],
  80. )
  81. def test_numba_vs_cython_rolling_methods(
  82. self,
  83. data,
  84. nogil,
  85. parallel,
  86. nopython,
  87. arithmetic_numba_supported_operators,
  88. step,
  89. ):
  90. method, kwargs = arithmetic_numba_supported_operators
  91. engine_kwargs = {"nogil": nogil, "parallel": parallel, "nopython": nopython}
  92. roll = data.rolling(3, step=step)
  93. result = getattr(roll, method)(
  94. engine="numba", engine_kwargs=engine_kwargs, **kwargs
  95. )
  96. expected = getattr(roll, method)(engine="cython", **kwargs)
  97. tm.assert_equal(result, expected)
  98. @pytest.mark.parametrize(
  99. "data", [DataFrame(np.eye(5)), Series(range(5), name="foo")]
  100. )
  101. def test_numba_vs_cython_expanding_methods(
  102. self, data, nogil, parallel, nopython, arithmetic_numba_supported_operators
  103. ):
  104. method, kwargs = arithmetic_numba_supported_operators
  105. engine_kwargs = {"nogil": nogil, "parallel": parallel, "nopython": nopython}
  106. data = DataFrame(np.eye(5))
  107. expand = data.expanding()
  108. result = getattr(expand, method)(
  109. engine="numba", engine_kwargs=engine_kwargs, **kwargs
  110. )
  111. expected = getattr(expand, method)(engine="cython", **kwargs)
  112. tm.assert_equal(result, expected)
  113. @pytest.mark.parametrize("jit", [True, False])
  114. def test_cache_apply(self, jit, nogil, parallel, nopython, step):
  115. # Test that the functions are cached correctly if we switch functions
  116. def func_1(x):
  117. return np.mean(x) + 4
  118. def func_2(x):
  119. return np.std(x) * 5
  120. if jit:
  121. import numba
  122. func_1 = numba.jit(func_1)
  123. func_2 = numba.jit(func_2)
  124. engine_kwargs = {"nogil": nogil, "parallel": parallel, "nopython": nopython}
  125. roll = Series(range(10)).rolling(2, step=step)
  126. result = roll.apply(
  127. func_1, engine="numba", engine_kwargs=engine_kwargs, raw=True
  128. )
  129. expected = roll.apply(func_1, engine="cython", raw=True)
  130. tm.assert_series_equal(result, expected)
  131. result = roll.apply(
  132. func_2, engine="numba", engine_kwargs=engine_kwargs, raw=True
  133. )
  134. expected = roll.apply(func_2, engine="cython", raw=True)
  135. tm.assert_series_equal(result, expected)
  136. # This run should use the cached func_1
  137. result = roll.apply(
  138. func_1, engine="numba", engine_kwargs=engine_kwargs, raw=True
  139. )
  140. expected = roll.apply(func_1, engine="cython", raw=True)
  141. tm.assert_series_equal(result, expected)
  142. @pytest.mark.parametrize(
  143. "window,window_kwargs",
  144. [
  145. ["rolling", {"window": 3, "min_periods": 0}],
  146. ["expanding", {}],
  147. ],
  148. )
  149. def test_dont_cache_args(
  150. self, window, window_kwargs, nogil, parallel, nopython, method
  151. ):
  152. # GH 42287
  153. def add(values, x):
  154. return np.sum(values) + x
  155. engine_kwargs = {"nopython": nopython, "nogil": nogil, "parallel": parallel}
  156. df = DataFrame({"value": [0, 0, 0]})
  157. result = getattr(df, window)(method=method, **window_kwargs).apply(
  158. add, raw=True, engine="numba", engine_kwargs=engine_kwargs, args=(1,)
  159. )
  160. expected = DataFrame({"value": [1.0, 1.0, 1.0]})
  161. tm.assert_frame_equal(result, expected)
  162. result = getattr(df, window)(method=method, **window_kwargs).apply(
  163. add, raw=True, engine="numba", engine_kwargs=engine_kwargs, args=(2,)
  164. )
  165. expected = DataFrame({"value": [2.0, 2.0, 2.0]})
  166. tm.assert_frame_equal(result, expected)
  167. def test_dont_cache_engine_kwargs(self):
  168. # If the user passes a different set of engine_kwargs don't return the same
  169. # jitted function
  170. nogil = False
  171. parallel = True
  172. nopython = True
  173. def func(x):
  174. return nogil + parallel + nopython
  175. engine_kwargs = {"nopython": nopython, "nogil": nogil, "parallel": parallel}
  176. df = DataFrame({"value": [0, 0, 0]})
  177. result = df.rolling(1).apply(
  178. func, raw=True, engine="numba", engine_kwargs=engine_kwargs
  179. )
  180. expected = DataFrame({"value": [2.0, 2.0, 2.0]})
  181. tm.assert_frame_equal(result, expected)
  182. parallel = False
  183. engine_kwargs = {"nopython": nopython, "nogil": nogil, "parallel": parallel}
  184. result = df.rolling(1).apply(
  185. func, raw=True, engine="numba", engine_kwargs=engine_kwargs
  186. )
  187. expected = DataFrame({"value": [1.0, 1.0, 1.0]})
  188. tm.assert_frame_equal(result, expected)
  189. @td.skip_if_no("numba")
  190. class TestEWM:
  191. @pytest.mark.parametrize(
  192. "grouper", [lambda x: x, lambda x: x.groupby("A")], ids=["None", "groupby"]
  193. )
  194. @pytest.mark.parametrize("method", ["mean", "sum"])
  195. def test_invalid_engine(self, grouper, method):
  196. df = DataFrame({"A": ["a", "b", "a", "b"], "B": range(4)})
  197. with pytest.raises(ValueError, match="engine must be either"):
  198. getattr(grouper(df).ewm(com=1.0), method)(engine="foo")
  199. @pytest.mark.parametrize(
  200. "grouper", [lambda x: x, lambda x: x.groupby("A")], ids=["None", "groupby"]
  201. )
  202. @pytest.mark.parametrize("method", ["mean", "sum"])
  203. def test_invalid_engine_kwargs(self, grouper, method):
  204. df = DataFrame({"A": ["a", "b", "a", "b"], "B": range(4)})
  205. with pytest.raises(ValueError, match="cython engine does not"):
  206. getattr(grouper(df).ewm(com=1.0), method)(
  207. engine="cython", engine_kwargs={"nopython": True}
  208. )
  209. @pytest.mark.parametrize("grouper", ["None", "groupby"])
  210. @pytest.mark.parametrize("method", ["mean", "sum"])
  211. def test_cython_vs_numba(
  212. self, grouper, method, nogil, parallel, nopython, ignore_na, adjust
  213. ):
  214. df = DataFrame({"B": range(4)})
  215. if grouper == "None":
  216. grouper = lambda x: x
  217. else:
  218. df["A"] = ["a", "b", "a", "b"]
  219. grouper = lambda x: x.groupby("A")
  220. if method == "sum":
  221. adjust = True
  222. ewm = grouper(df).ewm(com=1.0, adjust=adjust, ignore_na=ignore_na)
  223. engine_kwargs = {"nogil": nogil, "parallel": parallel, "nopython": nopython}
  224. result = getattr(ewm, method)(engine="numba", engine_kwargs=engine_kwargs)
  225. expected = getattr(ewm, method)(engine="cython")
  226. tm.assert_frame_equal(result, expected)
  227. @pytest.mark.parametrize("grouper", ["None", "groupby"])
  228. def test_cython_vs_numba_times(self, grouper, nogil, parallel, nopython, ignore_na):
  229. # GH 40951
  230. df = DataFrame({"B": [0, 0, 1, 1, 2, 2]})
  231. if grouper == "None":
  232. grouper = lambda x: x
  233. else:
  234. grouper = lambda x: x.groupby("A")
  235. df["A"] = ["a", "b", "a", "b", "b", "a"]
  236. halflife = "23 days"
  237. times = to_datetime(
  238. [
  239. "2020-01-01",
  240. "2020-01-01",
  241. "2020-01-02",
  242. "2020-01-10",
  243. "2020-02-23",
  244. "2020-01-03",
  245. ]
  246. )
  247. ewm = grouper(df).ewm(
  248. halflife=halflife, adjust=True, ignore_na=ignore_na, times=times
  249. )
  250. engine_kwargs = {"nogil": nogil, "parallel": parallel, "nopython": nopython}
  251. result = ewm.mean(engine="numba", engine_kwargs=engine_kwargs)
  252. expected = ewm.mean(engine="cython")
  253. tm.assert_frame_equal(result, expected)
  254. @td.skip_if_no("numba")
  255. def test_use_global_config():
  256. def f(x):
  257. return np.mean(x) + 2
  258. s = Series(range(10))
  259. with option_context("compute.use_numba", True):
  260. result = s.rolling(2).apply(f, engine=None, raw=True)
  261. expected = s.rolling(2).apply(f, engine="numba", raw=True)
  262. tm.assert_series_equal(expected, result)
  263. @td.skip_if_no("numba")
  264. def test_invalid_kwargs_nopython():
  265. with pytest.raises(NumbaUtilError, match="numba does not support kwargs with"):
  266. Series(range(1)).rolling(1).apply(
  267. lambda x: x, kwargs={"a": 1}, engine="numba", raw=True
  268. )
  269. @td.skip_if_no("numba")
  270. @pytest.mark.slow
  271. @pytest.mark.filterwarnings("ignore")
  272. # Filter warnings when parallel=True and the function can't be parallelized by Numba
  273. class TestTableMethod:
  274. def test_table_series_valueerror(self):
  275. def f(x):
  276. return np.sum(x, axis=0) + 1
  277. with pytest.raises(
  278. ValueError, match="method='table' not applicable for Series objects."
  279. ):
  280. Series(range(1)).rolling(1, method="table").apply(
  281. f, engine="numba", raw=True
  282. )
  283. def test_table_method_rolling_methods(
  284. self,
  285. axis,
  286. nogil,
  287. parallel,
  288. nopython,
  289. arithmetic_numba_supported_operators,
  290. step,
  291. ):
  292. method, kwargs = arithmetic_numba_supported_operators
  293. engine_kwargs = {"nogil": nogil, "parallel": parallel, "nopython": nopython}
  294. df = DataFrame(np.eye(3))
  295. roll_table = df.rolling(2, method="table", axis=axis, min_periods=0, step=step)
  296. if method in ("var", "std"):
  297. with pytest.raises(NotImplementedError, match=f"{method} not supported"):
  298. getattr(roll_table, method)(
  299. engine_kwargs=engine_kwargs, engine="numba", **kwargs
  300. )
  301. else:
  302. roll_single = df.rolling(
  303. 2, method="single", axis=axis, min_periods=0, step=step
  304. )
  305. result = getattr(roll_table, method)(
  306. engine_kwargs=engine_kwargs, engine="numba", **kwargs
  307. )
  308. expected = getattr(roll_single, method)(
  309. engine_kwargs=engine_kwargs, engine="numba", **kwargs
  310. )
  311. tm.assert_frame_equal(result, expected)
  312. def test_table_method_rolling_apply(self, axis, nogil, parallel, nopython, step):
  313. engine_kwargs = {"nogil": nogil, "parallel": parallel, "nopython": nopython}
  314. def f(x):
  315. return np.sum(x, axis=0) + 1
  316. df = DataFrame(np.eye(3))
  317. result = df.rolling(
  318. 2, method="table", axis=axis, min_periods=0, step=step
  319. ).apply(f, raw=True, engine_kwargs=engine_kwargs, engine="numba")
  320. expected = df.rolling(
  321. 2, method="single", axis=axis, min_periods=0, step=step
  322. ).apply(f, raw=True, engine_kwargs=engine_kwargs, engine="numba")
  323. tm.assert_frame_equal(result, expected)
  324. def test_table_method_rolling_weighted_mean(self, step):
  325. def weighted_mean(x):
  326. arr = np.ones((1, x.shape[1]))
  327. arr[:, :2] = (x[:, :2] * x[:, 2]).sum(axis=0) / x[:, 2].sum()
  328. return arr
  329. df = DataFrame([[1, 2, 0.6], [2, 3, 0.4], [3, 4, 0.2], [4, 5, 0.7]])
  330. result = df.rolling(2, method="table", min_periods=0, step=step).apply(
  331. weighted_mean, raw=True, engine="numba"
  332. )
  333. expected = DataFrame(
  334. [
  335. [1.0, 2.0, 1.0],
  336. [1.8, 2.0, 1.0],
  337. [3.333333, 2.333333, 1.0],
  338. [1.555556, 7, 1.0],
  339. ]
  340. )[::step]
  341. tm.assert_frame_equal(result, expected)
  342. def test_table_method_expanding_apply(self, axis, nogil, parallel, nopython):
  343. engine_kwargs = {"nogil": nogil, "parallel": parallel, "nopython": nopython}
  344. def f(x):
  345. return np.sum(x, axis=0) + 1
  346. df = DataFrame(np.eye(3))
  347. result = df.expanding(method="table", axis=axis).apply(
  348. f, raw=True, engine_kwargs=engine_kwargs, engine="numba"
  349. )
  350. expected = df.expanding(method="single", axis=axis).apply(
  351. f, raw=True, engine_kwargs=engine_kwargs, engine="numba"
  352. )
  353. tm.assert_frame_equal(result, expected)
  354. def test_table_method_expanding_methods(
  355. self, axis, nogil, parallel, nopython, arithmetic_numba_supported_operators
  356. ):
  357. method, kwargs = arithmetic_numba_supported_operators
  358. engine_kwargs = {"nogil": nogil, "parallel": parallel, "nopython": nopython}
  359. df = DataFrame(np.eye(3))
  360. expand_table = df.expanding(method="table", axis=axis)
  361. if method in ("var", "std"):
  362. with pytest.raises(NotImplementedError, match=f"{method} not supported"):
  363. getattr(expand_table, method)(
  364. engine_kwargs=engine_kwargs, engine="numba", **kwargs
  365. )
  366. else:
  367. expand_single = df.expanding(method="single", axis=axis)
  368. result = getattr(expand_table, method)(
  369. engine_kwargs=engine_kwargs, engine="numba", **kwargs
  370. )
  371. expected = getattr(expand_single, method)(
  372. engine_kwargs=engine_kwargs, engine="numba", **kwargs
  373. )
  374. tm.assert_frame_equal(result, expected)
  375. @pytest.mark.parametrize("data", [np.eye(3), np.ones((2, 3)), np.ones((3, 2))])
  376. @pytest.mark.parametrize("method", ["mean", "sum"])
  377. def test_table_method_ewm(self, data, method, axis, nogil, parallel, nopython):
  378. engine_kwargs = {"nogil": nogil, "parallel": parallel, "nopython": nopython}
  379. df = DataFrame(data)
  380. result = getattr(df.ewm(com=1, method="table", axis=axis), method)(
  381. engine_kwargs=engine_kwargs, engine="numba"
  382. )
  383. expected = getattr(df.ewm(com=1, method="single", axis=axis), method)(
  384. engine_kwargs=engine_kwargs, engine="numba"
  385. )
  386. tm.assert_frame_equal(result, expected)
  387. @td.skip_if_no("numba")
  388. def test_npfunc_no_warnings():
  389. df = DataFrame({"col1": [1, 2, 3, 4, 5]})
  390. with tm.assert_produces_warning(False):
  391. df.col1.rolling(2).apply(np.prod, raw=True, engine="numba")