test_interval.py 7.3 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225
  1. import numpy as np
  2. import pytest
  3. from pandas._libs import index as libindex
  4. import pandas as pd
  5. from pandas import (
  6. DataFrame,
  7. IntervalIndex,
  8. Series,
  9. )
  10. import pandas._testing as tm
  11. class TestIntervalIndex:
  12. @pytest.fixture
  13. def series_with_interval_index(self):
  14. return Series(np.arange(5), IntervalIndex.from_breaks(np.arange(6)))
  15. def test_getitem_with_scalar(self, series_with_interval_index, indexer_sl):
  16. ser = series_with_interval_index.copy()
  17. expected = ser.iloc[:3]
  18. tm.assert_series_equal(expected, indexer_sl(ser)[:3])
  19. tm.assert_series_equal(expected, indexer_sl(ser)[:2.5])
  20. tm.assert_series_equal(expected, indexer_sl(ser)[0.1:2.5])
  21. if indexer_sl is tm.loc:
  22. tm.assert_series_equal(expected, ser.loc[-1:3])
  23. expected = ser.iloc[1:4]
  24. tm.assert_series_equal(expected, indexer_sl(ser)[[1.5, 2.5, 3.5]])
  25. tm.assert_series_equal(expected, indexer_sl(ser)[[2, 3, 4]])
  26. tm.assert_series_equal(expected, indexer_sl(ser)[[1.5, 3, 4]])
  27. expected = ser.iloc[2:5]
  28. tm.assert_series_equal(expected, indexer_sl(ser)[ser >= 2])
  29. @pytest.mark.parametrize("direction", ["increasing", "decreasing"])
  30. def test_getitem_nonoverlapping_monotonic(self, direction, closed, indexer_sl):
  31. tpls = [(0, 1), (2, 3), (4, 5)]
  32. if direction == "decreasing":
  33. tpls = tpls[::-1]
  34. idx = IntervalIndex.from_tuples(tpls, closed=closed)
  35. ser = Series(list("abc"), idx)
  36. for key, expected in zip(idx.left, ser):
  37. if idx.closed_left:
  38. assert indexer_sl(ser)[key] == expected
  39. else:
  40. with pytest.raises(KeyError, match=str(key)):
  41. indexer_sl(ser)[key]
  42. for key, expected in zip(idx.right, ser):
  43. if idx.closed_right:
  44. assert indexer_sl(ser)[key] == expected
  45. else:
  46. with pytest.raises(KeyError, match=str(key)):
  47. indexer_sl(ser)[key]
  48. for key, expected in zip(idx.mid, ser):
  49. assert indexer_sl(ser)[key] == expected
  50. def test_getitem_non_matching(self, series_with_interval_index, indexer_sl):
  51. ser = series_with_interval_index.copy()
  52. # this is a departure from our current
  53. # indexing scheme, but simpler
  54. with pytest.raises(KeyError, match=r"\[-1\] not in index"):
  55. indexer_sl(ser)[[-1, 3, 4, 5]]
  56. with pytest.raises(KeyError, match=r"\[-1\] not in index"):
  57. indexer_sl(ser)[[-1, 3]]
  58. def test_loc_getitem_large_series(self, monkeypatch):
  59. size_cutoff = 20
  60. with monkeypatch.context():
  61. monkeypatch.setattr(libindex, "_SIZE_CUTOFF", size_cutoff)
  62. ser = Series(
  63. np.arange(size_cutoff),
  64. index=IntervalIndex.from_breaks(np.arange(size_cutoff + 1)),
  65. )
  66. result1 = ser.loc[:8]
  67. result2 = ser.loc[0:8]
  68. result3 = ser.loc[0:8:1]
  69. tm.assert_series_equal(result1, result2)
  70. tm.assert_series_equal(result1, result3)
  71. def test_loc_getitem_frame(self):
  72. # CategoricalIndex with IntervalIndex categories
  73. df = DataFrame({"A": range(10)})
  74. ser = pd.cut(df.A, 5)
  75. df["B"] = ser
  76. df = df.set_index("B")
  77. result = df.loc[4]
  78. expected = df.iloc[4:6]
  79. tm.assert_frame_equal(result, expected)
  80. with pytest.raises(KeyError, match="10"):
  81. df.loc[10]
  82. # single list-like
  83. result = df.loc[[4]]
  84. expected = df.iloc[4:6]
  85. tm.assert_frame_equal(result, expected)
  86. # non-unique
  87. result = df.loc[[4, 5]]
  88. expected = df.take([4, 5, 4, 5])
  89. tm.assert_frame_equal(result, expected)
  90. msg = (
  91. r"None of \[Index\(\[10\], dtype='object', name='B'\)\] "
  92. r"are in the \[index\]"
  93. )
  94. with pytest.raises(KeyError, match=msg):
  95. df.loc[[10]]
  96. # partial missing
  97. with pytest.raises(KeyError, match=r"\[10\] not in index"):
  98. df.loc[[10, 4]]
  99. def test_getitem_interval_with_nans(self, frame_or_series, indexer_sl):
  100. # GH#41831
  101. index = IntervalIndex([np.nan, np.nan])
  102. key = index[:-1]
  103. obj = frame_or_series(range(2), index=index)
  104. if frame_or_series is DataFrame and indexer_sl is tm.setitem:
  105. obj = obj.T
  106. result = indexer_sl(obj)[key]
  107. expected = obj
  108. tm.assert_equal(result, expected)
  109. def test_setitem_interval_with_slice(self):
  110. # GH#54722
  111. ii = IntervalIndex.from_breaks(range(4, 15))
  112. ser = Series(range(10), index=ii)
  113. orig = ser.copy()
  114. # This should be a no-op (used to raise)
  115. ser.loc[1:3] = 20
  116. tm.assert_series_equal(ser, orig)
  117. ser.loc[6:8] = 19
  118. orig.iloc[1:4] = 19
  119. tm.assert_series_equal(ser, orig)
  120. ser2 = Series(range(5), index=ii[::2])
  121. orig2 = ser2.copy()
  122. # this used to raise
  123. ser2.loc[6:8] = 22 # <- raises on main, sets on branch
  124. orig2.iloc[1] = 22
  125. tm.assert_series_equal(ser2, orig2)
  126. ser2.loc[5:7] = 21
  127. orig2.iloc[:2] = 21
  128. tm.assert_series_equal(ser2, orig2)
  129. class TestIntervalIndexInsideMultiIndex:
  130. def test_mi_intervalindex_slicing_with_scalar(self):
  131. # GH#27456
  132. ii = IntervalIndex.from_arrays(
  133. [0, 1, 10, 11, 0, 1, 10, 11], [1, 2, 11, 12, 1, 2, 11, 12], name="MP"
  134. )
  135. idx = pd.MultiIndex.from_arrays(
  136. [
  137. pd.Index(["FC", "FC", "FC", "FC", "OWNER", "OWNER", "OWNER", "OWNER"]),
  138. pd.Index(
  139. ["RID1", "RID1", "RID2", "RID2", "RID1", "RID1", "RID2", "RID2"]
  140. ),
  141. ii,
  142. ]
  143. )
  144. idx.names = ["Item", "RID", "MP"]
  145. df = DataFrame({"value": [1, 2, 3, 4, 5, 6, 7, 8]})
  146. df.index = idx
  147. query_df = DataFrame(
  148. {
  149. "Item": ["FC", "OWNER", "FC", "OWNER", "OWNER"],
  150. "RID": ["RID1", "RID1", "RID1", "RID2", "RID2"],
  151. "MP": [0.2, 1.5, 1.6, 11.1, 10.9],
  152. }
  153. )
  154. query_df = query_df.sort_index()
  155. idx = pd.MultiIndex.from_arrays([query_df.Item, query_df.RID, query_df.MP])
  156. query_df.index = idx
  157. result = df.value.loc[query_df.index]
  158. # the IntervalIndex level is indexed with floats, which map to
  159. # the intervals containing them. Matching the behavior we would get
  160. # with _only_ an IntervalIndex, we get an IntervalIndex level back.
  161. sliced_level = ii.take([0, 1, 1, 3, 2])
  162. expected_index = pd.MultiIndex.from_arrays(
  163. [idx.get_level_values(0), idx.get_level_values(1), sliced_level]
  164. )
  165. expected = Series([1, 6, 2, 8, 7], index=expected_index, name="value")
  166. tm.assert_series_equal(result, expected)
  167. @pytest.mark.parametrize(
  168. "base",
  169. [101, 1010],
  170. )
  171. def test_reindex_behavior_with_interval_index(self, base):
  172. # GH 51826
  173. ser = Series(
  174. range(base),
  175. index=IntervalIndex.from_arrays(range(base), range(1, base + 1)),
  176. )
  177. expected_result = Series([np.nan, 0], index=[np.nan, 1.0], dtype=float)
  178. result = ser.reindex(index=[np.nan, 1.0])
  179. tm.assert_series_equal(result, expected_result)