groupby.py 4.2 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142
  1. from __future__ import annotations
  2. from typing import TYPE_CHECKING
  3. import numpy as np
  4. from pandas.core.dtypes.missing import remove_na_arraylike
  5. from pandas import (
  6. MultiIndex,
  7. concat,
  8. )
  9. from pandas.plotting._matplotlib.misc import unpack_single_str_list
  10. if TYPE_CHECKING:
  11. from collections.abc import Hashable
  12. from pandas._typing import IndexLabel
  13. from pandas import (
  14. DataFrame,
  15. Series,
  16. )
  17. def create_iter_data_given_by(
  18. data: DataFrame, kind: str = "hist"
  19. ) -> dict[Hashable, DataFrame | Series]:
  20. """
  21. Create data for iteration given `by` is assigned or not, and it is only
  22. used in both hist and boxplot.
  23. If `by` is assigned, return a dictionary of DataFrames in which the key of
  24. dictionary is the values in groups.
  25. If `by` is not assigned, return input as is, and this preserves current
  26. status of iter_data.
  27. Parameters
  28. ----------
  29. data : reformatted grouped data from `_compute_plot_data` method.
  30. kind : str, plot kind. This function is only used for `hist` and `box` plots.
  31. Returns
  32. -------
  33. iter_data : DataFrame or Dictionary of DataFrames
  34. Examples
  35. --------
  36. If `by` is assigned:
  37. >>> import numpy as np
  38. >>> tuples = [('h1', 'a'), ('h1', 'b'), ('h2', 'a'), ('h2', 'b')]
  39. >>> mi = pd.MultiIndex.from_tuples(tuples)
  40. >>> value = [[1, 3, np.nan, np.nan],
  41. ... [3, 4, np.nan, np.nan], [np.nan, np.nan, 5, 6]]
  42. >>> data = pd.DataFrame(value, columns=mi)
  43. >>> create_iter_data_given_by(data)
  44. {'h1': h1
  45. a b
  46. 0 1.0 3.0
  47. 1 3.0 4.0
  48. 2 NaN NaN, 'h2': h2
  49. a b
  50. 0 NaN NaN
  51. 1 NaN NaN
  52. 2 5.0 6.0}
  53. """
  54. # For `hist` plot, before transformation, the values in level 0 are values
  55. # in groups and subplot titles, and later used for column subselection and
  56. # iteration; For `box` plot, values in level 1 are column names to show,
  57. # and are used for iteration and as subplots titles.
  58. if kind == "hist":
  59. level = 0
  60. else:
  61. level = 1
  62. # Select sub-columns based on the value of level of MI, and if `by` is
  63. # assigned, data must be a MI DataFrame
  64. assert isinstance(data.columns, MultiIndex)
  65. return {
  66. col: data.loc[:, data.columns.get_level_values(level) == col]
  67. for col in data.columns.levels[level]
  68. }
  69. def reconstruct_data_with_by(
  70. data: DataFrame, by: IndexLabel, cols: IndexLabel
  71. ) -> DataFrame:
  72. """
  73. Internal function to group data, and reassign multiindex column names onto the
  74. result in order to let grouped data be used in _compute_plot_data method.
  75. Parameters
  76. ----------
  77. data : Original DataFrame to plot
  78. by : grouped `by` parameter selected by users
  79. cols : columns of data set (excluding columns used in `by`)
  80. Returns
  81. -------
  82. Output is the reconstructed DataFrame with MultiIndex columns. The first level
  83. of MI is unique values of groups, and second level of MI is the columns
  84. selected by users.
  85. Examples
  86. --------
  87. >>> d = {'h': ['h1', 'h1', 'h2'], 'a': [1, 3, 5], 'b': [3, 4, 6]}
  88. >>> df = pd.DataFrame(d)
  89. >>> reconstruct_data_with_by(df, by='h', cols=['a', 'b'])
  90. h1 h2
  91. a b a b
  92. 0 1.0 3.0 NaN NaN
  93. 1 3.0 4.0 NaN NaN
  94. 2 NaN NaN 5.0 6.0
  95. """
  96. by_modified = unpack_single_str_list(by)
  97. grouped = data.groupby(by_modified)
  98. data_list = []
  99. for key, group in grouped:
  100. # error: List item 1 has incompatible type "Union[Hashable,
  101. # Sequence[Hashable]]"; expected "Iterable[Hashable]"
  102. columns = MultiIndex.from_product([[key], cols]) # type: ignore[list-item]
  103. sub_group = group[cols]
  104. sub_group.columns = columns
  105. data_list.append(sub_group)
  106. data = concat(data_list, axis=1)
  107. return data
  108. def reformat_hist_y_given_by(y: np.ndarray, by: IndexLabel | None) -> np.ndarray:
  109. """Internal function to reformat y given `by` is applied or not for hist plot.
  110. If by is None, input y is 1-d with NaN removed; and if by is not None, groupby
  111. will take place and input y is multi-dimensional array.
  112. """
  113. if by is not None and len(y.shape) > 1:
  114. return np.array([remove_na_arraylike(col) for col in y.T]).T
  115. return remove_na_arraylike(y)