test_groupby_subclass.py 4.5 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135
  1. from datetime import datetime
  2. import numpy as np
  3. import pytest
  4. from pandas import (
  5. DataFrame,
  6. Index,
  7. Series,
  8. )
  9. import pandas._testing as tm
  10. from pandas.tests.groupby import get_groupby_method_args
  11. pytestmark = pytest.mark.filterwarnings(
  12. "ignore:Passing a BlockManager|Passing a SingleBlockManager:DeprecationWarning"
  13. )
  14. @pytest.mark.parametrize(
  15. "obj",
  16. [
  17. tm.SubclassedDataFrame({"A": np.arange(0, 10)}),
  18. tm.SubclassedSeries(np.arange(0, 10), name="A"),
  19. ],
  20. )
  21. def test_groupby_preserves_subclass(obj, groupby_func):
  22. # GH28330 -- preserve subclass through groupby operations
  23. if isinstance(obj, Series) and groupby_func in {"corrwith"}:
  24. pytest.skip(f"Not applicable for Series and {groupby_func}")
  25. grouped = obj.groupby(np.arange(0, 10))
  26. # Groups should preserve subclass type
  27. assert isinstance(grouped.get_group(0), type(obj))
  28. args = get_groupby_method_args(groupby_func, obj)
  29. warn = FutureWarning if groupby_func == "fillna" else None
  30. msg = f"{type(grouped).__name__}.fillna is deprecated"
  31. with tm.assert_produces_warning(warn, match=msg, raise_on_extra_warnings=False):
  32. result1 = getattr(grouped, groupby_func)(*args)
  33. with tm.assert_produces_warning(warn, match=msg, raise_on_extra_warnings=False):
  34. result2 = grouped.agg(groupby_func, *args)
  35. # Reduction or transformation kernels should preserve type
  36. slices = {"ngroup", "cumcount", "size"}
  37. if isinstance(obj, DataFrame) and groupby_func in slices:
  38. assert isinstance(result1, tm.SubclassedSeries)
  39. else:
  40. assert isinstance(result1, type(obj))
  41. # Confirm .agg() groupby operations return same results
  42. if isinstance(result1, DataFrame):
  43. tm.assert_frame_equal(result1, result2)
  44. else:
  45. tm.assert_series_equal(result1, result2)
  46. def test_groupby_preserves_metadata():
  47. # GH-37343
  48. custom_df = tm.SubclassedDataFrame({"a": [1, 2, 3], "b": [1, 1, 2], "c": [7, 8, 9]})
  49. assert "testattr" in custom_df._metadata
  50. custom_df.testattr = "hello"
  51. for _, group_df in custom_df.groupby("c"):
  52. assert group_df.testattr == "hello"
  53. # GH-45314
  54. def func(group):
  55. assert isinstance(group, tm.SubclassedDataFrame)
  56. assert hasattr(group, "testattr")
  57. assert group.testattr == "hello"
  58. return group.testattr
  59. msg = "DataFrameGroupBy.apply operated on the grouping columns"
  60. with tm.assert_produces_warning(
  61. FutureWarning,
  62. match=msg,
  63. raise_on_extra_warnings=False,
  64. check_stacklevel=False,
  65. ):
  66. result = custom_df.groupby("c").apply(func)
  67. expected = tm.SubclassedSeries(["hello"] * 3, index=Index([7, 8, 9], name="c"))
  68. tm.assert_series_equal(result, expected)
  69. result = custom_df.groupby("c").apply(func, include_groups=False)
  70. tm.assert_series_equal(result, expected)
  71. # https://github.com/pandas-dev/pandas/pull/56761
  72. result = custom_df.groupby("c")[["a", "b"]].apply(func)
  73. tm.assert_series_equal(result, expected)
  74. def func2(group):
  75. assert isinstance(group, tm.SubclassedSeries)
  76. assert hasattr(group, "testattr")
  77. return group.testattr
  78. custom_series = tm.SubclassedSeries([1, 2, 3])
  79. custom_series.testattr = "hello"
  80. result = custom_series.groupby(custom_df["c"]).apply(func2)
  81. tm.assert_series_equal(result, expected)
  82. result = custom_series.groupby(custom_df["c"]).agg(func2)
  83. tm.assert_series_equal(result, expected)
  84. @pytest.mark.parametrize("obj", [DataFrame, tm.SubclassedDataFrame])
  85. def test_groupby_resample_preserves_subclass(obj):
  86. # GH28330 -- preserve subclass through groupby.resample()
  87. df = obj(
  88. {
  89. "Buyer": Series("Carl Carl Carl Carl Joe Carl".split(), dtype=object),
  90. "Quantity": [18, 3, 5, 1, 9, 3],
  91. "Date": [
  92. datetime(2013, 9, 1, 13, 0),
  93. datetime(2013, 9, 1, 13, 5),
  94. datetime(2013, 10, 1, 20, 0),
  95. datetime(2013, 10, 3, 10, 0),
  96. datetime(2013, 12, 2, 12, 0),
  97. datetime(2013, 9, 2, 14, 0),
  98. ],
  99. }
  100. )
  101. df = df.set_index("Date")
  102. # Confirm groupby.resample() preserves dataframe type
  103. msg = "DataFrameGroupBy.resample operated on the grouping columns"
  104. with tm.assert_produces_warning(
  105. FutureWarning,
  106. match=msg,
  107. raise_on_extra_warnings=False,
  108. check_stacklevel=False,
  109. ):
  110. result = df.groupby("Buyer").resample("5D").sum()
  111. assert isinstance(result, obj)