test_all_methods.py 3.0 KB

1234567891011121314151617181920212223242526272829303132333435363738394041424344454647484950515253545556575859606162636465666768697071727374757677787980818283
  1. """
  2. Tests that apply to all groupby operation methods.
  3. The only tests that should appear here are those that use the `groupby_func` fixture.
  4. Even if it does use that fixture, prefer a more specific test file if it available
  5. such as:
  6. - test_categorical
  7. - test_groupby_dropna
  8. - test_groupby_subclass
  9. - test_raises
  10. """
  11. import pytest
  12. import pandas as pd
  13. from pandas import DataFrame
  14. import pandas._testing as tm
  15. from pandas.tests.groupby import get_groupby_method_args
  16. def test_multiindex_group_all_columns_when_empty(groupby_func):
  17. # GH 32464
  18. df = DataFrame({"a": [], "b": [], "c": []}).set_index(["a", "b", "c"])
  19. gb = df.groupby(["a", "b", "c"], group_keys=False)
  20. method = getattr(gb, groupby_func)
  21. args = get_groupby_method_args(groupby_func, df)
  22. warn = FutureWarning if groupby_func == "fillna" else None
  23. warn_msg = "DataFrameGroupBy.fillna is deprecated"
  24. with tm.assert_produces_warning(warn, match=warn_msg):
  25. result = method(*args).index
  26. expected = df.index
  27. tm.assert_index_equal(result, expected)
  28. def test_duplicate_columns(request, groupby_func, as_index):
  29. # GH#50806
  30. if groupby_func == "corrwith":
  31. msg = "GH#50845 - corrwith fails when there are duplicate columns"
  32. request.applymarker(pytest.mark.xfail(reason=msg))
  33. df = DataFrame([[1, 3, 6], [1, 4, 7], [2, 5, 8]], columns=list("abb"))
  34. args = get_groupby_method_args(groupby_func, df)
  35. gb = df.groupby("a", as_index=as_index)
  36. warn = FutureWarning if groupby_func == "fillna" else None
  37. warn_msg = "DataFrameGroupBy.fillna is deprecated"
  38. with tm.assert_produces_warning(warn, match=warn_msg):
  39. result = getattr(gb, groupby_func)(*args)
  40. expected_df = df.set_axis(["a", "b", "c"], axis=1)
  41. expected_args = get_groupby_method_args(groupby_func, expected_df)
  42. expected_gb = expected_df.groupby("a", as_index=as_index)
  43. warn = FutureWarning if groupby_func == "fillna" else None
  44. warn_msg = "DataFrameGroupBy.fillna is deprecated"
  45. with tm.assert_produces_warning(warn, match=warn_msg):
  46. expected = getattr(expected_gb, groupby_func)(*expected_args)
  47. if groupby_func not in ("size", "ngroup", "cumcount"):
  48. expected = expected.rename(columns={"c": "b"})
  49. tm.assert_equal(result, expected)
  50. @pytest.mark.parametrize(
  51. "idx",
  52. [
  53. pd.Index(["a", "a"], name="foo"),
  54. pd.MultiIndex.from_tuples((("a", "a"), ("a", "a")), names=["foo", "bar"]),
  55. ],
  56. )
  57. def test_dup_labels_output_shape(groupby_func, idx):
  58. if groupby_func in {"size", "ngroup", "cumcount"}:
  59. pytest.skip(f"Not applicable for {groupby_func}")
  60. df = DataFrame([[1, 1]], columns=idx)
  61. grp_by = df.groupby([0])
  62. args = get_groupby_method_args(groupby_func, df)
  63. warn = FutureWarning if groupby_func == "fillna" else None
  64. warn_msg = "DataFrameGroupBy.fillna is deprecated"
  65. with tm.assert_produces_warning(warn, match=warn_msg):
  66. result = getattr(grp_by, groupby_func)(*args)
  67. assert result.shape == (1, 2)
  68. tm.assert_index_equal(result.columns, idx)