| 123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525526527528529530531532533534535536537538539540541542543544545546547548549550551552553554555556557558559560561562563564565566567568569570571572573574575576577578579580581 |
- from __future__ import annotations
- from typing import (
- TYPE_CHECKING,
- Any,
- Literal,
- final,
- )
- import numpy as np
- from pandas.core.dtypes.common import (
- is_integer,
- is_list_like,
- )
- from pandas.core.dtypes.generic import (
- ABCDataFrame,
- ABCIndex,
- )
- from pandas.core.dtypes.missing import (
- isna,
- remove_na_arraylike,
- )
- from pandas.io.formats.printing import pprint_thing
- from pandas.plotting._matplotlib.core import (
- LinePlot,
- MPLPlot,
- )
- from pandas.plotting._matplotlib.groupby import (
- create_iter_data_given_by,
- reformat_hist_y_given_by,
- )
- from pandas.plotting._matplotlib.misc import unpack_single_str_list
- from pandas.plotting._matplotlib.tools import (
- create_subplots,
- flatten_axes,
- maybe_adjust_figure,
- set_ticks_props,
- )
- if TYPE_CHECKING:
- from matplotlib.axes import Axes
- from matplotlib.figure import Figure
- from pandas._typing import PlottingOrientation
- from pandas import (
- DataFrame,
- Series,
- )
- class HistPlot(LinePlot):
- @property
- def _kind(self) -> Literal["hist", "kde"]:
- return "hist"
- def __init__(
- self,
- data,
- bins: int | np.ndarray | list[np.ndarray] = 10,
- bottom: int | np.ndarray = 0,
- *,
- range=None,
- weights=None,
- **kwargs,
- ) -> None:
- if is_list_like(bottom):
- bottom = np.array(bottom)
- self.bottom = bottom
- self._bin_range = range
- self.weights = weights
- self.xlabel = kwargs.get("xlabel")
- self.ylabel = kwargs.get("ylabel")
- # Do not call LinePlot.__init__ which may fill nan
- MPLPlot.__init__(self, data, **kwargs) # pylint: disable=non-parent-init-called
- self.bins = self._adjust_bins(bins)
- def _adjust_bins(self, bins: int | np.ndarray | list[np.ndarray]):
- if is_integer(bins):
- if self.by is not None:
- by_modified = unpack_single_str_list(self.by)
- grouped = self.data.groupby(by_modified)[self.columns]
- bins = [self._calculate_bins(group, bins) for key, group in grouped]
- else:
- bins = self._calculate_bins(self.data, bins)
- return bins
- def _calculate_bins(self, data: Series | DataFrame, bins) -> np.ndarray:
- """Calculate bins given data"""
- nd_values = data.infer_objects(copy=False)._get_numeric_data()
- values = np.ravel(nd_values)
- values = values[~isna(values)]
- hist, bins = np.histogram(values, bins=bins, range=self._bin_range)
- return bins
- # error: Signature of "_plot" incompatible with supertype "LinePlot"
- @classmethod
- def _plot( # type: ignore[override]
- cls,
- ax: Axes,
- y: np.ndarray,
- style=None,
- bottom: int | np.ndarray = 0,
- column_num: int = 0,
- stacking_id=None,
- *,
- bins,
- **kwds,
- ):
- if column_num == 0:
- cls._initialize_stacker(ax, stacking_id, len(bins) - 1)
- base = np.zeros(len(bins) - 1)
- bottom = bottom + cls._get_stacked_values(ax, stacking_id, base, kwds["label"])
- # ignore style
- n, bins, patches = ax.hist(y, bins=bins, bottom=bottom, **kwds)
- cls._update_stacker(ax, stacking_id, n)
- return patches
- def _make_plot(self, fig: Figure) -> None:
- colors = self._get_colors()
- stacking_id = self._get_stacking_id()
- # Re-create iterated data if `by` is assigned by users
- data = (
- create_iter_data_given_by(self.data, self._kind)
- if self.by is not None
- else self.data
- )
- # error: Argument "data" to "_iter_data" of "MPLPlot" has incompatible
- # type "object"; expected "DataFrame | dict[Hashable, Series | DataFrame]"
- for i, (label, y) in enumerate(self._iter_data(data=data)): # type: ignore[arg-type]
- ax = self._get_ax(i)
- kwds = self.kwds.copy()
- if self.color is not None:
- kwds["color"] = self.color
- label = pprint_thing(label)
- label = self._mark_right_label(label, index=i)
- kwds["label"] = label
- style, kwds = self._apply_style_colors(colors, kwds, i, label)
- if style is not None:
- kwds["style"] = style
- self._make_plot_keywords(kwds, y)
- # the bins is multi-dimension array now and each plot need only 1-d and
- # when by is applied, label should be columns that are grouped
- if self.by is not None:
- kwds["bins"] = kwds["bins"][i]
- kwds["label"] = self.columns
- kwds.pop("color")
- if self.weights is not None:
- kwds["weights"] = type(self)._get_column_weights(self.weights, i, y)
- y = reformat_hist_y_given_by(y, self.by)
- artists = self._plot(ax, y, column_num=i, stacking_id=stacking_id, **kwds)
- # when by is applied, show title for subplots to know which group it is
- if self.by is not None:
- ax.set_title(pprint_thing(label))
- self._append_legend_handles_labels(artists[0], label)
- def _make_plot_keywords(self, kwds: dict[str, Any], y: np.ndarray) -> None:
- """merge BoxPlot/KdePlot properties to passed kwds"""
- # y is required for KdePlot
- kwds["bottom"] = self.bottom
- kwds["bins"] = self.bins
- @final
- @staticmethod
- def _get_column_weights(weights, i: int, y):
- # We allow weights to be a multi-dimensional array, e.g. a (10, 2) array,
- # and each sub-array (10,) will be called in each iteration. If users only
- # provide 1D array, we assume the same weights is used for all iterations
- if weights is not None:
- if np.ndim(weights) != 1 and np.shape(weights)[-1] != 1:
- try:
- weights = weights[:, i]
- except IndexError as err:
- raise ValueError(
- "weights must have the same shape as data, "
- "or be a single column"
- ) from err
- weights = weights[~isna(y)]
- return weights
- def _post_plot_logic(self, ax: Axes, data) -> None:
- if self.orientation == "horizontal":
- # error: Argument 1 to "set_xlabel" of "_AxesBase" has incompatible
- # type "Hashable"; expected "str"
- ax.set_xlabel(
- "Frequency"
- if self.xlabel is None
- else self.xlabel # type: ignore[arg-type]
- )
- ax.set_ylabel(self.ylabel) # type: ignore[arg-type]
- else:
- ax.set_xlabel(self.xlabel) # type: ignore[arg-type]
- ax.set_ylabel(
- "Frequency"
- if self.ylabel is None
- else self.ylabel # type: ignore[arg-type]
- )
- @property
- def orientation(self) -> PlottingOrientation:
- if self.kwds.get("orientation", None) == "horizontal":
- return "horizontal"
- else:
- return "vertical"
- class KdePlot(HistPlot):
- @property
- def _kind(self) -> Literal["kde"]:
- return "kde"
- @property
- def orientation(self) -> Literal["vertical"]:
- return "vertical"
- def __init__(
- self, data, bw_method=None, ind=None, *, weights=None, **kwargs
- ) -> None:
- # Do not call LinePlot.__init__ which may fill nan
- MPLPlot.__init__(self, data, **kwargs) # pylint: disable=non-parent-init-called
- self.bw_method = bw_method
- self.ind = ind
- self.weights = weights
- @staticmethod
- def _get_ind(y: np.ndarray, ind):
- if ind is None:
- # np.nanmax() and np.nanmin() ignores the missing values
- sample_range = np.nanmax(y) - np.nanmin(y)
- ind = np.linspace(
- np.nanmin(y) - 0.5 * sample_range,
- np.nanmax(y) + 0.5 * sample_range,
- 1000,
- )
- elif is_integer(ind):
- sample_range = np.nanmax(y) - np.nanmin(y)
- ind = np.linspace(
- np.nanmin(y) - 0.5 * sample_range,
- np.nanmax(y) + 0.5 * sample_range,
- ind,
- )
- return ind
- @classmethod
- # error: Signature of "_plot" incompatible with supertype "MPLPlot"
- def _plot( # type: ignore[override]
- cls,
- ax: Axes,
- y: np.ndarray,
- style=None,
- bw_method=None,
- ind=None,
- column_num=None,
- stacking_id: int | None = None,
- **kwds,
- ):
- from scipy.stats import gaussian_kde
- y = remove_na_arraylike(y)
- gkde = gaussian_kde(y, bw_method=bw_method)
- y = gkde.evaluate(ind)
- lines = MPLPlot._plot(ax, ind, y, style=style, **kwds)
- return lines
- def _make_plot_keywords(self, kwds: dict[str, Any], y: np.ndarray) -> None:
- kwds["bw_method"] = self.bw_method
- kwds["ind"] = type(self)._get_ind(y, ind=self.ind)
- def _post_plot_logic(self, ax: Axes, data) -> None:
- ax.set_ylabel("Density")
- def _grouped_plot(
- plotf,
- data: Series | DataFrame,
- column=None,
- by=None,
- numeric_only: bool = True,
- figsize: tuple[float, float] | None = None,
- sharex: bool = True,
- sharey: bool = True,
- layout=None,
- rot: float = 0,
- ax=None,
- **kwargs,
- ):
- # error: Non-overlapping equality check (left operand type: "Optional[Tuple[float,
- # float]]", right operand type: "Literal['default']")
- if figsize == "default": # type: ignore[comparison-overlap]
- # allowed to specify mpl default with 'default'
- raise ValueError(
- "figsize='default' is no longer supported. "
- "Specify figure size by tuple instead"
- )
- grouped = data.groupby(by)
- if column is not None:
- grouped = grouped[column]
- naxes = len(grouped)
- fig, axes = create_subplots(
- naxes=naxes, figsize=figsize, sharex=sharex, sharey=sharey, ax=ax, layout=layout
- )
- _axes = flatten_axes(axes)
- for i, (key, group) in enumerate(grouped):
- ax = _axes[i]
- if numeric_only and isinstance(group, ABCDataFrame):
- group = group._get_numeric_data()
- plotf(group, ax, **kwargs)
- ax.set_title(pprint_thing(key))
- return fig, axes
- def _grouped_hist(
- data: Series | DataFrame,
- column=None,
- by=None,
- ax=None,
- bins: int = 50,
- figsize: tuple[float, float] | None = None,
- layout=None,
- sharex: bool = False,
- sharey: bool = False,
- rot: float = 90,
- grid: bool = True,
- xlabelsize: int | None = None,
- xrot=None,
- ylabelsize: int | None = None,
- yrot=None,
- legend: bool = False,
- **kwargs,
- ):
- """
- Grouped histogram
- Parameters
- ----------
- data : Series/DataFrame
- column : object, optional
- by : object, optional
- ax : axes, optional
- bins : int, default 50
- figsize : tuple, optional
- layout : optional
- sharex : bool, default False
- sharey : bool, default False
- rot : float, default 90
- grid : bool, default True
- legend: : bool, default False
- kwargs : dict, keyword arguments passed to matplotlib.Axes.hist
- Returns
- -------
- collection of Matplotlib Axes
- """
- if legend:
- assert "label" not in kwargs
- if data.ndim == 1:
- kwargs["label"] = data.name
- elif column is None:
- kwargs["label"] = data.columns
- else:
- kwargs["label"] = column
- def plot_group(group, ax) -> None:
- ax.hist(group.dropna().values, bins=bins, **kwargs)
- if legend:
- ax.legend()
- if xrot is None:
- xrot = rot
- fig, axes = _grouped_plot(
- plot_group,
- data,
- column=column,
- by=by,
- sharex=sharex,
- sharey=sharey,
- ax=ax,
- figsize=figsize,
- layout=layout,
- rot=rot,
- )
- set_ticks_props(
- axes, xlabelsize=xlabelsize, xrot=xrot, ylabelsize=ylabelsize, yrot=yrot
- )
- maybe_adjust_figure(
- fig, bottom=0.15, top=0.9, left=0.1, right=0.9, hspace=0.5, wspace=0.3
- )
- return axes
- def hist_series(
- self: Series,
- by=None,
- ax=None,
- grid: bool = True,
- xlabelsize: int | None = None,
- xrot=None,
- ylabelsize: int | None = None,
- yrot=None,
- figsize: tuple[float, float] | None = None,
- bins: int = 10,
- legend: bool = False,
- **kwds,
- ):
- import matplotlib.pyplot as plt
- if legend and "label" in kwds:
- raise ValueError("Cannot use both legend and label")
- if by is None:
- if kwds.get("layout", None) is not None:
- raise ValueError("The 'layout' keyword is not supported when 'by' is None")
- # hack until the plotting interface is a bit more unified
- fig = kwds.pop(
- "figure", plt.gcf() if plt.get_fignums() else plt.figure(figsize=figsize)
- )
- if figsize is not None and tuple(figsize) != tuple(fig.get_size_inches()):
- fig.set_size_inches(*figsize, forward=True)
- if ax is None:
- ax = fig.gca()
- elif ax.get_figure() != fig:
- raise AssertionError("passed axis not bound to passed figure")
- values = self.dropna().values
- if legend:
- kwds["label"] = self.name
- ax.hist(values, bins=bins, **kwds)
- if legend:
- ax.legend()
- ax.grid(grid)
- axes = np.array([ax])
- # error: Argument 1 to "set_ticks_props" has incompatible type "ndarray[Any,
- # dtype[Any]]"; expected "Axes | Sequence[Axes]"
- set_ticks_props(
- axes, # type: ignore[arg-type]
- xlabelsize=xlabelsize,
- xrot=xrot,
- ylabelsize=ylabelsize,
- yrot=yrot,
- )
- else:
- if "figure" in kwds:
- raise ValueError(
- "Cannot pass 'figure' when using the "
- "'by' argument, since a new 'Figure' instance will be created"
- )
- axes = _grouped_hist(
- self,
- by=by,
- ax=ax,
- grid=grid,
- figsize=figsize,
- bins=bins,
- xlabelsize=xlabelsize,
- xrot=xrot,
- ylabelsize=ylabelsize,
- yrot=yrot,
- legend=legend,
- **kwds,
- )
- if hasattr(axes, "ndim"):
- if axes.ndim == 1 and len(axes) == 1:
- return axes[0]
- return axes
- def hist_frame(
- data: DataFrame,
- column=None,
- by=None,
- grid: bool = True,
- xlabelsize: int | None = None,
- xrot=None,
- ylabelsize: int | None = None,
- yrot=None,
- ax=None,
- sharex: bool = False,
- sharey: bool = False,
- figsize: tuple[float, float] | None = None,
- layout=None,
- bins: int = 10,
- legend: bool = False,
- **kwds,
- ):
- if legend and "label" in kwds:
- raise ValueError("Cannot use both legend and label")
- if by is not None:
- axes = _grouped_hist(
- data,
- column=column,
- by=by,
- ax=ax,
- grid=grid,
- figsize=figsize,
- sharex=sharex,
- sharey=sharey,
- layout=layout,
- bins=bins,
- xlabelsize=xlabelsize,
- xrot=xrot,
- ylabelsize=ylabelsize,
- yrot=yrot,
- legend=legend,
- **kwds,
- )
- return axes
- if column is not None:
- if not isinstance(column, (list, np.ndarray, ABCIndex)):
- column = [column]
- data = data[column]
- # GH32590
- data = data.select_dtypes(
- include=(np.number, "datetime64", "datetimetz"), exclude="timedelta"
- )
- naxes = len(data.columns)
- if naxes == 0:
- raise ValueError(
- "hist method requires numerical or datetime columns, nothing to plot."
- )
- fig, axes = create_subplots(
- naxes=naxes,
- ax=ax,
- squeeze=False,
- sharex=sharex,
- sharey=sharey,
- figsize=figsize,
- layout=layout,
- )
- _axes = flatten_axes(axes)
- can_set_label = "label" not in kwds
- for i, col in enumerate(data.columns):
- ax = _axes[i]
- if legend and can_set_label:
- kwds["label"] = col
- ax.hist(data[col].dropna().values, bins=bins, **kwds)
- ax.set_title(col)
- ax.grid(grid)
- if legend:
- ax.legend()
- set_ticks_props(
- axes, xlabelsize=xlabelsize, xrot=xrot, ylabelsize=ylabelsize, yrot=yrot
- )
- maybe_adjust_figure(fig, wspace=0.3, hspace=0.3)
- return axes
|