ops.py 11 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289
  1. from __future__ import annotations
  2. from typing import final
  3. import numpy as np
  4. import pytest
  5. from pandas.core.dtypes.common import is_string_dtype
  6. import pandas as pd
  7. import pandas._testing as tm
  8. from pandas.core import ops
  9. class BaseOpsUtil:
  10. series_scalar_exc: type[Exception] | None = TypeError
  11. frame_scalar_exc: type[Exception] | None = TypeError
  12. series_array_exc: type[Exception] | None = TypeError
  13. divmod_exc: type[Exception] | None = TypeError
  14. def _get_expected_exception(
  15. self, op_name: str, obj, other
  16. ) -> type[Exception] | tuple[type[Exception], ...] | None:
  17. # Find the Exception, if any we expect to raise calling
  18. # obj.__op_name__(other)
  19. # The self.obj_bar_exc pattern isn't great in part because it can depend
  20. # on op_name or dtypes, but we use it here for backward-compatibility.
  21. if op_name in ["__divmod__", "__rdivmod__"]:
  22. result = self.divmod_exc
  23. elif isinstance(obj, pd.Series) and isinstance(other, pd.Series):
  24. result = self.series_array_exc
  25. elif isinstance(obj, pd.Series):
  26. result = self.series_scalar_exc
  27. else:
  28. result = self.frame_scalar_exc
  29. return result
  30. def _cast_pointwise_result(self, op_name: str, obj, other, pointwise_result):
  31. # In _check_op we check that the result of a pointwise operation
  32. # (found via _combine) matches the result of the vectorized
  33. # operation obj.__op_name__(other).
  34. # In some cases pandas dtype inference on the scalar result may not
  35. # give a matching dtype even if both operations are behaving "correctly".
  36. # In these cases, do extra required casting here.
  37. return pointwise_result
  38. def get_op_from_name(self, op_name: str):
  39. return tm.get_op_from_name(op_name)
  40. # Subclasses are not expected to need to override check_opname, _check_op,
  41. # _check_divmod_op, or _combine.
  42. # Ideally any relevant overriding can be done in _cast_pointwise_result,
  43. # get_op_from_name, and the specification of `exc`. If you find a use
  44. # case that still requires overriding _check_op or _combine, please let
  45. # us know at github.com/pandas-dev/pandas/issues
  46. @final
  47. def check_opname(self, ser: pd.Series, op_name: str, other):
  48. exc = self._get_expected_exception(op_name, ser, other)
  49. op = self.get_op_from_name(op_name)
  50. self._check_op(ser, op, other, op_name, exc)
  51. # see comment on check_opname
  52. @final
  53. def _combine(self, obj, other, op):
  54. if isinstance(obj, pd.DataFrame):
  55. if len(obj.columns) != 1:
  56. raise NotImplementedError
  57. expected = obj.iloc[:, 0].combine(other, op).to_frame()
  58. else:
  59. expected = obj.combine(other, op)
  60. return expected
  61. # see comment on check_opname
  62. @final
  63. def _check_op(
  64. self, ser: pd.Series, op, other, op_name: str, exc=NotImplementedError
  65. ):
  66. # Check that the Series/DataFrame arithmetic/comparison method matches
  67. # the pointwise result from _combine.
  68. if exc is None:
  69. result = op(ser, other)
  70. expected = self._combine(ser, other, op)
  71. expected = self._cast_pointwise_result(op_name, ser, other, expected)
  72. assert isinstance(result, type(ser))
  73. tm.assert_equal(result, expected)
  74. else:
  75. with pytest.raises(exc):
  76. op(ser, other)
  77. # see comment on check_opname
  78. @final
  79. def _check_divmod_op(self, ser: pd.Series, op, other):
  80. # check that divmod behavior matches behavior of floordiv+mod
  81. if op is divmod:
  82. exc = self._get_expected_exception("__divmod__", ser, other)
  83. else:
  84. exc = self._get_expected_exception("__rdivmod__", ser, other)
  85. if exc is None:
  86. result_div, result_mod = op(ser, other)
  87. if op is divmod:
  88. expected_div, expected_mod = ser // other, ser % other
  89. else:
  90. expected_div, expected_mod = other // ser, other % ser
  91. tm.assert_series_equal(result_div, expected_div)
  92. tm.assert_series_equal(result_mod, expected_mod)
  93. else:
  94. with pytest.raises(exc):
  95. divmod(ser, other)
  96. class BaseArithmeticOpsTests(BaseOpsUtil):
  97. """
  98. Various Series and DataFrame arithmetic ops methods.
  99. Subclasses supporting various ops should set the class variables
  100. to indicate that they support ops of that kind
  101. * series_scalar_exc = TypeError
  102. * frame_scalar_exc = TypeError
  103. * series_array_exc = TypeError
  104. * divmod_exc = TypeError
  105. """
  106. series_scalar_exc: type[Exception] | None = TypeError
  107. frame_scalar_exc: type[Exception] | None = TypeError
  108. series_array_exc: type[Exception] | None = TypeError
  109. divmod_exc: type[Exception] | None = TypeError
  110. def test_arith_series_with_scalar(self, data, all_arithmetic_operators):
  111. # series & scalar
  112. if all_arithmetic_operators == "__rmod__" and is_string_dtype(data.dtype):
  113. pytest.skip("Skip testing Python string formatting")
  114. op_name = all_arithmetic_operators
  115. ser = pd.Series(data)
  116. self.check_opname(ser, op_name, ser.iloc[0])
  117. def test_arith_frame_with_scalar(self, data, all_arithmetic_operators):
  118. # frame & scalar
  119. if all_arithmetic_operators == "__rmod__" and is_string_dtype(data.dtype):
  120. pytest.skip("Skip testing Python string formatting")
  121. op_name = all_arithmetic_operators
  122. df = pd.DataFrame({"A": data})
  123. self.check_opname(df, op_name, data[0])
  124. def test_arith_series_with_array(self, data, all_arithmetic_operators):
  125. # ndarray & other series
  126. op_name = all_arithmetic_operators
  127. ser = pd.Series(data)
  128. self.check_opname(ser, op_name, pd.Series([ser.iloc[0]] * len(ser)))
  129. def test_divmod(self, data):
  130. ser = pd.Series(data)
  131. self._check_divmod_op(ser, divmod, 1)
  132. self._check_divmod_op(1, ops.rdivmod, ser)
  133. def test_divmod_series_array(self, data, data_for_twos):
  134. ser = pd.Series(data)
  135. self._check_divmod_op(ser, divmod, data)
  136. other = data_for_twos
  137. self._check_divmod_op(other, ops.rdivmod, ser)
  138. other = pd.Series(other)
  139. self._check_divmod_op(other, ops.rdivmod, ser)
  140. def test_add_series_with_extension_array(self, data):
  141. # Check adding an ExtensionArray to a Series of the same dtype matches
  142. # the behavior of adding the arrays directly and then wrapping in a
  143. # Series.
  144. ser = pd.Series(data)
  145. exc = self._get_expected_exception("__add__", ser, data)
  146. if exc is not None:
  147. with pytest.raises(exc):
  148. ser + data
  149. return
  150. result = ser + data
  151. expected = pd.Series(data + data)
  152. tm.assert_series_equal(result, expected)
  153. @pytest.mark.parametrize("box", [pd.Series, pd.DataFrame, pd.Index])
  154. @pytest.mark.parametrize(
  155. "op_name",
  156. [
  157. x
  158. for x in tm.arithmetic_dunder_methods + tm.comparison_dunder_methods
  159. if not x.startswith("__r")
  160. ],
  161. )
  162. def test_direct_arith_with_ndframe_returns_not_implemented(
  163. self, data, box, op_name
  164. ):
  165. # EAs should return NotImplemented for ops with Series/DataFrame/Index
  166. # Pandas takes care of unboxing the series and calling the EA's op.
  167. other = box(data)
  168. if hasattr(data, op_name):
  169. result = getattr(data, op_name)(other)
  170. assert result is NotImplemented
  171. class BaseComparisonOpsTests(BaseOpsUtil):
  172. """Various Series and DataFrame comparison ops methods."""
  173. def _compare_other(self, ser: pd.Series, data, op, other):
  174. if op.__name__ in ["eq", "ne"]:
  175. # comparison should match point-wise comparisons
  176. result = op(ser, other)
  177. expected = ser.combine(other, op)
  178. expected = self._cast_pointwise_result(op.__name__, ser, other, expected)
  179. tm.assert_series_equal(result, expected)
  180. else:
  181. exc = None
  182. try:
  183. result = op(ser, other)
  184. except Exception as err:
  185. exc = err
  186. if exc is None:
  187. # Didn't error, then should match pointwise behavior
  188. expected = ser.combine(other, op)
  189. expected = self._cast_pointwise_result(
  190. op.__name__, ser, other, expected
  191. )
  192. tm.assert_series_equal(result, expected)
  193. else:
  194. with pytest.raises(type(exc)):
  195. ser.combine(other, op)
  196. def test_compare_scalar(self, data, comparison_op):
  197. ser = pd.Series(data)
  198. self._compare_other(ser, data, comparison_op, 0)
  199. def test_compare_array(self, data, comparison_op):
  200. ser = pd.Series(data)
  201. other = pd.Series([data[0]] * len(data), dtype=data.dtype)
  202. self._compare_other(ser, data, comparison_op, other)
  203. class BaseUnaryOpsTests(BaseOpsUtil):
  204. def test_invert(self, data):
  205. ser = pd.Series(data, name="name")
  206. try:
  207. # 10 is an arbitrary choice here, just avoid iterating over
  208. # the whole array to trim test runtime
  209. [~x for x in data[:10]]
  210. except TypeError:
  211. # scalars don't support invert -> we don't expect the vectorized
  212. # operation to succeed
  213. with pytest.raises(TypeError):
  214. ~ser
  215. with pytest.raises(TypeError):
  216. ~data
  217. else:
  218. # Note we do not reuse the pointwise result to construct expected
  219. # because python semantics for negating bools are weird see GH#54569
  220. result = ~ser
  221. expected = pd.Series(~data, name="name")
  222. tm.assert_series_equal(result, expected)
  223. @pytest.mark.parametrize("ufunc", [np.positive, np.negative, np.abs])
  224. def test_unary_ufunc_dunder_equivalence(self, data, ufunc):
  225. # the dunder __pos__ works if and only if np.positive works,
  226. # same for __neg__/np.negative and __abs__/np.abs
  227. attr = {np.positive: "__pos__", np.negative: "__neg__", np.abs: "__abs__"}[
  228. ufunc
  229. ]
  230. exc = None
  231. try:
  232. result = getattr(data, attr)()
  233. except Exception as err:
  234. exc = err
  235. # if __pos__ raised, then so should the ufunc
  236. with pytest.raises((type(exc), TypeError)):
  237. ufunc(data)
  238. else:
  239. alt = ufunc(data)
  240. tm.assert_extension_array_equal(result, alt)