test_setops.py 8.2 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208
  1. import numpy as np
  2. import pytest
  3. from pandas import (
  4. Index,
  5. IntervalIndex,
  6. Timestamp,
  7. interval_range,
  8. )
  9. import pandas._testing as tm
  10. def monotonic_index(start, end, dtype="int64", closed="right"):
  11. return IntervalIndex.from_breaks(np.arange(start, end, dtype=dtype), closed=closed)
  12. def empty_index(dtype="int64", closed="right"):
  13. return IntervalIndex(np.array([], dtype=dtype), closed=closed)
  14. class TestIntervalIndex:
  15. def test_union(self, closed, sort):
  16. index = monotonic_index(0, 11, closed=closed)
  17. other = monotonic_index(5, 13, closed=closed)
  18. expected = monotonic_index(0, 13, closed=closed)
  19. result = index[::-1].union(other, sort=sort)
  20. if sort in (None, True):
  21. tm.assert_index_equal(result, expected)
  22. else:
  23. tm.assert_index_equal(result.sort_values(), expected)
  24. result = other[::-1].union(index, sort=sort)
  25. if sort in (None, True):
  26. tm.assert_index_equal(result, expected)
  27. else:
  28. tm.assert_index_equal(result.sort_values(), expected)
  29. tm.assert_index_equal(index.union(index, sort=sort), index)
  30. tm.assert_index_equal(index.union(index[:1], sort=sort), index)
  31. def test_union_empty_result(self, closed, sort):
  32. # GH 19101: empty result, same dtype
  33. index = empty_index(dtype="int64", closed=closed)
  34. result = index.union(index, sort=sort)
  35. tm.assert_index_equal(result, index)
  36. # GH 19101: empty result, different numeric dtypes -> common dtype is f8
  37. other = empty_index(dtype="float64", closed=closed)
  38. result = index.union(other, sort=sort)
  39. expected = other
  40. tm.assert_index_equal(result, expected)
  41. other = index.union(index, sort=sort)
  42. tm.assert_index_equal(result, expected)
  43. other = empty_index(dtype="uint64", closed=closed)
  44. result = index.union(other, sort=sort)
  45. tm.assert_index_equal(result, expected)
  46. result = other.union(index, sort=sort)
  47. tm.assert_index_equal(result, expected)
  48. def test_intersection(self, closed, sort):
  49. index = monotonic_index(0, 11, closed=closed)
  50. other = monotonic_index(5, 13, closed=closed)
  51. expected = monotonic_index(5, 11, closed=closed)
  52. result = index[::-1].intersection(other, sort=sort)
  53. if sort in (None, True):
  54. tm.assert_index_equal(result, expected)
  55. else:
  56. tm.assert_index_equal(result.sort_values(), expected)
  57. result = other[::-1].intersection(index, sort=sort)
  58. if sort in (None, True):
  59. tm.assert_index_equal(result, expected)
  60. else:
  61. tm.assert_index_equal(result.sort_values(), expected)
  62. tm.assert_index_equal(index.intersection(index, sort=sort), index)
  63. # GH 26225: nested intervals
  64. index = IntervalIndex.from_tuples([(1, 2), (1, 3), (1, 4), (0, 2)])
  65. other = IntervalIndex.from_tuples([(1, 2), (1, 3)])
  66. expected = IntervalIndex.from_tuples([(1, 2), (1, 3)])
  67. result = index.intersection(other)
  68. tm.assert_index_equal(result, expected)
  69. # GH 26225
  70. index = IntervalIndex.from_tuples([(0, 3), (0, 2)])
  71. other = IntervalIndex.from_tuples([(0, 2), (1, 3)])
  72. expected = IntervalIndex.from_tuples([(0, 2)])
  73. result = index.intersection(other)
  74. tm.assert_index_equal(result, expected)
  75. # GH 26225: duplicate nan element
  76. index = IntervalIndex([np.nan, np.nan])
  77. other = IntervalIndex([np.nan])
  78. expected = IntervalIndex([np.nan])
  79. result = index.intersection(other)
  80. tm.assert_index_equal(result, expected)
  81. def test_intersection_empty_result(self, closed, sort):
  82. index = monotonic_index(0, 11, closed=closed)
  83. # GH 19101: empty result, same dtype
  84. other = monotonic_index(300, 314, closed=closed)
  85. expected = empty_index(dtype="int64", closed=closed)
  86. result = index.intersection(other, sort=sort)
  87. tm.assert_index_equal(result, expected)
  88. # GH 19101: empty result, different numeric dtypes -> common dtype is float64
  89. other = monotonic_index(300, 314, dtype="float64", closed=closed)
  90. result = index.intersection(other, sort=sort)
  91. expected = other[:0]
  92. tm.assert_index_equal(result, expected)
  93. other = monotonic_index(300, 314, dtype="uint64", closed=closed)
  94. result = index.intersection(other, sort=sort)
  95. tm.assert_index_equal(result, expected)
  96. def test_intersection_duplicates(self):
  97. # GH#38743
  98. index = IntervalIndex.from_tuples([(1, 2), (1, 2), (2, 3), (3, 4)])
  99. other = IntervalIndex.from_tuples([(1, 2), (2, 3)])
  100. expected = IntervalIndex.from_tuples([(1, 2), (2, 3)])
  101. result = index.intersection(other)
  102. tm.assert_index_equal(result, expected)
  103. def test_difference(self, closed, sort):
  104. index = IntervalIndex.from_arrays([1, 0, 3, 2], [1, 2, 3, 4], closed=closed)
  105. result = index.difference(index[:1], sort=sort)
  106. expected = index[1:]
  107. if sort is None:
  108. expected = expected.sort_values()
  109. tm.assert_index_equal(result, expected)
  110. # GH 19101: empty result, same dtype
  111. result = index.difference(index, sort=sort)
  112. expected = empty_index(dtype="int64", closed=closed)
  113. tm.assert_index_equal(result, expected)
  114. # GH 19101: empty result, different dtypes
  115. other = IntervalIndex.from_arrays(
  116. index.left.astype("float64"), index.right, closed=closed
  117. )
  118. result = index.difference(other, sort=sort)
  119. tm.assert_index_equal(result, expected)
  120. def test_symmetric_difference(self, closed, sort):
  121. index = monotonic_index(0, 11, closed=closed)
  122. result = index[1:].symmetric_difference(index[:-1], sort=sort)
  123. expected = IntervalIndex([index[0], index[-1]])
  124. if sort in (None, True):
  125. tm.assert_index_equal(result, expected)
  126. else:
  127. tm.assert_index_equal(result.sort_values(), expected)
  128. # GH 19101: empty result, same dtype
  129. result = index.symmetric_difference(index, sort=sort)
  130. expected = empty_index(dtype="int64", closed=closed)
  131. if sort in (None, True):
  132. tm.assert_index_equal(result, expected)
  133. else:
  134. tm.assert_index_equal(result.sort_values(), expected)
  135. # GH 19101: empty result, different dtypes
  136. other = IntervalIndex.from_arrays(
  137. index.left.astype("float64"), index.right, closed=closed
  138. )
  139. result = index.symmetric_difference(other, sort=sort)
  140. expected = empty_index(dtype="float64", closed=closed)
  141. tm.assert_index_equal(result, expected)
  142. @pytest.mark.filterwarnings("ignore:'<' not supported between:RuntimeWarning")
  143. @pytest.mark.parametrize(
  144. "op_name", ["union", "intersection", "difference", "symmetric_difference"]
  145. )
  146. def test_set_incompatible_types(self, closed, op_name, sort):
  147. index = monotonic_index(0, 11, closed=closed)
  148. set_op = getattr(index, op_name)
  149. # TODO: standardize return type of non-union setops type(self vs other)
  150. # non-IntervalIndex
  151. if op_name == "difference":
  152. expected = index
  153. else:
  154. expected = getattr(index.astype("O"), op_name)(Index([1, 2, 3]))
  155. result = set_op(Index([1, 2, 3]), sort=sort)
  156. tm.assert_index_equal(result, expected)
  157. # mixed closed -> cast to object
  158. for other_closed in {"right", "left", "both", "neither"} - {closed}:
  159. other = monotonic_index(0, 11, closed=other_closed)
  160. expected = getattr(index.astype(object), op_name)(other, sort=sort)
  161. if op_name == "difference":
  162. expected = index
  163. result = set_op(other, sort=sort)
  164. tm.assert_index_equal(result, expected)
  165. # GH 19016: incompatible dtypes -> cast to object
  166. other = interval_range(Timestamp("20180101"), periods=9, closed=closed)
  167. expected = getattr(index.astype(object), op_name)(other, sort=sort)
  168. if op_name == "difference":
  169. expected = index
  170. result = set_op(other, sort=sort)
  171. tm.assert_index_equal(result, expected)