test_setops.py 25 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525526527528529530531532533534535536537538539540541542543544545546547548549550551552553554555556557558559560561562563564565566567568569570571572573574575576577578579580581582583584585586587588589590591592593594595596597598599600601602603604605606607608609610611612613614615616617618619620621622623624625626627628629630631632633634635636637638639640641642643644645646647648649650651652653654655656657658659660661662663664665666667668669670671672673674675676677678679680681682683684685686687688689690691692693694695696697698699700701702703704705706707708709710711712713714715716717718719720721722723724725726727728729730731732733734735736737738739740741742743744745746747748749750751752753754755756757758759760761762763764765766767768769770771772
  1. import numpy as np
  2. import pytest
  3. import pandas as pd
  4. from pandas import (
  5. CategoricalIndex,
  6. DataFrame,
  7. Index,
  8. IntervalIndex,
  9. MultiIndex,
  10. Series,
  11. )
  12. import pandas._testing as tm
  13. from pandas.api.types import (
  14. is_float_dtype,
  15. is_unsigned_integer_dtype,
  16. )
  17. @pytest.mark.parametrize("case", [0.5, "xxx"])
  18. @pytest.mark.parametrize(
  19. "method", ["intersection", "union", "difference", "symmetric_difference"]
  20. )
  21. def test_set_ops_error_cases(idx, case, sort, method):
  22. # non-iterable input
  23. msg = "Input must be Index or array-like"
  24. with pytest.raises(TypeError, match=msg):
  25. getattr(idx, method)(case, sort=sort)
  26. @pytest.mark.parametrize("klass", [MultiIndex, np.array, Series, list])
  27. def test_intersection_base(idx, sort, klass):
  28. first = idx[2::-1] # first 3 elements reversed
  29. second = idx[:5]
  30. if klass is not MultiIndex:
  31. second = klass(second.values)
  32. intersect = first.intersection(second, sort=sort)
  33. if sort is None:
  34. expected = first.sort_values()
  35. else:
  36. expected = first
  37. tm.assert_index_equal(intersect, expected)
  38. msg = "other must be a MultiIndex or a list of tuples"
  39. with pytest.raises(TypeError, match=msg):
  40. first.intersection([1, 2, 3], sort=sort)
  41. @pytest.mark.arm_slow
  42. @pytest.mark.parametrize("klass", [MultiIndex, np.array, Series, list])
  43. def test_union_base(idx, sort, klass):
  44. first = idx[::-1]
  45. second = idx[:5]
  46. if klass is not MultiIndex:
  47. second = klass(second.values)
  48. union = first.union(second, sort=sort)
  49. if sort is None:
  50. expected = first.sort_values()
  51. else:
  52. expected = first
  53. tm.assert_index_equal(union, expected)
  54. msg = "other must be a MultiIndex or a list of tuples"
  55. with pytest.raises(TypeError, match=msg):
  56. first.union([1, 2, 3], sort=sort)
  57. def test_difference_base(idx, sort):
  58. second = idx[4:]
  59. answer = idx[:4]
  60. result = idx.difference(second, sort=sort)
  61. if sort is None:
  62. answer = answer.sort_values()
  63. assert result.equals(answer)
  64. tm.assert_index_equal(result, answer)
  65. # GH 10149
  66. cases = [klass(second.values) for klass in [np.array, Series, list]]
  67. for case in cases:
  68. result = idx.difference(case, sort=sort)
  69. tm.assert_index_equal(result, answer)
  70. msg = "other must be a MultiIndex or a list of tuples"
  71. with pytest.raises(TypeError, match=msg):
  72. idx.difference([1, 2, 3], sort=sort)
  73. def test_symmetric_difference(idx, sort):
  74. first = idx[1:]
  75. second = idx[:-1]
  76. answer = idx[[-1, 0]]
  77. result = first.symmetric_difference(second, sort=sort)
  78. if sort is None:
  79. answer = answer.sort_values()
  80. tm.assert_index_equal(result, answer)
  81. # GH 10149
  82. cases = [klass(second.values) for klass in [np.array, Series, list]]
  83. for case in cases:
  84. result = first.symmetric_difference(case, sort=sort)
  85. tm.assert_index_equal(result, answer)
  86. msg = "other must be a MultiIndex or a list of tuples"
  87. with pytest.raises(TypeError, match=msg):
  88. first.symmetric_difference([1, 2, 3], sort=sort)
  89. def test_multiindex_symmetric_difference():
  90. # GH 13490
  91. idx = MultiIndex.from_product([["a", "b"], ["A", "B"]], names=["a", "b"])
  92. result = idx.symmetric_difference(idx)
  93. assert result.names == idx.names
  94. idx2 = idx.copy().rename(["A", "B"])
  95. result = idx.symmetric_difference(idx2)
  96. assert result.names == [None, None]
  97. def test_empty(idx):
  98. # GH 15270
  99. assert not idx.empty
  100. assert idx[:0].empty
  101. def test_difference(idx, sort):
  102. first = idx
  103. result = first.difference(idx[-3:], sort=sort)
  104. vals = idx[:-3].values
  105. if sort is None:
  106. vals = sorted(vals)
  107. expected = MultiIndex.from_tuples(vals, sortorder=0, names=idx.names)
  108. assert isinstance(result, MultiIndex)
  109. assert result.equals(expected)
  110. assert result.names == idx.names
  111. tm.assert_index_equal(result, expected)
  112. # empty difference: reflexive
  113. result = idx.difference(idx, sort=sort)
  114. expected = idx[:0]
  115. assert result.equals(expected)
  116. assert result.names == idx.names
  117. # empty difference: superset
  118. result = idx[-3:].difference(idx, sort=sort)
  119. expected = idx[:0]
  120. assert result.equals(expected)
  121. assert result.names == idx.names
  122. # empty difference: degenerate
  123. result = idx[:0].difference(idx, sort=sort)
  124. expected = idx[:0]
  125. assert result.equals(expected)
  126. assert result.names == idx.names
  127. # names not the same
  128. chunklet = idx[-3:]
  129. chunklet.names = ["foo", "baz"]
  130. result = first.difference(chunklet, sort=sort)
  131. assert result.names == (None, None)
  132. # empty, but non-equal
  133. result = idx.difference(idx.sortlevel(1)[0], sort=sort)
  134. assert len(result) == 0
  135. # raise Exception called with non-MultiIndex
  136. result = first.difference(first.values, sort=sort)
  137. assert result.equals(first[:0])
  138. # name from empty array
  139. result = first.difference([], sort=sort)
  140. assert first.equals(result)
  141. assert first.names == result.names
  142. # name from non-empty array
  143. result = first.difference([("foo", "one")], sort=sort)
  144. expected = MultiIndex.from_tuples(
  145. [("bar", "one"), ("baz", "two"), ("foo", "two"), ("qux", "one"), ("qux", "two")]
  146. )
  147. expected.names = first.names
  148. assert first.names == result.names
  149. msg = "other must be a MultiIndex or a list of tuples"
  150. with pytest.raises(TypeError, match=msg):
  151. first.difference([1, 2, 3, 4, 5], sort=sort)
  152. def test_difference_sort_special():
  153. # GH-24959
  154. idx = MultiIndex.from_product([[1, 0], ["a", "b"]])
  155. # sort=None, the default
  156. result = idx.difference([])
  157. tm.assert_index_equal(result, idx)
  158. def test_difference_sort_special_true():
  159. idx = MultiIndex.from_product([[1, 0], ["a", "b"]])
  160. result = idx.difference([], sort=True)
  161. expected = MultiIndex.from_product([[0, 1], ["a", "b"]])
  162. tm.assert_index_equal(result, expected)
  163. def test_difference_sort_incomparable():
  164. # GH-24959
  165. idx = MultiIndex.from_product([[1, pd.Timestamp("2000"), 2], ["a", "b"]])
  166. other = MultiIndex.from_product([[3, pd.Timestamp("2000"), 4], ["c", "d"]])
  167. # sort=None, the default
  168. msg = "sort order is undefined for incomparable objects"
  169. with tm.assert_produces_warning(RuntimeWarning, match=msg):
  170. result = idx.difference(other)
  171. tm.assert_index_equal(result, idx)
  172. # sort=False
  173. result = idx.difference(other, sort=False)
  174. tm.assert_index_equal(result, idx)
  175. def test_difference_sort_incomparable_true():
  176. idx = MultiIndex.from_product([[1, pd.Timestamp("2000"), 2], ["a", "b"]])
  177. other = MultiIndex.from_product([[3, pd.Timestamp("2000"), 4], ["c", "d"]])
  178. # TODO: this is raising in constructing a Categorical when calling
  179. # algos.safe_sort. Should we catch and re-raise with a better message?
  180. msg = "'values' is not ordered, please explicitly specify the categories order "
  181. with pytest.raises(TypeError, match=msg):
  182. idx.difference(other, sort=True)
  183. def test_union(idx, sort):
  184. piece1 = idx[:5][::-1]
  185. piece2 = idx[3:]
  186. the_union = piece1.union(piece2, sort=sort)
  187. if sort in (None, False):
  188. tm.assert_index_equal(the_union.sort_values(), idx.sort_values())
  189. else:
  190. tm.assert_index_equal(the_union, idx)
  191. # corner case, pass self or empty thing:
  192. the_union = idx.union(idx, sort=sort)
  193. tm.assert_index_equal(the_union, idx)
  194. the_union = idx.union(idx[:0], sort=sort)
  195. tm.assert_index_equal(the_union, idx)
  196. tuples = idx.values
  197. result = idx[:4].union(tuples[4:], sort=sort)
  198. if sort is None:
  199. tm.assert_index_equal(result.sort_values(), idx.sort_values())
  200. else:
  201. assert result.equals(idx)
  202. def test_union_with_regular_index(idx, using_infer_string):
  203. other = Index(["A", "B", "C"])
  204. result = other.union(idx)
  205. assert ("foo", "one") in result
  206. assert "B" in result
  207. if using_infer_string:
  208. with pytest.raises(NotImplementedError, match="Can only union"):
  209. idx.union(other)
  210. else:
  211. msg = "The values in the array are unorderable"
  212. with tm.assert_produces_warning(RuntimeWarning, match=msg):
  213. result2 = idx.union(other)
  214. # This is more consistent now, if sorting fails then we don't sort at all
  215. # in the MultiIndex case.
  216. assert not result.equals(result2)
  217. def test_intersection(idx, sort):
  218. piece1 = idx[:5][::-1]
  219. piece2 = idx[3:]
  220. the_int = piece1.intersection(piece2, sort=sort)
  221. if sort in (None, True):
  222. tm.assert_index_equal(the_int, idx[3:5])
  223. else:
  224. tm.assert_index_equal(the_int.sort_values(), idx[3:5])
  225. # corner case, pass self
  226. the_int = idx.intersection(idx, sort=sort)
  227. tm.assert_index_equal(the_int, idx)
  228. # empty intersection: disjoint
  229. empty = idx[:2].intersection(idx[2:], sort=sort)
  230. expected = idx[:0]
  231. assert empty.equals(expected)
  232. tuples = idx.values
  233. result = idx.intersection(tuples)
  234. assert result.equals(idx)
  235. @pytest.mark.parametrize(
  236. "method", ["intersection", "union", "difference", "symmetric_difference"]
  237. )
  238. def test_setop_with_categorical(idx, sort, method):
  239. other = idx.to_flat_index().astype("category")
  240. res_names = [None] * idx.nlevels
  241. result = getattr(idx, method)(other, sort=sort)
  242. expected = getattr(idx, method)(idx, sort=sort).rename(res_names)
  243. tm.assert_index_equal(result, expected)
  244. result = getattr(idx, method)(other[:5], sort=sort)
  245. expected = getattr(idx, method)(idx[:5], sort=sort).rename(res_names)
  246. tm.assert_index_equal(result, expected)
  247. def test_intersection_non_object(idx, sort):
  248. other = Index(range(3), name="foo")
  249. result = idx.intersection(other, sort=sort)
  250. expected = MultiIndex(levels=idx.levels, codes=[[]] * idx.nlevels, names=None)
  251. tm.assert_index_equal(result, expected, exact=True)
  252. # if we pass a length-0 ndarray (i.e. no name, we retain our idx.name)
  253. result = idx.intersection(np.asarray(other)[:0], sort=sort)
  254. expected = MultiIndex(levels=idx.levels, codes=[[]] * idx.nlevels, names=idx.names)
  255. tm.assert_index_equal(result, expected, exact=True)
  256. msg = "other must be a MultiIndex or a list of tuples"
  257. with pytest.raises(TypeError, match=msg):
  258. # With non-zero length non-index, we try and fail to convert to tuples
  259. idx.intersection(np.asarray(other), sort=sort)
  260. def test_intersect_equal_sort():
  261. # GH-24959
  262. idx = MultiIndex.from_product([[1, 0], ["a", "b"]])
  263. tm.assert_index_equal(idx.intersection(idx, sort=False), idx)
  264. tm.assert_index_equal(idx.intersection(idx, sort=None), idx)
  265. def test_intersect_equal_sort_true():
  266. idx = MultiIndex.from_product([[1, 0], ["a", "b"]])
  267. expected = MultiIndex.from_product([[0, 1], ["a", "b"]])
  268. result = idx.intersection(idx, sort=True)
  269. tm.assert_index_equal(result, expected)
  270. @pytest.mark.parametrize("slice_", [slice(None), slice(0)])
  271. def test_union_sort_other_empty(slice_):
  272. # https://github.com/pandas-dev/pandas/issues/24959
  273. idx = MultiIndex.from_product([[1, 0], ["a", "b"]])
  274. # default, sort=None
  275. other = idx[slice_]
  276. tm.assert_index_equal(idx.union(other), idx)
  277. tm.assert_index_equal(other.union(idx), idx)
  278. # sort=False
  279. tm.assert_index_equal(idx.union(other, sort=False), idx)
  280. def test_union_sort_other_empty_sort():
  281. idx = MultiIndex.from_product([[1, 0], ["a", "b"]])
  282. other = idx[:0]
  283. result = idx.union(other, sort=True)
  284. expected = MultiIndex.from_product([[0, 1], ["a", "b"]])
  285. tm.assert_index_equal(result, expected)
  286. def test_union_sort_other_incomparable():
  287. # https://github.com/pandas-dev/pandas/issues/24959
  288. idx = MultiIndex.from_product([[1, pd.Timestamp("2000")], ["a", "b"]])
  289. # default, sort=None
  290. with tm.assert_produces_warning(RuntimeWarning):
  291. result = idx.union(idx[:1])
  292. tm.assert_index_equal(result, idx)
  293. # sort=False
  294. result = idx.union(idx[:1], sort=False)
  295. tm.assert_index_equal(result, idx)
  296. def test_union_sort_other_incomparable_sort():
  297. idx = MultiIndex.from_product([[1, pd.Timestamp("2000")], ["a", "b"]])
  298. msg = "'<' not supported between instances of 'Timestamp' and 'int'"
  299. with pytest.raises(TypeError, match=msg):
  300. idx.union(idx[:1], sort=True)
  301. def test_union_non_object_dtype_raises():
  302. # GH#32646 raise NotImplementedError instead of less-informative error
  303. mi = MultiIndex.from_product([["a", "b"], [1, 2]])
  304. idx = mi.levels[1]
  305. msg = "Can only union MultiIndex with MultiIndex or Index of tuples"
  306. with pytest.raises(NotImplementedError, match=msg):
  307. mi.union(idx)
  308. def test_union_empty_self_different_names():
  309. # GH#38423
  310. mi = MultiIndex.from_arrays([[]])
  311. mi2 = MultiIndex.from_arrays([[1, 2], [3, 4]], names=["a", "b"])
  312. result = mi.union(mi2)
  313. expected = MultiIndex.from_arrays([[1, 2], [3, 4]])
  314. tm.assert_index_equal(result, expected)
  315. def test_union_multiindex_empty_rangeindex():
  316. # GH#41234
  317. mi = MultiIndex.from_arrays([[1, 2], [3, 4]], names=["a", "b"])
  318. ri = pd.RangeIndex(0)
  319. result_left = mi.union(ri)
  320. tm.assert_index_equal(mi, result_left, check_names=False)
  321. result_right = ri.union(mi)
  322. tm.assert_index_equal(mi, result_right, check_names=False)
  323. @pytest.mark.parametrize(
  324. "method", ["union", "intersection", "difference", "symmetric_difference"]
  325. )
  326. def test_setops_sort_validation(method):
  327. idx1 = MultiIndex.from_product([["a", "b"], [1, 2]])
  328. idx2 = MultiIndex.from_product([["b", "c"], [1, 2]])
  329. with pytest.raises(ValueError, match="The 'sort' keyword only takes"):
  330. getattr(idx1, method)(idx2, sort=2)
  331. # sort=True is supported as of GH#?
  332. getattr(idx1, method)(idx2, sort=True)
  333. @pytest.mark.parametrize("val", [pd.NA, 100])
  334. def test_difference_keep_ea_dtypes(any_numeric_ea_dtype, val):
  335. # GH#48606
  336. midx = MultiIndex.from_arrays(
  337. [Series([1, 2], dtype=any_numeric_ea_dtype), [2, 1]], names=["a", None]
  338. )
  339. midx2 = MultiIndex.from_arrays(
  340. [Series([1, 2, val], dtype=any_numeric_ea_dtype), [1, 1, 3]]
  341. )
  342. result = midx.difference(midx2)
  343. expected = MultiIndex.from_arrays([Series([1], dtype=any_numeric_ea_dtype), [2]])
  344. tm.assert_index_equal(result, expected)
  345. result = midx.difference(midx.sort_values(ascending=False))
  346. expected = MultiIndex.from_arrays(
  347. [Series([], dtype=any_numeric_ea_dtype), Series([], dtype=np.int64)],
  348. names=["a", None],
  349. )
  350. tm.assert_index_equal(result, expected)
  351. @pytest.mark.parametrize("val", [pd.NA, 5])
  352. def test_symmetric_difference_keeping_ea_dtype(any_numeric_ea_dtype, val):
  353. # GH#48607
  354. midx = MultiIndex.from_arrays(
  355. [Series([1, 2], dtype=any_numeric_ea_dtype), [2, 1]], names=["a", None]
  356. )
  357. midx2 = MultiIndex.from_arrays(
  358. [Series([1, 2, val], dtype=any_numeric_ea_dtype), [1, 1, 3]]
  359. )
  360. result = midx.symmetric_difference(midx2)
  361. expected = MultiIndex.from_arrays(
  362. [Series([1, 1, val], dtype=any_numeric_ea_dtype), [1, 2, 3]]
  363. )
  364. tm.assert_index_equal(result, expected)
  365. @pytest.mark.parametrize(
  366. ("tuples", "exp_tuples"),
  367. [
  368. ([("val1", "test1")], [("val1", "test1")]),
  369. ([("val1", "test1"), ("val1", "test1")], [("val1", "test1")]),
  370. (
  371. [("val2", "test2"), ("val1", "test1")],
  372. [("val2", "test2"), ("val1", "test1")],
  373. ),
  374. ],
  375. )
  376. def test_intersect_with_duplicates(tuples, exp_tuples):
  377. # GH#36915
  378. left = MultiIndex.from_tuples(tuples, names=["first", "second"])
  379. right = MultiIndex.from_tuples(
  380. [("val1", "test1"), ("val1", "test1"), ("val2", "test2")],
  381. names=["first", "second"],
  382. )
  383. result = left.intersection(right)
  384. expected = MultiIndex.from_tuples(exp_tuples, names=["first", "second"])
  385. tm.assert_index_equal(result, expected)
  386. @pytest.mark.parametrize(
  387. "data, names, expected",
  388. [
  389. ((1,), None, [None, None]),
  390. ((1,), ["a"], [None, None]),
  391. ((1,), ["b"], [None, None]),
  392. ((1, 2), ["c", "d"], [None, None]),
  393. ((1, 2), ["b", "a"], [None, None]),
  394. ((1, 2, 3), ["a", "b", "c"], [None, None]),
  395. ((1, 2), ["a", "c"], ["a", None]),
  396. ((1, 2), ["c", "b"], [None, "b"]),
  397. ((1, 2), ["a", "b"], ["a", "b"]),
  398. ((1, 2), [None, "b"], [None, "b"]),
  399. ],
  400. )
  401. def test_maybe_match_names(data, names, expected):
  402. # GH#38323
  403. mi = MultiIndex.from_tuples([], names=["a", "b"])
  404. mi2 = MultiIndex.from_tuples([data], names=names)
  405. result = mi._maybe_match_names(mi2)
  406. assert result == expected
  407. def test_intersection_equal_different_names():
  408. # GH#30302
  409. mi1 = MultiIndex.from_arrays([[1, 2], [3, 4]], names=["c", "b"])
  410. mi2 = MultiIndex.from_arrays([[1, 2], [3, 4]], names=["a", "b"])
  411. result = mi1.intersection(mi2)
  412. expected = MultiIndex.from_arrays([[1, 2], [3, 4]], names=[None, "b"])
  413. tm.assert_index_equal(result, expected)
  414. def test_intersection_different_names():
  415. # GH#38323
  416. mi = MultiIndex.from_arrays([[1], [3]], names=["c", "b"])
  417. mi2 = MultiIndex.from_arrays([[1], [3]])
  418. result = mi.intersection(mi2)
  419. tm.assert_index_equal(result, mi2)
  420. def test_intersection_with_missing_values_on_both_sides(nulls_fixture):
  421. # GH#38623
  422. mi1 = MultiIndex.from_arrays([[3, nulls_fixture, 4, nulls_fixture], [1, 2, 4, 2]])
  423. mi2 = MultiIndex.from_arrays([[3, nulls_fixture, 3], [1, 2, 4]])
  424. result = mi1.intersection(mi2)
  425. expected = MultiIndex.from_arrays([[3, nulls_fixture], [1, 2]])
  426. tm.assert_index_equal(result, expected)
  427. def test_union_with_missing_values_on_both_sides(nulls_fixture):
  428. # GH#38623
  429. mi1 = MultiIndex.from_arrays([[1, nulls_fixture]])
  430. mi2 = MultiIndex.from_arrays([[1, nulls_fixture, 3]])
  431. result = mi1.union(mi2)
  432. expected = MultiIndex.from_arrays([[1, 3, nulls_fixture]])
  433. tm.assert_index_equal(result, expected)
  434. @pytest.mark.parametrize("dtype", ["float64", "Float64"])
  435. @pytest.mark.parametrize("sort", [None, False])
  436. def test_union_nan_got_duplicated(dtype, sort):
  437. # GH#38977, GH#49010
  438. mi1 = MultiIndex.from_arrays([pd.array([1.0, np.nan], dtype=dtype), [2, 3]])
  439. mi2 = MultiIndex.from_arrays([pd.array([1.0, np.nan, 3.0], dtype=dtype), [2, 3, 4]])
  440. result = mi1.union(mi2, sort=sort)
  441. if sort is None:
  442. expected = MultiIndex.from_arrays(
  443. [pd.array([1.0, 3.0, np.nan], dtype=dtype), [2, 4, 3]]
  444. )
  445. else:
  446. expected = mi2
  447. tm.assert_index_equal(result, expected)
  448. @pytest.mark.parametrize("val", [4, 1])
  449. def test_union_keep_ea_dtype(any_numeric_ea_dtype, val):
  450. # GH#48505
  451. arr1 = Series([val, 2], dtype=any_numeric_ea_dtype)
  452. arr2 = Series([2, 1], dtype=any_numeric_ea_dtype)
  453. midx = MultiIndex.from_arrays([arr1, [1, 2]], names=["a", None])
  454. midx2 = MultiIndex.from_arrays([arr2, [2, 1]])
  455. result = midx.union(midx2)
  456. if val == 4:
  457. expected = MultiIndex.from_arrays(
  458. [Series([1, 2, 4], dtype=any_numeric_ea_dtype), [1, 2, 1]]
  459. )
  460. else:
  461. expected = MultiIndex.from_arrays(
  462. [Series([1, 2], dtype=any_numeric_ea_dtype), [1, 2]]
  463. )
  464. tm.assert_index_equal(result, expected)
  465. @pytest.mark.parametrize("dupe_val", [3, pd.NA])
  466. def test_union_with_duplicates_keep_ea_dtype(dupe_val, any_numeric_ea_dtype):
  467. # GH48900
  468. mi1 = MultiIndex.from_arrays(
  469. [
  470. Series([1, dupe_val, 2], dtype=any_numeric_ea_dtype),
  471. Series([1, dupe_val, 2], dtype=any_numeric_ea_dtype),
  472. ]
  473. )
  474. mi2 = MultiIndex.from_arrays(
  475. [
  476. Series([2, dupe_val, dupe_val], dtype=any_numeric_ea_dtype),
  477. Series([2, dupe_val, dupe_val], dtype=any_numeric_ea_dtype),
  478. ]
  479. )
  480. result = mi1.union(mi2)
  481. expected = MultiIndex.from_arrays(
  482. [
  483. Series([1, 2, dupe_val, dupe_val], dtype=any_numeric_ea_dtype),
  484. Series([1, 2, dupe_val, dupe_val], dtype=any_numeric_ea_dtype),
  485. ]
  486. )
  487. tm.assert_index_equal(result, expected)
  488. @pytest.mark.filterwarnings(r"ignore:PeriodDtype\[B\] is deprecated:FutureWarning")
  489. def test_union_duplicates(index, request):
  490. # GH#38977
  491. if index.empty or isinstance(index, (IntervalIndex, CategoricalIndex)):
  492. pytest.skip(f"No duplicates in an empty {type(index).__name__}")
  493. values = index.unique().values.tolist()
  494. mi1 = MultiIndex.from_arrays([values, [1] * len(values)])
  495. mi2 = MultiIndex.from_arrays([[values[0]] + values, [1] * (len(values) + 1)])
  496. result = mi2.union(mi1)
  497. expected = mi2.sort_values()
  498. tm.assert_index_equal(result, expected)
  499. if (
  500. is_unsigned_integer_dtype(mi2.levels[0])
  501. and (mi2.get_level_values(0) < 2**63).all()
  502. ):
  503. # GH#47294 - union uses lib.fast_zip, converting data to Python integers
  504. # and loses type information. Result is then unsigned only when values are
  505. # sufficiently large to require unsigned dtype. This happens only if other
  506. # has dups or one of both have missing values
  507. expected = expected.set_levels(
  508. [expected.levels[0].astype(np.int64), expected.levels[1]]
  509. )
  510. elif is_float_dtype(mi2.levels[0]):
  511. # mi2 has duplicates witch is a different path than above, Fix that path
  512. # to use correct float dtype?
  513. expected = expected.set_levels(
  514. [expected.levels[0].astype(float), expected.levels[1]]
  515. )
  516. result = mi1.union(mi2)
  517. tm.assert_index_equal(result, expected)
  518. def test_union_keep_dtype_precision(any_real_numeric_dtype):
  519. # GH#48498
  520. arr1 = Series([4, 1, 1], dtype=any_real_numeric_dtype)
  521. arr2 = Series([1, 4], dtype=any_real_numeric_dtype)
  522. midx = MultiIndex.from_arrays([arr1, [2, 1, 1]], names=["a", None])
  523. midx2 = MultiIndex.from_arrays([arr2, [1, 2]], names=["a", None])
  524. result = midx.union(midx2)
  525. expected = MultiIndex.from_arrays(
  526. ([Series([1, 1, 4], dtype=any_real_numeric_dtype), [1, 1, 2]]),
  527. names=["a", None],
  528. )
  529. tm.assert_index_equal(result, expected)
  530. def test_union_keep_ea_dtype_with_na(any_numeric_ea_dtype):
  531. # GH#48498
  532. arr1 = Series([4, pd.NA], dtype=any_numeric_ea_dtype)
  533. arr2 = Series([1, pd.NA], dtype=any_numeric_ea_dtype)
  534. midx = MultiIndex.from_arrays([arr1, [2, 1]], names=["a", None])
  535. midx2 = MultiIndex.from_arrays([arr2, [1, 2]])
  536. result = midx.union(midx2)
  537. expected = MultiIndex.from_arrays(
  538. [Series([1, 4, pd.NA, pd.NA], dtype=any_numeric_ea_dtype), [1, 2, 1, 2]]
  539. )
  540. tm.assert_index_equal(result, expected)
  541. @pytest.mark.parametrize(
  542. "levels1, levels2, codes1, codes2, names",
  543. [
  544. (
  545. [["a", "b", "c"], [0, ""]],
  546. [["c", "d", "b"], [""]],
  547. [[0, 1, 2], [1, 1, 1]],
  548. [[0, 1, 2], [0, 0, 0]],
  549. ["name1", "name2"],
  550. ),
  551. ],
  552. )
  553. def test_intersection_lexsort_depth(levels1, levels2, codes1, codes2, names):
  554. # GH#25169
  555. mi1 = MultiIndex(levels=levels1, codes=codes1, names=names)
  556. mi2 = MultiIndex(levels=levels2, codes=codes2, names=names)
  557. mi_int = mi1.intersection(mi2)
  558. assert mi_int._lexsort_depth == 2
  559. @pytest.mark.parametrize(
  560. "a",
  561. [pd.Categorical(["a", "b"], categories=["a", "b"]), ["a", "b"]],
  562. )
  563. @pytest.mark.parametrize(
  564. "b",
  565. [
  566. pd.Categorical(["a", "b"], categories=["b", "a"], ordered=True),
  567. pd.Categorical(["a", "b"], categories=["b", "a"]),
  568. ],
  569. )
  570. def test_intersection_with_non_lex_sorted_categories(a, b):
  571. # GH#49974
  572. other = ["1", "2"]
  573. df1 = DataFrame({"x": a, "y": other})
  574. df2 = DataFrame({"x": b, "y": other})
  575. expected = MultiIndex.from_arrays([a, other], names=["x", "y"])
  576. res1 = MultiIndex.from_frame(df1).intersection(
  577. MultiIndex.from_frame(df2.sort_values(["x", "y"]))
  578. )
  579. res2 = MultiIndex.from_frame(df1).intersection(MultiIndex.from_frame(df2))
  580. res3 = MultiIndex.from_frame(df1.sort_values(["x", "y"])).intersection(
  581. MultiIndex.from_frame(df2)
  582. )
  583. res4 = MultiIndex.from_frame(df1.sort_values(["x", "y"])).intersection(
  584. MultiIndex.from_frame(df2.sort_values(["x", "y"]))
  585. )
  586. tm.assert_index_equal(res1, expected)
  587. tm.assert_index_equal(res2, expected)
  588. tm.assert_index_equal(res3, expected)
  589. tm.assert_index_equal(res4, expected)
  590. @pytest.mark.parametrize("val", [pd.NA, 100])
  591. def test_intersection_keep_ea_dtypes(val, any_numeric_ea_dtype):
  592. # GH#48604
  593. midx = MultiIndex.from_arrays(
  594. [Series([1, 2], dtype=any_numeric_ea_dtype), [2, 1]], names=["a", None]
  595. )
  596. midx2 = MultiIndex.from_arrays(
  597. [Series([1, 2, val], dtype=any_numeric_ea_dtype), [1, 1, 3]]
  598. )
  599. result = midx.intersection(midx2)
  600. expected = MultiIndex.from_arrays([Series([2], dtype=any_numeric_ea_dtype), [1]])
  601. tm.assert_index_equal(result, expected)
  602. def test_union_with_na_when_constructing_dataframe():
  603. # GH43222
  604. series1 = Series(
  605. (1,),
  606. index=MultiIndex.from_arrays(
  607. [Series([None], dtype="str"), Series([None], dtype="str")]
  608. ),
  609. )
  610. series2 = Series((10, 20), index=MultiIndex.from_tuples(((None, None), ("a", "b"))))
  611. result = DataFrame([series1, series2])
  612. expected = DataFrame({(np.nan, np.nan): [1.0, 10.0], ("a", "b"): [np.nan, 20.0]})
  613. tm.assert_frame_equal(result, expected)