test_setops.py 5.7 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168
  1. from datetime import (
  2. datetime,
  3. timedelta,
  4. )
  5. import numpy as np
  6. import pytest
  7. import pandas._testing as tm
  8. from pandas.core.indexes.api import (
  9. Index,
  10. RangeIndex,
  11. )
  12. @pytest.fixture
  13. def index_large():
  14. # large values used in TestUInt64Index where no compat needed with int64/float64
  15. large = [2**63, 2**63 + 10, 2**63 + 15, 2**63 + 20, 2**63 + 25]
  16. return Index(large, dtype=np.uint64)
  17. class TestSetOps:
  18. @pytest.mark.parametrize("dtype", ["f8", "u8", "i8"])
  19. def test_union_non_numeric(self, dtype):
  20. # corner case, non-numeric
  21. index = Index(np.arange(5, dtype=dtype), dtype=dtype)
  22. assert index.dtype == dtype
  23. other = Index([datetime.now() + timedelta(i) for i in range(4)], dtype=object)
  24. result = index.union(other)
  25. expected = Index(np.concatenate((index, other)))
  26. tm.assert_index_equal(result, expected)
  27. result = other.union(index)
  28. expected = Index(np.concatenate((other, index)))
  29. tm.assert_index_equal(result, expected)
  30. def test_intersection(self):
  31. index = Index(range(5), dtype=np.int64)
  32. other = Index([1, 2, 3, 4, 5])
  33. result = index.intersection(other)
  34. expected = Index(np.sort(np.intersect1d(index.values, other.values)))
  35. tm.assert_index_equal(result, expected)
  36. result = other.intersection(index)
  37. expected = Index(
  38. np.sort(np.asarray(np.intersect1d(index.values, other.values)))
  39. )
  40. tm.assert_index_equal(result, expected)
  41. @pytest.mark.parametrize("dtype", ["int64", "uint64"])
  42. def test_int_float_union_dtype(self, dtype):
  43. # https://github.com/pandas-dev/pandas/issues/26778
  44. # [u]int | float -> float
  45. index = Index([0, 2, 3], dtype=dtype)
  46. other = Index([0.5, 1.5], dtype=np.float64)
  47. expected = Index([0.0, 0.5, 1.5, 2.0, 3.0], dtype=np.float64)
  48. result = index.union(other)
  49. tm.assert_index_equal(result, expected)
  50. result = other.union(index)
  51. tm.assert_index_equal(result, expected)
  52. def test_range_float_union_dtype(self):
  53. # https://github.com/pandas-dev/pandas/issues/26778
  54. index = RangeIndex(start=0, stop=3)
  55. other = Index([0.5, 1.5], dtype=np.float64)
  56. result = index.union(other)
  57. expected = Index([0.0, 0.5, 1, 1.5, 2.0], dtype=np.float64)
  58. tm.assert_index_equal(result, expected)
  59. result = other.union(index)
  60. tm.assert_index_equal(result, expected)
  61. def test_range_uint64_union_dtype(self):
  62. # https://github.com/pandas-dev/pandas/issues/26778
  63. index = RangeIndex(start=0, stop=3)
  64. other = Index([0, 10], dtype=np.uint64)
  65. result = index.union(other)
  66. expected = Index([0, 1, 2, 10], dtype=object)
  67. tm.assert_index_equal(result, expected)
  68. result = other.union(index)
  69. tm.assert_index_equal(result, expected)
  70. def test_float64_index_difference(self):
  71. # https://github.com/pandas-dev/pandas/issues/35217
  72. float_index = Index([1.0, 2, 3])
  73. string_index = Index(["1", "2", "3"])
  74. result = float_index.difference(string_index)
  75. tm.assert_index_equal(result, float_index)
  76. result = string_index.difference(float_index)
  77. tm.assert_index_equal(result, string_index)
  78. def test_intersection_uint64_outside_int64_range(self, index_large):
  79. other = Index([2**63, 2**63 + 5, 2**63 + 10, 2**63 + 15, 2**63 + 20])
  80. result = index_large.intersection(other)
  81. expected = Index(np.sort(np.intersect1d(index_large.values, other.values)))
  82. tm.assert_index_equal(result, expected)
  83. result = other.intersection(index_large)
  84. expected = Index(
  85. np.sort(np.asarray(np.intersect1d(index_large.values, other.values)))
  86. )
  87. tm.assert_index_equal(result, expected)
  88. @pytest.mark.parametrize(
  89. "index2,keeps_name",
  90. [
  91. (Index([4, 7, 6, 5, 3], name="index"), True),
  92. (Index([4, 7, 6, 5, 3], name="other"), False),
  93. ],
  94. )
  95. def test_intersection_monotonic(self, index2, keeps_name, sort):
  96. index1 = Index([5, 3, 2, 4, 1], name="index")
  97. expected = Index([5, 3, 4])
  98. if keeps_name:
  99. expected.name = "index"
  100. result = index1.intersection(index2, sort=sort)
  101. if sort is None:
  102. expected = expected.sort_values()
  103. tm.assert_index_equal(result, expected)
  104. def test_symmetric_difference(self, sort):
  105. # smoke
  106. index1 = Index([5, 2, 3, 4], name="index1")
  107. index2 = Index([2, 3, 4, 1])
  108. result = index1.symmetric_difference(index2, sort=sort)
  109. expected = Index([5, 1])
  110. if sort is not None:
  111. tm.assert_index_equal(result, expected)
  112. else:
  113. tm.assert_index_equal(result, expected.sort_values())
  114. assert result.name is None
  115. if sort is None:
  116. expected = expected.sort_values()
  117. tm.assert_index_equal(result, expected)
  118. class TestSetOpsSort:
  119. @pytest.mark.parametrize("slice_", [slice(None), slice(0)])
  120. def test_union_sort_other_special(self, slice_):
  121. # https://github.com/pandas-dev/pandas/issues/24959
  122. idx = Index([1, 0, 2])
  123. # default, sort=None
  124. other = idx[slice_]
  125. tm.assert_index_equal(idx.union(other), idx)
  126. tm.assert_index_equal(other.union(idx), idx)
  127. # sort=False
  128. tm.assert_index_equal(idx.union(other, sort=False), idx)
  129. @pytest.mark.parametrize("slice_", [slice(None), slice(0)])
  130. def test_union_sort_special_true(self, slice_):
  131. idx = Index([1, 0, 2])
  132. # default, sort=None
  133. other = idx[slice_]
  134. result = idx.union(other, sort=True)
  135. expected = Index([0, 1, 2])
  136. tm.assert_index_equal(result, expected)