common.py 4.3 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158
  1. """
  2. Assertion helpers for arithmetic tests.
  3. """
  4. import numpy as np
  5. import pytest
  6. from pandas import (
  7. DataFrame,
  8. Index,
  9. Series,
  10. array,
  11. )
  12. import pandas._testing as tm
  13. from pandas.core.arrays import (
  14. BooleanArray,
  15. NumpyExtensionArray,
  16. )
  17. def assert_cannot_add(left, right, msg="cannot add"):
  18. """
  19. Helper function to assert that two objects cannot be added.
  20. Parameters
  21. ----------
  22. left : object
  23. The first operand.
  24. right : object
  25. The second operand.
  26. msg : str, default "cannot add"
  27. The error message expected in the TypeError.
  28. """
  29. with pytest.raises(TypeError, match=msg):
  30. left + right
  31. with pytest.raises(TypeError, match=msg):
  32. right + left
  33. def assert_invalid_addsub_type(left, right, msg=None):
  34. """
  35. Helper function to assert that two objects can
  36. neither be added nor subtracted.
  37. Parameters
  38. ----------
  39. left : object
  40. The first operand.
  41. right : object
  42. The second operand.
  43. msg : str or None, default None
  44. The error message expected in the TypeError.
  45. """
  46. with pytest.raises(TypeError, match=msg):
  47. left + right
  48. with pytest.raises(TypeError, match=msg):
  49. right + left
  50. with pytest.raises(TypeError, match=msg):
  51. left - right
  52. with pytest.raises(TypeError, match=msg):
  53. right - left
  54. def get_upcast_box(left, right, is_cmp: bool = False):
  55. """
  56. Get the box to use for 'expected' in an arithmetic or comparison operation.
  57. Parameters
  58. left : Any
  59. right : Any
  60. is_cmp : bool, default False
  61. Whether the operation is a comparison method.
  62. """
  63. if isinstance(left, DataFrame) or isinstance(right, DataFrame):
  64. return DataFrame
  65. if isinstance(left, Series) or isinstance(right, Series):
  66. if is_cmp and isinstance(left, Index):
  67. # Index does not defer for comparisons
  68. return np.array
  69. return Series
  70. if isinstance(left, Index) or isinstance(right, Index):
  71. if is_cmp:
  72. return np.array
  73. return Index
  74. return tm.to_array
  75. def assert_invalid_comparison(left, right, box):
  76. """
  77. Assert that comparison operations with mismatched types behave correctly.
  78. Parameters
  79. ----------
  80. left : np.ndarray, ExtensionArray, Index, or Series
  81. right : object
  82. box : {pd.DataFrame, pd.Series, pd.Index, pd.array, tm.to_array}
  83. """
  84. # Not for tznaive-tzaware comparison
  85. # Note: not quite the same as how we do this for tm.box_expected
  86. xbox = box if box not in [Index, array] else np.array
  87. def xbox2(x):
  88. # Eventually we'd like this to be tighter, but for now we'll
  89. # just exclude NumpyExtensionArray[bool]
  90. if isinstance(x, NumpyExtensionArray):
  91. return x._ndarray
  92. if isinstance(x, BooleanArray):
  93. # NB: we are assuming no pd.NAs for now
  94. return x.astype(bool)
  95. return x
  96. result = xbox2(left == right)
  97. expected = xbox(np.zeros(result.shape, dtype=np.bool_))
  98. tm.assert_equal(result, expected)
  99. result = xbox2(right == left)
  100. tm.assert_equal(result, xbox(expected))
  101. result = xbox2(left != right)
  102. tm.assert_equal(result, ~expected)
  103. result = xbox2(right != left)
  104. tm.assert_equal(result, xbox(~expected))
  105. msg = "|".join(
  106. [
  107. "Invalid comparison between",
  108. "Cannot compare type",
  109. "not supported between",
  110. "invalid type promotion",
  111. (
  112. # GH#36706 npdev 1.20.0 2020-09-28
  113. r"The DTypes <class 'numpy.dtype\[datetime64\]'> and "
  114. r"<class 'numpy.dtype\[int64\]'> do not have a common DType. "
  115. "For example they cannot be stored in a single array unless the "
  116. "dtype is `object`."
  117. ),
  118. ]
  119. )
  120. with pytest.raises(TypeError, match=msg):
  121. left < right
  122. with pytest.raises(TypeError, match=msg):
  123. left <= right
  124. with pytest.raises(TypeError, match=msg):
  125. left > right
  126. with pytest.raises(TypeError, match=msg):
  127. left >= right
  128. with pytest.raises(TypeError, match=msg):
  129. right < left
  130. with pytest.raises(TypeError, match=msg):
  131. right <= left
  132. with pytest.raises(TypeError, match=msg):
  133. right > left
  134. with pytest.raises(TypeError, match=msg):
  135. right >= left