test_arithmetic.py 5.8 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192
  1. from datetime import timedelta
  2. import numpy as np
  3. import pytest
  4. from pandas import (
  5. Interval,
  6. Timedelta,
  7. Timestamp,
  8. )
  9. import pandas._testing as tm
  10. class TestIntervalArithmetic:
  11. def test_interval_add(self, closed):
  12. interval = Interval(0, 1, closed=closed)
  13. expected = Interval(1, 2, closed=closed)
  14. result = interval + 1
  15. assert result == expected
  16. result = 1 + interval
  17. assert result == expected
  18. result = interval
  19. result += 1
  20. assert result == expected
  21. msg = r"unsupported operand type\(s\) for \+"
  22. with pytest.raises(TypeError, match=msg):
  23. interval + interval
  24. with pytest.raises(TypeError, match=msg):
  25. interval + "foo"
  26. def test_interval_sub(self, closed):
  27. interval = Interval(0, 1, closed=closed)
  28. expected = Interval(-1, 0, closed=closed)
  29. result = interval - 1
  30. assert result == expected
  31. result = interval
  32. result -= 1
  33. assert result == expected
  34. msg = r"unsupported operand type\(s\) for -"
  35. with pytest.raises(TypeError, match=msg):
  36. interval - interval
  37. with pytest.raises(TypeError, match=msg):
  38. interval - "foo"
  39. def test_interval_mult(self, closed):
  40. interval = Interval(0, 1, closed=closed)
  41. expected = Interval(0, 2, closed=closed)
  42. result = interval * 2
  43. assert result == expected
  44. result = 2 * interval
  45. assert result == expected
  46. result = interval
  47. result *= 2
  48. assert result == expected
  49. msg = r"unsupported operand type\(s\) for \*"
  50. with pytest.raises(TypeError, match=msg):
  51. interval * interval
  52. msg = r"can\'t multiply sequence by non-int"
  53. with pytest.raises(TypeError, match=msg):
  54. interval * "foo"
  55. def test_interval_div(self, closed):
  56. interval = Interval(0, 1, closed=closed)
  57. expected = Interval(0, 0.5, closed=closed)
  58. result = interval / 2.0
  59. assert result == expected
  60. result = interval
  61. result /= 2.0
  62. assert result == expected
  63. msg = r"unsupported operand type\(s\) for /"
  64. with pytest.raises(TypeError, match=msg):
  65. interval / interval
  66. with pytest.raises(TypeError, match=msg):
  67. interval / "foo"
  68. def test_interval_floordiv(self, closed):
  69. interval = Interval(1, 2, closed=closed)
  70. expected = Interval(0, 1, closed=closed)
  71. result = interval // 2
  72. assert result == expected
  73. result = interval
  74. result //= 2
  75. assert result == expected
  76. msg = r"unsupported operand type\(s\) for //"
  77. with pytest.raises(TypeError, match=msg):
  78. interval // interval
  79. with pytest.raises(TypeError, match=msg):
  80. interval // "foo"
  81. @pytest.mark.parametrize("method", ["__add__", "__sub__"])
  82. @pytest.mark.parametrize(
  83. "interval",
  84. [
  85. Interval(
  86. Timestamp("2017-01-01 00:00:00"), Timestamp("2018-01-01 00:00:00")
  87. ),
  88. Interval(Timedelta(days=7), Timedelta(days=14)),
  89. ],
  90. )
  91. @pytest.mark.parametrize(
  92. "delta", [Timedelta(days=7), timedelta(7), np.timedelta64(7, "D")]
  93. )
  94. def test_time_interval_add_subtract_timedelta(self, interval, delta, method):
  95. # https://github.com/pandas-dev/pandas/issues/32023
  96. result = getattr(interval, method)(delta)
  97. left = getattr(interval.left, method)(delta)
  98. right = getattr(interval.right, method)(delta)
  99. expected = Interval(left, right)
  100. assert result == expected
  101. @pytest.mark.parametrize("interval", [Interval(1, 2), Interval(1.0, 2.0)])
  102. @pytest.mark.parametrize(
  103. "delta", [Timedelta(days=7), timedelta(7), np.timedelta64(7, "D")]
  104. )
  105. def test_numeric_interval_add_timedelta_raises(self, interval, delta):
  106. # https://github.com/pandas-dev/pandas/issues/32023
  107. msg = "|".join(
  108. [
  109. "unsupported operand",
  110. "cannot use operands",
  111. "Only numeric, Timestamp and Timedelta endpoints are allowed",
  112. ]
  113. )
  114. with pytest.raises((TypeError, ValueError), match=msg):
  115. interval + delta
  116. with pytest.raises((TypeError, ValueError), match=msg):
  117. delta + interval
  118. @pytest.mark.parametrize("klass", [timedelta, np.timedelta64, Timedelta])
  119. def test_timedelta_add_timestamp_interval(self, klass):
  120. delta = klass(0)
  121. expected = Interval(Timestamp("2020-01-01"), Timestamp("2020-02-01"))
  122. result = delta + expected
  123. assert result == expected
  124. result = expected + delta
  125. assert result == expected
  126. class TestIntervalComparisons:
  127. def test_interval_equal(self):
  128. assert Interval(0, 1) == Interval(0, 1, closed="right")
  129. assert Interval(0, 1) != Interval(0, 1, closed="left")
  130. assert Interval(0, 1) != 0
  131. def test_interval_comparison(self):
  132. msg = (
  133. "'<' not supported between instances of "
  134. "'pandas._libs.interval.Interval' and 'int'"
  135. )
  136. with pytest.raises(TypeError, match=msg):
  137. Interval(0, 1) < 2
  138. assert Interval(0, 1) < Interval(1, 2)
  139. assert Interval(0, 1) < Interval(0, 2)
  140. assert Interval(0, 1) < Interval(0.5, 1.5)
  141. assert Interval(0, 1) <= Interval(0, 1)
  142. assert Interval(0, 1) > Interval(-1, 2)
  143. assert Interval(0, 1) >= Interval(0, 1)
  144. def test_equality_comparison_broadcasts_over_array(self):
  145. # https://github.com/pandas-dev/pandas/issues/35931
  146. interval = Interval(0, 1)
  147. arr = np.array([interval, interval])
  148. result = interval == arr
  149. expected = np.array([True, True])
  150. tm.assert_numpy_array_equal(result, expected)