test_numba.py 4.1 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129
  1. import numpy as np
  2. import pytest
  3. from pandas.compat import is_platform_arm
  4. import pandas.util._test_decorators as td
  5. import pandas as pd
  6. from pandas import (
  7. DataFrame,
  8. Index,
  9. )
  10. import pandas._testing as tm
  11. from pandas.util.version import Version
  12. pytestmark = [td.skip_if_no("numba"), pytest.mark.single_cpu, pytest.mark.skipif()]
  13. numba = pytest.importorskip("numba")
  14. pytestmark.append(
  15. pytest.mark.skipif(
  16. Version(numba.__version__) == Version("0.61") and is_platform_arm(),
  17. reason=f"Segfaults on ARM platforms with numba {numba.__version__}",
  18. )
  19. )
  20. @pytest.fixture(params=[0, 1])
  21. def apply_axis(request):
  22. return request.param
  23. def test_numba_vs_python_noop(float_frame, apply_axis):
  24. func = lambda x: x
  25. result = float_frame.apply(func, engine="numba", axis=apply_axis)
  26. expected = float_frame.apply(func, engine="python", axis=apply_axis)
  27. tm.assert_frame_equal(result, expected)
  28. def test_numba_vs_python_string_index():
  29. # GH#56189
  30. df = DataFrame(
  31. 1,
  32. index=Index(["a", "b"], dtype=pd.StringDtype(na_value=np.nan)),
  33. columns=Index(["x", "y"], dtype=pd.StringDtype(na_value=np.nan)),
  34. )
  35. func = lambda x: x
  36. result = df.apply(func, engine="numba", axis=0)
  37. expected = df.apply(func, engine="python", axis=0)
  38. tm.assert_frame_equal(
  39. result, expected, check_column_type=False, check_index_type=False
  40. )
  41. def test_numba_vs_python_indexing():
  42. frame = DataFrame(
  43. {"a": [1, 2, 3], "b": [4, 5, 6], "c": [7.0, 8.0, 9.0]},
  44. index=Index(["A", "B", "C"]),
  45. )
  46. row_func = lambda x: x["c"]
  47. result = frame.apply(row_func, engine="numba", axis=1)
  48. expected = frame.apply(row_func, engine="python", axis=1)
  49. tm.assert_series_equal(result, expected)
  50. col_func = lambda x: x["A"]
  51. result = frame.apply(col_func, engine="numba", axis=0)
  52. expected = frame.apply(col_func, engine="python", axis=0)
  53. tm.assert_series_equal(result, expected)
  54. @pytest.mark.parametrize(
  55. "reduction",
  56. [lambda x: x.mean(), lambda x: x.min(), lambda x: x.max(), lambda x: x.sum()],
  57. )
  58. def test_numba_vs_python_reductions(reduction, apply_axis):
  59. df = DataFrame(np.ones((4, 4), dtype=np.float64))
  60. result = df.apply(reduction, engine="numba", axis=apply_axis)
  61. expected = df.apply(reduction, engine="python", axis=apply_axis)
  62. tm.assert_series_equal(result, expected)
  63. @pytest.mark.parametrize("colnames", [[1, 2, 3], [1.0, 2.0, 3.0]])
  64. def test_numba_numeric_colnames(colnames):
  65. # Check that numeric column names lower properly and can be indxed on
  66. df = DataFrame(
  67. np.array([[1, 2, 3], [4, 5, 6], [7, 8, 9]], dtype=np.int64), columns=colnames
  68. )
  69. first_col = colnames[0]
  70. f = lambda x: x[first_col] # Get the first column
  71. result = df.apply(f, engine="numba", axis=1)
  72. expected = df.apply(f, engine="python", axis=1)
  73. tm.assert_series_equal(result, expected)
  74. def test_numba_parallel_unsupported(float_frame):
  75. f = lambda x: x
  76. with pytest.raises(
  77. NotImplementedError,
  78. match="Parallel apply is not supported when raw=False and engine='numba'",
  79. ):
  80. float_frame.apply(f, engine="numba", engine_kwargs={"parallel": True})
  81. def test_numba_nonunique_unsupported(apply_axis):
  82. f = lambda x: x
  83. df = DataFrame({"a": [1, 2]}, index=Index(["a", "a"]))
  84. with pytest.raises(
  85. NotImplementedError,
  86. match="The index/columns must be unique when raw=False and engine='numba'",
  87. ):
  88. df.apply(f, engine="numba", axis=apply_axis)
  89. def test_numba_unsupported_dtypes(apply_axis):
  90. pytest.importorskip("pyarrow")
  91. f = lambda x: x
  92. df = DataFrame({"a": [1, 2], "b": ["a", "b"], "c": [4, 5]})
  93. df["c"] = df["c"].astype("double[pyarrow]")
  94. with pytest.raises(
  95. ValueError,
  96. match="Column b must have a numeric dtype. Found 'object|str' instead",
  97. ):
  98. df.apply(f, engine="numba", axis=apply_axis)
  99. with pytest.raises(
  100. ValueError,
  101. match="Column c is backed by an extension array, "
  102. "which is not supported by the numba engine.",
  103. ):
  104. df["c"].to_frame().apply(f, engine="numba", axis=apply_axis)