test_interval.py 7.9 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231
  1. import numpy as np
  2. import pytest
  3. import pandas as pd
  4. from pandas import (
  5. Index,
  6. Interval,
  7. IntervalIndex,
  8. Timedelta,
  9. Timestamp,
  10. date_range,
  11. timedelta_range,
  12. )
  13. import pandas._testing as tm
  14. from pandas.core.arrays import IntervalArray
  15. @pytest.fixture(
  16. params=[
  17. (Index([0, 2, 4]), Index([1, 3, 5])),
  18. (Index([0.0, 1.0, 2.0]), Index([1.0, 2.0, 3.0])),
  19. (timedelta_range("0 days", periods=3), timedelta_range("1 day", periods=3)),
  20. (date_range("20170101", periods=3), date_range("20170102", periods=3)),
  21. (
  22. date_range("20170101", periods=3, tz="US/Eastern"),
  23. date_range("20170102", periods=3, tz="US/Eastern"),
  24. ),
  25. ],
  26. ids=lambda x: str(x[0].dtype),
  27. )
  28. def left_right_dtypes(request):
  29. """
  30. Fixture for building an IntervalArray from various dtypes
  31. """
  32. return request.param
  33. class TestAttributes:
  34. @pytest.mark.parametrize(
  35. "left, right",
  36. [
  37. (0, 1),
  38. (Timedelta("0 days"), Timedelta("1 day")),
  39. (Timestamp("2018-01-01"), Timestamp("2018-01-02")),
  40. (
  41. Timestamp("2018-01-01", tz="US/Eastern"),
  42. Timestamp("2018-01-02", tz="US/Eastern"),
  43. ),
  44. ],
  45. )
  46. @pytest.mark.parametrize("constructor", [IntervalArray, IntervalIndex])
  47. def test_is_empty(self, constructor, left, right, closed):
  48. # GH27219
  49. tuples = [(left, left), (left, right), np.nan]
  50. expected = np.array([closed != "both", False, False])
  51. result = constructor.from_tuples(tuples, closed=closed).is_empty
  52. tm.assert_numpy_array_equal(result, expected)
  53. class TestMethods:
  54. @pytest.mark.parametrize("new_closed", ["left", "right", "both", "neither"])
  55. def test_set_closed(self, closed, new_closed):
  56. # GH 21670
  57. array = IntervalArray.from_breaks(range(10), closed=closed)
  58. result = array.set_closed(new_closed)
  59. expected = IntervalArray.from_breaks(range(10), closed=new_closed)
  60. tm.assert_extension_array_equal(result, expected)
  61. @pytest.mark.parametrize(
  62. "other",
  63. [
  64. Interval(0, 1, closed="right"),
  65. IntervalArray.from_breaks([1, 2, 3, 4], closed="right"),
  66. ],
  67. )
  68. def test_where_raises(self, other):
  69. # GH#45768 The IntervalArray methods raises; the Series method coerces
  70. ser = pd.Series(IntervalArray.from_breaks([1, 2, 3, 4], closed="left"))
  71. mask = np.array([True, False, True])
  72. match = "'value.closed' is 'right', expected 'left'."
  73. with pytest.raises(ValueError, match=match):
  74. ser.array._where(mask, other)
  75. res = ser.where(mask, other=other)
  76. expected = ser.astype(object).where(mask, other)
  77. tm.assert_series_equal(res, expected)
  78. def test_shift(self):
  79. # https://github.com/pandas-dev/pandas/issues/31495, GH#22428, GH#31502
  80. a = IntervalArray.from_breaks([1, 2, 3])
  81. result = a.shift()
  82. # int -> float
  83. expected = IntervalArray.from_tuples([(np.nan, np.nan), (1.0, 2.0)])
  84. tm.assert_interval_array_equal(result, expected)
  85. msg = "can only insert Interval objects and NA into an IntervalArray"
  86. with pytest.raises(TypeError, match=msg):
  87. a.shift(1, fill_value=pd.NaT)
  88. def test_shift_datetime(self):
  89. # GH#31502, GH#31504
  90. a = IntervalArray.from_breaks(date_range("2000", periods=4))
  91. result = a.shift(2)
  92. expected = a.take([-1, -1, 0], allow_fill=True)
  93. tm.assert_interval_array_equal(result, expected)
  94. result = a.shift(-1)
  95. expected = a.take([1, 2, -1], allow_fill=True)
  96. tm.assert_interval_array_equal(result, expected)
  97. msg = "can only insert Interval objects and NA into an IntervalArray"
  98. with pytest.raises(TypeError, match=msg):
  99. a.shift(1, fill_value=np.timedelta64("NaT", "ns"))
  100. class TestSetitem:
  101. def test_set_na(self, left_right_dtypes):
  102. left, right = left_right_dtypes
  103. left = left.copy(deep=True)
  104. right = right.copy(deep=True)
  105. result = IntervalArray.from_arrays(left, right)
  106. if result.dtype.subtype.kind not in ["m", "M"]:
  107. msg = "'value' should be an interval type, got <.*NaTType'> instead."
  108. with pytest.raises(TypeError, match=msg):
  109. result[0] = pd.NaT
  110. if result.dtype.subtype.kind in ["i", "u"]:
  111. msg = "Cannot set float NaN to integer-backed IntervalArray"
  112. # GH#45484 TypeError, not ValueError, matches what we get with
  113. # non-NA un-holdable value.
  114. with pytest.raises(TypeError, match=msg):
  115. result[0] = np.nan
  116. return
  117. result[0] = np.nan
  118. expected_left = Index([left._na_value] + list(left[1:]))
  119. expected_right = Index([right._na_value] + list(right[1:]))
  120. expected = IntervalArray.from_arrays(expected_left, expected_right)
  121. tm.assert_extension_array_equal(result, expected)
  122. def test_setitem_mismatched_closed(self):
  123. arr = IntervalArray.from_breaks(range(4))
  124. orig = arr.copy()
  125. other = arr.set_closed("both")
  126. msg = "'value.closed' is 'both', expected 'right'"
  127. with pytest.raises(ValueError, match=msg):
  128. arr[0] = other[0]
  129. with pytest.raises(ValueError, match=msg):
  130. arr[:1] = other[:1]
  131. with pytest.raises(ValueError, match=msg):
  132. arr[:0] = other[:0]
  133. with pytest.raises(ValueError, match=msg):
  134. arr[:] = other[::-1]
  135. with pytest.raises(ValueError, match=msg):
  136. arr[:] = list(other[::-1])
  137. with pytest.raises(ValueError, match=msg):
  138. arr[:] = other[::-1].astype(object)
  139. with pytest.raises(ValueError, match=msg):
  140. arr[:] = other[::-1].astype("category")
  141. # empty list should be no-op
  142. arr[:0] = []
  143. tm.assert_interval_array_equal(arr, orig)
  144. class TestReductions:
  145. def test_min_max_invalid_axis(self, left_right_dtypes):
  146. left, right = left_right_dtypes
  147. left = left.copy(deep=True)
  148. right = right.copy(deep=True)
  149. arr = IntervalArray.from_arrays(left, right)
  150. msg = "`axis` must be fewer than the number of dimensions"
  151. for axis in [-2, 1]:
  152. with pytest.raises(ValueError, match=msg):
  153. arr.min(axis=axis)
  154. with pytest.raises(ValueError, match=msg):
  155. arr.max(axis=axis)
  156. msg = "'>=' not supported between"
  157. with pytest.raises(TypeError, match=msg):
  158. arr.min(axis="foo")
  159. with pytest.raises(TypeError, match=msg):
  160. arr.max(axis="foo")
  161. def test_min_max(self, left_right_dtypes, index_or_series_or_array):
  162. # GH#44746
  163. left, right = left_right_dtypes
  164. left = left.copy(deep=True)
  165. right = right.copy(deep=True)
  166. arr = IntervalArray.from_arrays(left, right)
  167. # The expected results below are only valid if monotonic
  168. assert left.is_monotonic_increasing
  169. assert Index(arr).is_monotonic_increasing
  170. MIN = arr[0]
  171. MAX = arr[-1]
  172. indexer = np.arange(len(arr))
  173. np.random.default_rng(2).shuffle(indexer)
  174. arr = arr.take(indexer)
  175. arr_na = arr.insert(2, np.nan)
  176. arr = index_or_series_or_array(arr)
  177. arr_na = index_or_series_or_array(arr_na)
  178. for skipna in [True, False]:
  179. res = arr.min(skipna=skipna)
  180. assert res == MIN
  181. assert type(res) == type(MIN)
  182. res = arr.max(skipna=skipna)
  183. assert res == MAX
  184. assert type(res) == type(MAX)
  185. res = arr_na.min(skipna=False)
  186. assert np.isnan(res)
  187. res = arr_na.max(skipna=False)
  188. assert np.isnan(res)
  189. res = arr_na.min(skipna=True)
  190. assert res == MIN
  191. assert type(res) == type(MIN)
  192. res = arr_na.max(skipna=True)
  193. assert res == MAX
  194. assert type(res) == type(MAX)