test_expressions.py 14 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466
  1. import operator
  2. import re
  3. import numpy as np
  4. import pytest
  5. from pandas import option_context
  6. import pandas._testing as tm
  7. from pandas.core.api import (
  8. DataFrame,
  9. Index,
  10. Series,
  11. )
  12. from pandas.core.computation import expressions as expr
  13. @pytest.fixture
  14. def _frame():
  15. return DataFrame(
  16. np.random.default_rng(2).standard_normal((10001, 4)),
  17. columns=list("ABCD"),
  18. dtype="float64",
  19. )
  20. @pytest.fixture
  21. def _frame2():
  22. return DataFrame(
  23. np.random.default_rng(2).standard_normal((100, 4)),
  24. columns=list("ABCD"),
  25. dtype="float64",
  26. )
  27. @pytest.fixture
  28. def _mixed(_frame):
  29. return DataFrame(
  30. {
  31. "A": _frame["A"].copy(),
  32. "B": _frame["B"].astype("float32"),
  33. "C": _frame["C"].astype("int64"),
  34. "D": _frame["D"].astype("int32"),
  35. }
  36. )
  37. @pytest.fixture
  38. def _mixed2(_frame2):
  39. return DataFrame(
  40. {
  41. "A": _frame2["A"].copy(),
  42. "B": _frame2["B"].astype("float32"),
  43. "C": _frame2["C"].astype("int64"),
  44. "D": _frame2["D"].astype("int32"),
  45. }
  46. )
  47. @pytest.fixture
  48. def _integer():
  49. return DataFrame(
  50. np.random.default_rng(2).integers(1, 100, size=(10001, 4)),
  51. columns=list("ABCD"),
  52. dtype="int64",
  53. )
  54. @pytest.fixture
  55. def _integer_integers(_integer):
  56. # integers to get a case with zeros
  57. return _integer * np.random.default_rng(2).integers(0, 2, size=np.shape(_integer))
  58. @pytest.fixture
  59. def _integer2():
  60. return DataFrame(
  61. np.random.default_rng(2).integers(1, 100, size=(101, 4)),
  62. columns=list("ABCD"),
  63. dtype="int64",
  64. )
  65. @pytest.fixture
  66. def _array(_frame):
  67. return _frame["A"].values.copy()
  68. @pytest.fixture
  69. def _array2(_frame2):
  70. return _frame2["A"].values.copy()
  71. @pytest.fixture
  72. def _array_mixed(_mixed):
  73. return _mixed["D"].values.copy()
  74. @pytest.fixture
  75. def _array_mixed2(_mixed2):
  76. return _mixed2["D"].values.copy()
  77. @pytest.mark.skipif(not expr.USE_NUMEXPR, reason="not using numexpr")
  78. class TestExpressions:
  79. @staticmethod
  80. def call_op(df, other, flex: bool, opname: str):
  81. if flex:
  82. op = lambda x, y: getattr(x, opname)(y)
  83. op.__name__ = opname
  84. else:
  85. op = getattr(operator, opname)
  86. with option_context("compute.use_numexpr", False):
  87. expected = op(df, other)
  88. expr.get_test_result()
  89. result = op(df, other)
  90. return result, expected
  91. @pytest.mark.parametrize(
  92. "fixture",
  93. [
  94. "_integer",
  95. "_integer2",
  96. "_integer_integers",
  97. "_frame",
  98. "_frame2",
  99. "_mixed",
  100. "_mixed2",
  101. ],
  102. )
  103. @pytest.mark.parametrize("flex", [True, False])
  104. @pytest.mark.parametrize(
  105. "arith", ["add", "sub", "mul", "mod", "truediv", "floordiv"]
  106. )
  107. def test_run_arithmetic(self, request, fixture, flex, arith, monkeypatch):
  108. df = request.getfixturevalue(fixture)
  109. with monkeypatch.context() as m:
  110. m.setattr(expr, "_MIN_ELEMENTS", 0)
  111. result, expected = self.call_op(df, df, flex, arith)
  112. if arith == "truediv":
  113. assert all(x.kind == "f" for x in expected.dtypes.values)
  114. tm.assert_equal(expected, result)
  115. for i in range(len(df.columns)):
  116. result, expected = self.call_op(
  117. df.iloc[:, i], df.iloc[:, i], flex, arith
  118. )
  119. if arith == "truediv":
  120. assert expected.dtype.kind == "f"
  121. tm.assert_equal(expected, result)
  122. @pytest.mark.parametrize(
  123. "fixture",
  124. [
  125. "_integer",
  126. "_integer2",
  127. "_integer_integers",
  128. "_frame",
  129. "_frame2",
  130. "_mixed",
  131. "_mixed2",
  132. ],
  133. )
  134. @pytest.mark.parametrize("flex", [True, False])
  135. def test_run_binary(self, request, fixture, flex, comparison_op, monkeypatch):
  136. """
  137. tests solely that the result is the same whether or not numexpr is
  138. enabled. Need to test whether the function does the correct thing
  139. elsewhere.
  140. """
  141. df = request.getfixturevalue(fixture)
  142. arith = comparison_op.__name__
  143. with option_context("compute.use_numexpr", False):
  144. other = df.copy() + 1
  145. with monkeypatch.context() as m:
  146. m.setattr(expr, "_MIN_ELEMENTS", 0)
  147. expr.set_test_mode(True)
  148. result, expected = self.call_op(df, other, flex, arith)
  149. used_numexpr = expr.get_test_result()
  150. assert used_numexpr, "Did not use numexpr as expected."
  151. tm.assert_equal(expected, result)
  152. for i in range(len(df.columns)):
  153. binary_comp = other.iloc[:, i] + 1
  154. self.call_op(df.iloc[:, i], binary_comp, flex, "add")
  155. def test_invalid(self):
  156. array = np.random.default_rng(2).standard_normal(1_000_001)
  157. array2 = np.random.default_rng(2).standard_normal(100)
  158. # no op
  159. result = expr._can_use_numexpr(operator.add, None, array, array, "evaluate")
  160. assert not result
  161. # min elements
  162. result = expr._can_use_numexpr(operator.add, "+", array2, array2, "evaluate")
  163. assert not result
  164. # ok, we only check on first part of expression
  165. result = expr._can_use_numexpr(operator.add, "+", array, array2, "evaluate")
  166. assert result
  167. @pytest.mark.filterwarnings("ignore:invalid value encountered in:RuntimeWarning")
  168. @pytest.mark.parametrize(
  169. "opname,op_str",
  170. [("add", "+"), ("sub", "-"), ("mul", "*"), ("truediv", "/"), ("pow", "**")],
  171. )
  172. @pytest.mark.parametrize(
  173. "left_fix,right_fix", [("_array", "_array2"), ("_array_mixed", "_array_mixed2")]
  174. )
  175. def test_binary_ops(self, request, opname, op_str, left_fix, right_fix):
  176. left = request.getfixturevalue(left_fix)
  177. right = request.getfixturevalue(right_fix)
  178. def testit(left, right, opname, op_str):
  179. if opname == "pow":
  180. left = np.abs(left)
  181. op = getattr(operator, opname)
  182. # array has 0s
  183. result = expr.evaluate(op, left, left, use_numexpr=True)
  184. expected = expr.evaluate(op, left, left, use_numexpr=False)
  185. tm.assert_numpy_array_equal(result, expected)
  186. result = expr._can_use_numexpr(op, op_str, right, right, "evaluate")
  187. assert not result
  188. with option_context("compute.use_numexpr", False):
  189. testit(left, right, opname, op_str)
  190. expr.set_numexpr_threads(1)
  191. testit(left, right, opname, op_str)
  192. expr.set_numexpr_threads()
  193. testit(left, right, opname, op_str)
  194. @pytest.mark.parametrize(
  195. "left_fix,right_fix", [("_array", "_array2"), ("_array_mixed", "_array_mixed2")]
  196. )
  197. def test_comparison_ops(self, request, comparison_op, left_fix, right_fix):
  198. left = request.getfixturevalue(left_fix)
  199. right = request.getfixturevalue(right_fix)
  200. def testit():
  201. f12 = left + 1
  202. f22 = right + 1
  203. op = comparison_op
  204. result = expr.evaluate(op, left, f12, use_numexpr=True)
  205. expected = expr.evaluate(op, left, f12, use_numexpr=False)
  206. tm.assert_numpy_array_equal(result, expected)
  207. result = expr._can_use_numexpr(op, op, right, f22, "evaluate")
  208. assert not result
  209. with option_context("compute.use_numexpr", False):
  210. testit()
  211. expr.set_numexpr_threads(1)
  212. testit()
  213. expr.set_numexpr_threads()
  214. testit()
  215. @pytest.mark.parametrize("cond", [True, False])
  216. @pytest.mark.parametrize("fixture", ["_frame", "_frame2", "_mixed", "_mixed2"])
  217. def test_where(self, request, cond, fixture):
  218. df = request.getfixturevalue(fixture)
  219. def testit():
  220. c = np.empty(df.shape, dtype=np.bool_)
  221. c.fill(cond)
  222. result = expr.where(c, df.values, df.values + 1)
  223. expected = np.where(c, df.values, df.values + 1)
  224. tm.assert_numpy_array_equal(result, expected)
  225. with option_context("compute.use_numexpr", False):
  226. testit()
  227. expr.set_numexpr_threads(1)
  228. testit()
  229. expr.set_numexpr_threads()
  230. testit()
  231. @pytest.mark.parametrize(
  232. "op_str,opname", [("/", "truediv"), ("//", "floordiv"), ("**", "pow")]
  233. )
  234. def test_bool_ops_raise_on_arithmetic(self, op_str, opname):
  235. df = DataFrame(
  236. {
  237. "a": np.random.default_rng(2).random(10) > 0.5,
  238. "b": np.random.default_rng(2).random(10) > 0.5,
  239. }
  240. )
  241. msg = f"operator '{opname}' not implemented for bool dtypes"
  242. f = getattr(operator, opname)
  243. err_msg = re.escape(msg)
  244. with pytest.raises(NotImplementedError, match=err_msg):
  245. f(df, df)
  246. with pytest.raises(NotImplementedError, match=err_msg):
  247. f(df.a, df.b)
  248. with pytest.raises(NotImplementedError, match=err_msg):
  249. f(df.a, True)
  250. with pytest.raises(NotImplementedError, match=err_msg):
  251. f(False, df.a)
  252. with pytest.raises(NotImplementedError, match=err_msg):
  253. f(False, df)
  254. with pytest.raises(NotImplementedError, match=err_msg):
  255. f(df, True)
  256. @pytest.mark.parametrize(
  257. "op_str,opname", [("+", "add"), ("*", "mul"), ("-", "sub")]
  258. )
  259. def test_bool_ops_warn_on_arithmetic(self, op_str, opname):
  260. n = 10
  261. df = DataFrame(
  262. {
  263. "a": np.random.default_rng(2).random(n) > 0.5,
  264. "b": np.random.default_rng(2).random(n) > 0.5,
  265. }
  266. )
  267. subs = {"+": "|", "*": "&", "-": "^"}
  268. sub_funcs = {"|": "or_", "&": "and_", "^": "xor"}
  269. f = getattr(operator, opname)
  270. fe = getattr(operator, sub_funcs[subs[op_str]])
  271. if op_str == "-":
  272. # raises TypeError
  273. return
  274. with tm.use_numexpr(True, min_elements=5):
  275. with tm.assert_produces_warning():
  276. r = f(df, df)
  277. e = fe(df, df)
  278. tm.assert_frame_equal(r, e)
  279. with tm.assert_produces_warning():
  280. r = f(df.a, df.b)
  281. e = fe(df.a, df.b)
  282. tm.assert_series_equal(r, e)
  283. with tm.assert_produces_warning():
  284. r = f(df.a, True)
  285. e = fe(df.a, True)
  286. tm.assert_series_equal(r, e)
  287. with tm.assert_produces_warning():
  288. r = f(False, df.a)
  289. e = fe(False, df.a)
  290. tm.assert_series_equal(r, e)
  291. with tm.assert_produces_warning():
  292. r = f(False, df)
  293. e = fe(False, df)
  294. tm.assert_frame_equal(r, e)
  295. with tm.assert_produces_warning():
  296. r = f(df, True)
  297. e = fe(df, True)
  298. tm.assert_frame_equal(r, e)
  299. @pytest.mark.parametrize(
  300. "test_input,expected",
  301. [
  302. (
  303. DataFrame(
  304. [[0, 1, 2, "aa"], [0, 1, 2, "aa"]], columns=["a", "b", "c", "dtype"]
  305. ),
  306. DataFrame([[False, False], [False, False]], columns=["a", "dtype"]),
  307. ),
  308. (
  309. DataFrame(
  310. [[0, 3, 2, "aa"], [0, 4, 2, "aa"], [0, 1, 1, "bb"]],
  311. columns=["a", "b", "c", "dtype"],
  312. ),
  313. DataFrame(
  314. [[False, False], [False, False], [False, False]],
  315. columns=["a", "dtype"],
  316. ),
  317. ),
  318. ],
  319. )
  320. def test_bool_ops_column_name_dtype(self, test_input, expected):
  321. # GH 22383 - .ne fails if columns containing column name 'dtype'
  322. result = test_input.loc[:, ["a", "dtype"]].ne(test_input.loc[:, ["a", "dtype"]])
  323. tm.assert_frame_equal(result, expected)
  324. @pytest.mark.parametrize(
  325. "arith", ("add", "sub", "mul", "mod", "truediv", "floordiv")
  326. )
  327. @pytest.mark.parametrize("axis", (0, 1))
  328. def test_frame_series_axis(self, axis, arith, _frame, monkeypatch):
  329. # GH#26736 Dataframe.floordiv(Series, axis=1) fails
  330. df = _frame
  331. if axis == 1:
  332. other = df.iloc[0, :]
  333. else:
  334. other = df.iloc[:, 0]
  335. with monkeypatch.context() as m:
  336. m.setattr(expr, "_MIN_ELEMENTS", 0)
  337. op_func = getattr(df, arith)
  338. with option_context("compute.use_numexpr", False):
  339. expected = op_func(other, axis=axis)
  340. result = op_func(other, axis=axis)
  341. tm.assert_frame_equal(expected, result)
  342. @pytest.mark.parametrize(
  343. "op",
  344. [
  345. "__mod__",
  346. "__rmod__",
  347. "__floordiv__",
  348. "__rfloordiv__",
  349. ],
  350. )
  351. @pytest.mark.parametrize("box", [DataFrame, Series, Index])
  352. @pytest.mark.parametrize("scalar", [-5, 5])
  353. def test_python_semantics_with_numexpr_installed(
  354. self, op, box, scalar, monkeypatch
  355. ):
  356. # https://github.com/pandas-dev/pandas/issues/36047
  357. with monkeypatch.context() as m:
  358. m.setattr(expr, "_MIN_ELEMENTS", 0)
  359. data = np.arange(-50, 50)
  360. obj = box(data)
  361. method = getattr(obj, op)
  362. result = method(scalar)
  363. # compare result with numpy
  364. with option_context("compute.use_numexpr", False):
  365. expected = method(scalar)
  366. tm.assert_equal(result, expected)
  367. # compare result element-wise with Python
  368. for i, elem in enumerate(data):
  369. if box == DataFrame:
  370. scalar_result = result.iloc[i, 0]
  371. else:
  372. scalar_result = result[i]
  373. try:
  374. expected = getattr(int(elem), op)(scalar)
  375. except ZeroDivisionError:
  376. pass
  377. else:
  378. assert scalar_result == expected