test_join.py 17 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525526527528529530531532533534535536537538539540541542543544545546547548549550551552553554555556557558559560561562563564565566567568569570571572573574575576
  1. from datetime import datetime
  2. import numpy as np
  3. import pytest
  4. from pandas.errors import MergeError
  5. import pandas as pd
  6. from pandas import (
  7. DataFrame,
  8. Index,
  9. MultiIndex,
  10. date_range,
  11. period_range,
  12. )
  13. import pandas._testing as tm
  14. from pandas.core.reshape.concat import concat
  15. @pytest.fixture
  16. def frame_with_period_index():
  17. return DataFrame(
  18. data=np.arange(20).reshape(4, 5),
  19. columns=list("abcde"),
  20. index=period_range(start="2000", freq="Y", periods=4),
  21. )
  22. @pytest.fixture
  23. def left():
  24. return DataFrame({"a": [20, 10, 0]}, index=[2, 1, 0])
  25. @pytest.fixture
  26. def right():
  27. return DataFrame({"b": [300, 100, 200]}, index=[3, 1, 2])
  28. @pytest.fixture
  29. def left_no_dup():
  30. return DataFrame(
  31. {"a": ["a", "b", "c", "d"], "b": ["cat", "dog", "weasel", "horse"]},
  32. index=range(4),
  33. )
  34. @pytest.fixture
  35. def right_no_dup():
  36. return DataFrame(
  37. {
  38. "a": ["a", "b", "c", "d", "e"],
  39. "c": ["meow", "bark", "um... weasel noise?", "nay", "chirp"],
  40. },
  41. index=range(5),
  42. ).set_index("a")
  43. @pytest.fixture
  44. def left_w_dups(left_no_dup):
  45. return concat(
  46. [left_no_dup, DataFrame({"a": ["a"], "b": ["cow"]}, index=[3])], sort=True
  47. )
  48. @pytest.fixture
  49. def right_w_dups(right_no_dup):
  50. return concat(
  51. [right_no_dup, DataFrame({"a": ["e"], "c": ["moo"]}, index=[3])]
  52. ).set_index("a")
  53. @pytest.mark.parametrize(
  54. "how, sort, expected",
  55. [
  56. ("inner", False, DataFrame({"a": [20, 10], "b": [200, 100]}, index=[2, 1])),
  57. ("inner", True, DataFrame({"a": [10, 20], "b": [100, 200]}, index=[1, 2])),
  58. (
  59. "left",
  60. False,
  61. DataFrame({"a": [20, 10, 0], "b": [200, 100, np.nan]}, index=[2, 1, 0]),
  62. ),
  63. (
  64. "left",
  65. True,
  66. DataFrame({"a": [0, 10, 20], "b": [np.nan, 100, 200]}, index=[0, 1, 2]),
  67. ),
  68. (
  69. "right",
  70. False,
  71. DataFrame({"a": [np.nan, 10, 20], "b": [300, 100, 200]}, index=[3, 1, 2]),
  72. ),
  73. (
  74. "right",
  75. True,
  76. DataFrame({"a": [10, 20, np.nan], "b": [100, 200, 300]}, index=[1, 2, 3]),
  77. ),
  78. (
  79. "outer",
  80. False,
  81. DataFrame(
  82. {"a": [0, 10, 20, np.nan], "b": [np.nan, 100, 200, 300]},
  83. index=[0, 1, 2, 3],
  84. ),
  85. ),
  86. (
  87. "outer",
  88. True,
  89. DataFrame(
  90. {"a": [0, 10, 20, np.nan], "b": [np.nan, 100, 200, 300]},
  91. index=[0, 1, 2, 3],
  92. ),
  93. ),
  94. ],
  95. )
  96. def test_join(left, right, how, sort, expected):
  97. result = left.join(right, how=how, sort=sort, validate="1:1")
  98. tm.assert_frame_equal(result, expected)
  99. def test_suffix_on_list_join():
  100. first = DataFrame({"key": [1, 2, 3, 4, 5]})
  101. second = DataFrame({"key": [1, 8, 3, 2, 5], "v1": [1, 2, 3, 4, 5]})
  102. third = DataFrame({"keys": [5, 2, 3, 4, 1], "v2": [1, 2, 3, 4, 5]})
  103. # check proper errors are raised
  104. msg = "Suffixes not supported when joining multiple DataFrames"
  105. with pytest.raises(ValueError, match=msg):
  106. first.join([second], lsuffix="y")
  107. with pytest.raises(ValueError, match=msg):
  108. first.join([second, third], rsuffix="x")
  109. with pytest.raises(ValueError, match=msg):
  110. first.join([second, third], lsuffix="y", rsuffix="x")
  111. with pytest.raises(ValueError, match="Indexes have overlapping values"):
  112. first.join([second, third])
  113. # no errors should be raised
  114. arr_joined = first.join([third])
  115. norm_joined = first.join(third)
  116. tm.assert_frame_equal(arr_joined, norm_joined)
  117. def test_join_invalid_validate(left_no_dup, right_no_dup):
  118. # GH 46622
  119. # Check invalid arguments
  120. msg = (
  121. '"invalid" is not a valid argument. '
  122. "Valid arguments are:\n"
  123. '- "1:1"\n'
  124. '- "1:m"\n'
  125. '- "m:1"\n'
  126. '- "m:m"\n'
  127. '- "one_to_one"\n'
  128. '- "one_to_many"\n'
  129. '- "many_to_one"\n'
  130. '- "many_to_many"'
  131. )
  132. with pytest.raises(ValueError, match=msg):
  133. left_no_dup.merge(right_no_dup, on="a", validate="invalid")
  134. @pytest.mark.parametrize("dtype", ["object", "string[pyarrow]"])
  135. def test_join_on_single_col_dup_on_right(left_no_dup, right_w_dups, dtype):
  136. # GH 46622
  137. # Dups on right allowed by one_to_many constraint
  138. if dtype == "string[pyarrow]":
  139. pytest.importorskip("pyarrow")
  140. left_no_dup = left_no_dup.astype(dtype)
  141. right_w_dups.index = right_w_dups.index.astype(dtype)
  142. left_no_dup.join(
  143. right_w_dups,
  144. on="a",
  145. validate="one_to_many",
  146. )
  147. # Dups on right not allowed by one_to_one constraint
  148. msg = "Merge keys are not unique in right dataset; not a one-to-one merge"
  149. with pytest.raises(MergeError, match=msg):
  150. left_no_dup.join(
  151. right_w_dups,
  152. on="a",
  153. validate="one_to_one",
  154. )
  155. def test_join_on_single_col_dup_on_left(left_w_dups, right_no_dup):
  156. # GH 46622
  157. # Dups on left allowed by many_to_one constraint
  158. left_w_dups.join(
  159. right_no_dup,
  160. on="a",
  161. validate="many_to_one",
  162. )
  163. # Dups on left not allowed by one_to_one constraint
  164. msg = "Merge keys are not unique in left dataset; not a one-to-one merge"
  165. with pytest.raises(MergeError, match=msg):
  166. left_w_dups.join(
  167. right_no_dup,
  168. on="a",
  169. validate="one_to_one",
  170. )
  171. def test_join_on_single_col_dup_on_both(left_w_dups, right_w_dups):
  172. # GH 46622
  173. # Dups on both allowed by many_to_many constraint
  174. left_w_dups.join(right_w_dups, on="a", validate="many_to_many")
  175. # Dups on both not allowed by many_to_one constraint
  176. msg = "Merge keys are not unique in right dataset; not a many-to-one merge"
  177. with pytest.raises(MergeError, match=msg):
  178. left_w_dups.join(
  179. right_w_dups,
  180. on="a",
  181. validate="many_to_one",
  182. )
  183. # Dups on both not allowed by one_to_many constraint
  184. msg = "Merge keys are not unique in left dataset; not a one-to-many merge"
  185. with pytest.raises(MergeError, match=msg):
  186. left_w_dups.join(
  187. right_w_dups,
  188. on="a",
  189. validate="one_to_many",
  190. )
  191. def test_join_on_multi_col_check_dup():
  192. # GH 46622
  193. # Two column join, dups in both, but jointly no dups
  194. left = DataFrame(
  195. {
  196. "a": ["a", "a", "b", "b"],
  197. "b": [0, 1, 0, 1],
  198. "c": ["cat", "dog", "weasel", "horse"],
  199. },
  200. index=range(4),
  201. ).set_index(["a", "b"])
  202. right = DataFrame(
  203. {
  204. "a": ["a", "a", "b"],
  205. "b": [0, 1, 0],
  206. "d": ["meow", "bark", "um... weasel noise?"],
  207. },
  208. index=range(3),
  209. ).set_index(["a", "b"])
  210. expected_multi = DataFrame(
  211. {
  212. "a": ["a", "a", "b"],
  213. "b": [0, 1, 0],
  214. "c": ["cat", "dog", "weasel"],
  215. "d": ["meow", "bark", "um... weasel noise?"],
  216. },
  217. index=range(3),
  218. ).set_index(["a", "b"])
  219. # Jointly no dups allowed by one_to_one constraint
  220. result = left.join(right, how="inner", validate="1:1")
  221. tm.assert_frame_equal(result, expected_multi)
  222. def test_join_index(float_frame):
  223. # left / right
  224. f = float_frame.loc[float_frame.index[:10], ["A", "B"]]
  225. f2 = float_frame.loc[float_frame.index[5:], ["C", "D"]].iloc[::-1]
  226. joined = f.join(f2)
  227. tm.assert_index_equal(f.index, joined.index)
  228. expected_columns = Index(["A", "B", "C", "D"])
  229. tm.assert_index_equal(joined.columns, expected_columns)
  230. joined = f.join(f2, how="left")
  231. tm.assert_index_equal(joined.index, f.index)
  232. tm.assert_index_equal(joined.columns, expected_columns)
  233. joined = f.join(f2, how="right")
  234. tm.assert_index_equal(joined.index, f2.index)
  235. tm.assert_index_equal(joined.columns, expected_columns)
  236. # inner
  237. joined = f.join(f2, how="inner")
  238. tm.assert_index_equal(joined.index, f.index[5:10])
  239. tm.assert_index_equal(joined.columns, expected_columns)
  240. # outer
  241. joined = f.join(f2, how="outer")
  242. tm.assert_index_equal(joined.index, float_frame.index.sort_values())
  243. tm.assert_index_equal(joined.columns, expected_columns)
  244. with pytest.raises(ValueError, match="join method"):
  245. f.join(f2, how="foo")
  246. # corner case - overlapping columns
  247. msg = "columns overlap but no suffix"
  248. for how in ("outer", "left", "inner"):
  249. with pytest.raises(ValueError, match=msg):
  250. float_frame.join(float_frame, how=how)
  251. def test_join_index_more(float_frame):
  252. af = float_frame.loc[:, ["A", "B"]]
  253. bf = float_frame.loc[::2, ["C", "D"]]
  254. expected = af.copy()
  255. expected["C"] = float_frame["C"][::2]
  256. expected["D"] = float_frame["D"][::2]
  257. result = af.join(bf)
  258. tm.assert_frame_equal(result, expected)
  259. result = af.join(bf, how="right")
  260. tm.assert_frame_equal(result, expected[::2])
  261. result = bf.join(af, how="right")
  262. tm.assert_frame_equal(result, expected.loc[:, result.columns])
  263. def test_join_index_series(float_frame):
  264. df = float_frame.copy()
  265. ser = df.pop(float_frame.columns[-1])
  266. joined = df.join(ser)
  267. tm.assert_frame_equal(joined, float_frame)
  268. ser.name = None
  269. with pytest.raises(ValueError, match="must have a name"):
  270. df.join(ser)
  271. def test_join_overlap(float_frame):
  272. df1 = float_frame.loc[:, ["A", "B", "C"]]
  273. df2 = float_frame.loc[:, ["B", "C", "D"]]
  274. joined = df1.join(df2, lsuffix="_df1", rsuffix="_df2")
  275. df1_suf = df1.loc[:, ["B", "C"]].add_suffix("_df1")
  276. df2_suf = df2.loc[:, ["B", "C"]].add_suffix("_df2")
  277. no_overlap = float_frame.loc[:, ["A", "D"]]
  278. expected = df1_suf.join(df2_suf).join(no_overlap)
  279. # column order not necessarily sorted
  280. tm.assert_frame_equal(joined, expected.loc[:, joined.columns])
  281. def test_join_period_index(frame_with_period_index):
  282. other = frame_with_period_index.rename(columns=lambda key: f"{key}{key}")
  283. joined_values = np.concatenate([frame_with_period_index.values] * 2, axis=1)
  284. joined_cols = frame_with_period_index.columns.append(other.columns)
  285. joined = frame_with_period_index.join(other)
  286. expected = DataFrame(
  287. data=joined_values, columns=joined_cols, index=frame_with_period_index.index
  288. )
  289. tm.assert_frame_equal(joined, expected)
  290. def test_join_left_sequence_non_unique_index():
  291. # https://github.com/pandas-dev/pandas/issues/19607
  292. df1 = DataFrame({"a": [0, 10, 20]}, index=[1, 2, 3])
  293. df2 = DataFrame({"b": [100, 200, 300]}, index=[4, 3, 2])
  294. df3 = DataFrame({"c": [400, 500, 600]}, index=[2, 2, 4])
  295. joined = df1.join([df2, df3], how="left")
  296. expected = DataFrame(
  297. {
  298. "a": [0, 10, 10, 20],
  299. "b": [np.nan, 300, 300, 200],
  300. "c": [np.nan, 400, 500, np.nan],
  301. },
  302. index=[1, 2, 2, 3],
  303. )
  304. tm.assert_frame_equal(joined, expected)
  305. def test_join_list_series(float_frame):
  306. # GH#46850
  307. # Join a DataFrame with a list containing both a Series and a DataFrame
  308. left = float_frame.A.to_frame()
  309. right = [float_frame.B, float_frame[["C", "D"]]]
  310. result = left.join(right)
  311. tm.assert_frame_equal(result, float_frame)
  312. @pytest.mark.parametrize("sort_kw", [True, False])
  313. def test_suppress_future_warning_with_sort_kw(sort_kw):
  314. a = DataFrame({"col1": [1, 2]}, index=["c", "a"])
  315. b = DataFrame({"col2": [4, 5]}, index=["b", "a"])
  316. c = DataFrame({"col3": [7, 8]}, index=["a", "b"])
  317. expected = DataFrame(
  318. {
  319. "col1": {"a": 2.0, "b": float("nan"), "c": 1.0},
  320. "col2": {"a": 5.0, "b": 4.0, "c": float("nan")},
  321. "col3": {"a": 7.0, "b": 8.0, "c": float("nan")},
  322. }
  323. )
  324. if sort_kw is False:
  325. expected = expected.reindex(index=["c", "a", "b"])
  326. with tm.assert_produces_warning(None):
  327. result = a.join([b, c], how="outer", sort=sort_kw)
  328. tm.assert_frame_equal(result, expected)
  329. class TestDataFrameJoin:
  330. def test_join(self, multiindex_dataframe_random_data):
  331. frame = multiindex_dataframe_random_data
  332. a = frame.loc[frame.index[:5], ["A"]]
  333. b = frame.loc[frame.index[2:], ["B", "C"]]
  334. joined = a.join(b, how="outer").reindex(frame.index)
  335. expected = frame.copy().values.copy()
  336. expected[np.isnan(joined.values)] = np.nan
  337. expected = DataFrame(expected, index=frame.index, columns=frame.columns)
  338. assert not np.isnan(joined.values).all()
  339. tm.assert_frame_equal(joined, expected)
  340. def test_join_segfault(self):
  341. # GH#1532
  342. df1 = DataFrame({"a": [1, 1], "b": [1, 2], "x": [1, 2]})
  343. df2 = DataFrame({"a": [2, 2], "b": [1, 2], "y": [1, 2]})
  344. df1 = df1.set_index(["a", "b"])
  345. df2 = df2.set_index(["a", "b"])
  346. # it works!
  347. for how in ["left", "right", "outer"]:
  348. df1.join(df2, how=how)
  349. def test_join_str_datetime(self):
  350. str_dates = ["20120209", "20120222"]
  351. dt_dates = [datetime(2012, 2, 9), datetime(2012, 2, 22)]
  352. A = DataFrame(str_dates, index=range(2), columns=["aa"])
  353. C = DataFrame([[1, 2], [3, 4]], index=str_dates, columns=dt_dates)
  354. tst = A.join(C, on="aa")
  355. assert len(tst.columns) == 3
  356. def test_join_multiindex_leftright(self):
  357. # GH 10741
  358. df1 = DataFrame(
  359. [
  360. ["a", "x", 0.471780],
  361. ["a", "y", 0.774908],
  362. ["a", "z", 0.563634],
  363. ["b", "x", -0.353756],
  364. ["b", "y", 0.368062],
  365. ["b", "z", -1.721840],
  366. ["c", "x", 1],
  367. ["c", "y", 2],
  368. ["c", "z", 3],
  369. ],
  370. columns=["first", "second", "value1"],
  371. ).set_index(["first", "second"])
  372. df2 = DataFrame([["a", 10], ["b", 20]], columns=["first", "value2"]).set_index(
  373. ["first"]
  374. )
  375. exp = DataFrame(
  376. [
  377. [0.471780, 10],
  378. [0.774908, 10],
  379. [0.563634, 10],
  380. [-0.353756, 20],
  381. [0.368062, 20],
  382. [-1.721840, 20],
  383. [1.000000, np.nan],
  384. [2.000000, np.nan],
  385. [3.000000, np.nan],
  386. ],
  387. index=df1.index,
  388. columns=["value1", "value2"],
  389. )
  390. # these must be the same results (but columns are flipped)
  391. tm.assert_frame_equal(df1.join(df2, how="left"), exp)
  392. tm.assert_frame_equal(df2.join(df1, how="right"), exp[["value2", "value1"]])
  393. exp_idx = MultiIndex.from_product(
  394. [["a", "b"], ["x", "y", "z"]], names=["first", "second"]
  395. )
  396. exp = DataFrame(
  397. [
  398. [0.471780, 10],
  399. [0.774908, 10],
  400. [0.563634, 10],
  401. [-0.353756, 20],
  402. [0.368062, 20],
  403. [-1.721840, 20],
  404. ],
  405. index=exp_idx,
  406. columns=["value1", "value2"],
  407. )
  408. tm.assert_frame_equal(df1.join(df2, how="right"), exp)
  409. tm.assert_frame_equal(df2.join(df1, how="left"), exp[["value2", "value1"]])
  410. def test_join_multiindex_dates(self):
  411. # GH 33692
  412. date = pd.Timestamp(2000, 1, 1).date()
  413. df1_index = MultiIndex.from_tuples([(0, date)], names=["index_0", "date"])
  414. df1 = DataFrame({"col1": [0]}, index=df1_index)
  415. df2_index = MultiIndex.from_tuples([(0, date)], names=["index_0", "date"])
  416. df2 = DataFrame({"col2": [0]}, index=df2_index)
  417. df3_index = MultiIndex.from_tuples([(0, date)], names=["index_0", "date"])
  418. df3 = DataFrame({"col3": [0]}, index=df3_index)
  419. result = df1.join([df2, df3])
  420. expected_index = MultiIndex.from_tuples([(0, date)], names=["index_0", "date"])
  421. expected = DataFrame(
  422. {"col1": [0], "col2": [0], "col3": [0]}, index=expected_index
  423. )
  424. tm.assert_equal(result, expected)
  425. def test_merge_join_different_levels_raises(self):
  426. # GH#9455
  427. # GH 40993: For raising, enforced in 2.0
  428. # first dataframe
  429. df1 = DataFrame(columns=["a", "b"], data=[[1, 11], [0, 22]])
  430. # second dataframe
  431. columns = MultiIndex.from_tuples([("a", ""), ("c", "c1")])
  432. df2 = DataFrame(columns=columns, data=[[1, 33], [0, 44]])
  433. # merge
  434. with pytest.raises(
  435. MergeError, match="Not allowed to merge between different levels"
  436. ):
  437. pd.merge(df1, df2, on="a")
  438. # join, see discussion in GH#12219
  439. with pytest.raises(
  440. MergeError, match="Not allowed to merge between different levels"
  441. ):
  442. df1.join(df2, on="a")
  443. def test_frame_join_tzaware(self):
  444. test1 = DataFrame(
  445. np.zeros((6, 3)),
  446. index=date_range(
  447. "2012-11-15 00:00:00", periods=6, freq="100ms", tz="US/Central"
  448. ),
  449. )
  450. test2 = DataFrame(
  451. np.zeros((3, 3)),
  452. index=date_range(
  453. "2012-11-15 00:00:00", periods=3, freq="250ms", tz="US/Central"
  454. ),
  455. columns=range(3, 6),
  456. )
  457. result = test1.join(test2, how="outer")
  458. expected = test1.index.union(test2.index)
  459. tm.assert_index_equal(result.index, expected)
  460. assert result.index.tz.zone == "US/Central"