hist.py 16 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525526527528529530531532533534535536537538539540541542543544545546547548549550551552553554555556557558559560561562563564565566567568569570571572573574575576577578579580581
  1. from __future__ import annotations
  2. from typing import (
  3. TYPE_CHECKING,
  4. Any,
  5. Literal,
  6. final,
  7. )
  8. import numpy as np
  9. from pandas.core.dtypes.common import (
  10. is_integer,
  11. is_list_like,
  12. )
  13. from pandas.core.dtypes.generic import (
  14. ABCDataFrame,
  15. ABCIndex,
  16. )
  17. from pandas.core.dtypes.missing import (
  18. isna,
  19. remove_na_arraylike,
  20. )
  21. from pandas.io.formats.printing import pprint_thing
  22. from pandas.plotting._matplotlib.core import (
  23. LinePlot,
  24. MPLPlot,
  25. )
  26. from pandas.plotting._matplotlib.groupby import (
  27. create_iter_data_given_by,
  28. reformat_hist_y_given_by,
  29. )
  30. from pandas.plotting._matplotlib.misc import unpack_single_str_list
  31. from pandas.plotting._matplotlib.tools import (
  32. create_subplots,
  33. flatten_axes,
  34. maybe_adjust_figure,
  35. set_ticks_props,
  36. )
  37. if TYPE_CHECKING:
  38. from matplotlib.axes import Axes
  39. from matplotlib.figure import Figure
  40. from pandas._typing import PlottingOrientation
  41. from pandas import (
  42. DataFrame,
  43. Series,
  44. )
  45. class HistPlot(LinePlot):
  46. @property
  47. def _kind(self) -> Literal["hist", "kde"]:
  48. return "hist"
  49. def __init__(
  50. self,
  51. data,
  52. bins: int | np.ndarray | list[np.ndarray] = 10,
  53. bottom: int | np.ndarray = 0,
  54. *,
  55. range=None,
  56. weights=None,
  57. **kwargs,
  58. ) -> None:
  59. if is_list_like(bottom):
  60. bottom = np.array(bottom)
  61. self.bottom = bottom
  62. self._bin_range = range
  63. self.weights = weights
  64. self.xlabel = kwargs.get("xlabel")
  65. self.ylabel = kwargs.get("ylabel")
  66. # Do not call LinePlot.__init__ which may fill nan
  67. MPLPlot.__init__(self, data, **kwargs) # pylint: disable=non-parent-init-called
  68. self.bins = self._adjust_bins(bins)
  69. def _adjust_bins(self, bins: int | np.ndarray | list[np.ndarray]):
  70. if is_integer(bins):
  71. if self.by is not None:
  72. by_modified = unpack_single_str_list(self.by)
  73. grouped = self.data.groupby(by_modified)[self.columns]
  74. bins = [self._calculate_bins(group, bins) for key, group in grouped]
  75. else:
  76. bins = self._calculate_bins(self.data, bins)
  77. return bins
  78. def _calculate_bins(self, data: Series | DataFrame, bins) -> np.ndarray:
  79. """Calculate bins given data"""
  80. nd_values = data.infer_objects(copy=False)._get_numeric_data()
  81. values = np.ravel(nd_values)
  82. values = values[~isna(values)]
  83. hist, bins = np.histogram(values, bins=bins, range=self._bin_range)
  84. return bins
  85. # error: Signature of "_plot" incompatible with supertype "LinePlot"
  86. @classmethod
  87. def _plot( # type: ignore[override]
  88. cls,
  89. ax: Axes,
  90. y: np.ndarray,
  91. style=None,
  92. bottom: int | np.ndarray = 0,
  93. column_num: int = 0,
  94. stacking_id=None,
  95. *,
  96. bins,
  97. **kwds,
  98. ):
  99. if column_num == 0:
  100. cls._initialize_stacker(ax, stacking_id, len(bins) - 1)
  101. base = np.zeros(len(bins) - 1)
  102. bottom = bottom + cls._get_stacked_values(ax, stacking_id, base, kwds["label"])
  103. # ignore style
  104. n, bins, patches = ax.hist(y, bins=bins, bottom=bottom, **kwds)
  105. cls._update_stacker(ax, stacking_id, n)
  106. return patches
  107. def _make_plot(self, fig: Figure) -> None:
  108. colors = self._get_colors()
  109. stacking_id = self._get_stacking_id()
  110. # Re-create iterated data if `by` is assigned by users
  111. data = (
  112. create_iter_data_given_by(self.data, self._kind)
  113. if self.by is not None
  114. else self.data
  115. )
  116. # error: Argument "data" to "_iter_data" of "MPLPlot" has incompatible
  117. # type "object"; expected "DataFrame | dict[Hashable, Series | DataFrame]"
  118. for i, (label, y) in enumerate(self._iter_data(data=data)): # type: ignore[arg-type]
  119. ax = self._get_ax(i)
  120. kwds = self.kwds.copy()
  121. if self.color is not None:
  122. kwds["color"] = self.color
  123. label = pprint_thing(label)
  124. label = self._mark_right_label(label, index=i)
  125. kwds["label"] = label
  126. style, kwds = self._apply_style_colors(colors, kwds, i, label)
  127. if style is not None:
  128. kwds["style"] = style
  129. self._make_plot_keywords(kwds, y)
  130. # the bins is multi-dimension array now and each plot need only 1-d and
  131. # when by is applied, label should be columns that are grouped
  132. if self.by is not None:
  133. kwds["bins"] = kwds["bins"][i]
  134. kwds["label"] = self.columns
  135. kwds.pop("color")
  136. if self.weights is not None:
  137. kwds["weights"] = type(self)._get_column_weights(self.weights, i, y)
  138. y = reformat_hist_y_given_by(y, self.by)
  139. artists = self._plot(ax, y, column_num=i, stacking_id=stacking_id, **kwds)
  140. # when by is applied, show title for subplots to know which group it is
  141. if self.by is not None:
  142. ax.set_title(pprint_thing(label))
  143. self._append_legend_handles_labels(artists[0], label)
  144. def _make_plot_keywords(self, kwds: dict[str, Any], y: np.ndarray) -> None:
  145. """merge BoxPlot/KdePlot properties to passed kwds"""
  146. # y is required for KdePlot
  147. kwds["bottom"] = self.bottom
  148. kwds["bins"] = self.bins
  149. @final
  150. @staticmethod
  151. def _get_column_weights(weights, i: int, y):
  152. # We allow weights to be a multi-dimensional array, e.g. a (10, 2) array,
  153. # and each sub-array (10,) will be called in each iteration. If users only
  154. # provide 1D array, we assume the same weights is used for all iterations
  155. if weights is not None:
  156. if np.ndim(weights) != 1 and np.shape(weights)[-1] != 1:
  157. try:
  158. weights = weights[:, i]
  159. except IndexError as err:
  160. raise ValueError(
  161. "weights must have the same shape as data, "
  162. "or be a single column"
  163. ) from err
  164. weights = weights[~isna(y)]
  165. return weights
  166. def _post_plot_logic(self, ax: Axes, data) -> None:
  167. if self.orientation == "horizontal":
  168. # error: Argument 1 to "set_xlabel" of "_AxesBase" has incompatible
  169. # type "Hashable"; expected "str"
  170. ax.set_xlabel(
  171. "Frequency"
  172. if self.xlabel is None
  173. else self.xlabel # type: ignore[arg-type]
  174. )
  175. ax.set_ylabel(self.ylabel) # type: ignore[arg-type]
  176. else:
  177. ax.set_xlabel(self.xlabel) # type: ignore[arg-type]
  178. ax.set_ylabel(
  179. "Frequency"
  180. if self.ylabel is None
  181. else self.ylabel # type: ignore[arg-type]
  182. )
  183. @property
  184. def orientation(self) -> PlottingOrientation:
  185. if self.kwds.get("orientation", None) == "horizontal":
  186. return "horizontal"
  187. else:
  188. return "vertical"
  189. class KdePlot(HistPlot):
  190. @property
  191. def _kind(self) -> Literal["kde"]:
  192. return "kde"
  193. @property
  194. def orientation(self) -> Literal["vertical"]:
  195. return "vertical"
  196. def __init__(
  197. self, data, bw_method=None, ind=None, *, weights=None, **kwargs
  198. ) -> None:
  199. # Do not call LinePlot.__init__ which may fill nan
  200. MPLPlot.__init__(self, data, **kwargs) # pylint: disable=non-parent-init-called
  201. self.bw_method = bw_method
  202. self.ind = ind
  203. self.weights = weights
  204. @staticmethod
  205. def _get_ind(y: np.ndarray, ind):
  206. if ind is None:
  207. # np.nanmax() and np.nanmin() ignores the missing values
  208. sample_range = np.nanmax(y) - np.nanmin(y)
  209. ind = np.linspace(
  210. np.nanmin(y) - 0.5 * sample_range,
  211. np.nanmax(y) + 0.5 * sample_range,
  212. 1000,
  213. )
  214. elif is_integer(ind):
  215. sample_range = np.nanmax(y) - np.nanmin(y)
  216. ind = np.linspace(
  217. np.nanmin(y) - 0.5 * sample_range,
  218. np.nanmax(y) + 0.5 * sample_range,
  219. ind,
  220. )
  221. return ind
  222. @classmethod
  223. # error: Signature of "_plot" incompatible with supertype "MPLPlot"
  224. def _plot( # type: ignore[override]
  225. cls,
  226. ax: Axes,
  227. y: np.ndarray,
  228. style=None,
  229. bw_method=None,
  230. ind=None,
  231. column_num=None,
  232. stacking_id: int | None = None,
  233. **kwds,
  234. ):
  235. from scipy.stats import gaussian_kde
  236. y = remove_na_arraylike(y)
  237. gkde = gaussian_kde(y, bw_method=bw_method)
  238. y = gkde.evaluate(ind)
  239. lines = MPLPlot._plot(ax, ind, y, style=style, **kwds)
  240. return lines
  241. def _make_plot_keywords(self, kwds: dict[str, Any], y: np.ndarray) -> None:
  242. kwds["bw_method"] = self.bw_method
  243. kwds["ind"] = type(self)._get_ind(y, ind=self.ind)
  244. def _post_plot_logic(self, ax: Axes, data) -> None:
  245. ax.set_ylabel("Density")
  246. def _grouped_plot(
  247. plotf,
  248. data: Series | DataFrame,
  249. column=None,
  250. by=None,
  251. numeric_only: bool = True,
  252. figsize: tuple[float, float] | None = None,
  253. sharex: bool = True,
  254. sharey: bool = True,
  255. layout=None,
  256. rot: float = 0,
  257. ax=None,
  258. **kwargs,
  259. ):
  260. # error: Non-overlapping equality check (left operand type: "Optional[Tuple[float,
  261. # float]]", right operand type: "Literal['default']")
  262. if figsize == "default": # type: ignore[comparison-overlap]
  263. # allowed to specify mpl default with 'default'
  264. raise ValueError(
  265. "figsize='default' is no longer supported. "
  266. "Specify figure size by tuple instead"
  267. )
  268. grouped = data.groupby(by)
  269. if column is not None:
  270. grouped = grouped[column]
  271. naxes = len(grouped)
  272. fig, axes = create_subplots(
  273. naxes=naxes, figsize=figsize, sharex=sharex, sharey=sharey, ax=ax, layout=layout
  274. )
  275. _axes = flatten_axes(axes)
  276. for i, (key, group) in enumerate(grouped):
  277. ax = _axes[i]
  278. if numeric_only and isinstance(group, ABCDataFrame):
  279. group = group._get_numeric_data()
  280. plotf(group, ax, **kwargs)
  281. ax.set_title(pprint_thing(key))
  282. return fig, axes
  283. def _grouped_hist(
  284. data: Series | DataFrame,
  285. column=None,
  286. by=None,
  287. ax=None,
  288. bins: int = 50,
  289. figsize: tuple[float, float] | None = None,
  290. layout=None,
  291. sharex: bool = False,
  292. sharey: bool = False,
  293. rot: float = 90,
  294. grid: bool = True,
  295. xlabelsize: int | None = None,
  296. xrot=None,
  297. ylabelsize: int | None = None,
  298. yrot=None,
  299. legend: bool = False,
  300. **kwargs,
  301. ):
  302. """
  303. Grouped histogram
  304. Parameters
  305. ----------
  306. data : Series/DataFrame
  307. column : object, optional
  308. by : object, optional
  309. ax : axes, optional
  310. bins : int, default 50
  311. figsize : tuple, optional
  312. layout : optional
  313. sharex : bool, default False
  314. sharey : bool, default False
  315. rot : float, default 90
  316. grid : bool, default True
  317. legend: : bool, default False
  318. kwargs : dict, keyword arguments passed to matplotlib.Axes.hist
  319. Returns
  320. -------
  321. collection of Matplotlib Axes
  322. """
  323. if legend:
  324. assert "label" not in kwargs
  325. if data.ndim == 1:
  326. kwargs["label"] = data.name
  327. elif column is None:
  328. kwargs["label"] = data.columns
  329. else:
  330. kwargs["label"] = column
  331. def plot_group(group, ax) -> None:
  332. ax.hist(group.dropna().values, bins=bins, **kwargs)
  333. if legend:
  334. ax.legend()
  335. if xrot is None:
  336. xrot = rot
  337. fig, axes = _grouped_plot(
  338. plot_group,
  339. data,
  340. column=column,
  341. by=by,
  342. sharex=sharex,
  343. sharey=sharey,
  344. ax=ax,
  345. figsize=figsize,
  346. layout=layout,
  347. rot=rot,
  348. )
  349. set_ticks_props(
  350. axes, xlabelsize=xlabelsize, xrot=xrot, ylabelsize=ylabelsize, yrot=yrot
  351. )
  352. maybe_adjust_figure(
  353. fig, bottom=0.15, top=0.9, left=0.1, right=0.9, hspace=0.5, wspace=0.3
  354. )
  355. return axes
  356. def hist_series(
  357. self: Series,
  358. by=None,
  359. ax=None,
  360. grid: bool = True,
  361. xlabelsize: int | None = None,
  362. xrot=None,
  363. ylabelsize: int | None = None,
  364. yrot=None,
  365. figsize: tuple[float, float] | None = None,
  366. bins: int = 10,
  367. legend: bool = False,
  368. **kwds,
  369. ):
  370. import matplotlib.pyplot as plt
  371. if legend and "label" in kwds:
  372. raise ValueError("Cannot use both legend and label")
  373. if by is None:
  374. if kwds.get("layout", None) is not None:
  375. raise ValueError("The 'layout' keyword is not supported when 'by' is None")
  376. # hack until the plotting interface is a bit more unified
  377. fig = kwds.pop(
  378. "figure", plt.gcf() if plt.get_fignums() else plt.figure(figsize=figsize)
  379. )
  380. if figsize is not None and tuple(figsize) != tuple(fig.get_size_inches()):
  381. fig.set_size_inches(*figsize, forward=True)
  382. if ax is None:
  383. ax = fig.gca()
  384. elif ax.get_figure() != fig:
  385. raise AssertionError("passed axis not bound to passed figure")
  386. values = self.dropna().values
  387. if legend:
  388. kwds["label"] = self.name
  389. ax.hist(values, bins=bins, **kwds)
  390. if legend:
  391. ax.legend()
  392. ax.grid(grid)
  393. axes = np.array([ax])
  394. # error: Argument 1 to "set_ticks_props" has incompatible type "ndarray[Any,
  395. # dtype[Any]]"; expected "Axes | Sequence[Axes]"
  396. set_ticks_props(
  397. axes, # type: ignore[arg-type]
  398. xlabelsize=xlabelsize,
  399. xrot=xrot,
  400. ylabelsize=ylabelsize,
  401. yrot=yrot,
  402. )
  403. else:
  404. if "figure" in kwds:
  405. raise ValueError(
  406. "Cannot pass 'figure' when using the "
  407. "'by' argument, since a new 'Figure' instance will be created"
  408. )
  409. axes = _grouped_hist(
  410. self,
  411. by=by,
  412. ax=ax,
  413. grid=grid,
  414. figsize=figsize,
  415. bins=bins,
  416. xlabelsize=xlabelsize,
  417. xrot=xrot,
  418. ylabelsize=ylabelsize,
  419. yrot=yrot,
  420. legend=legend,
  421. **kwds,
  422. )
  423. if hasattr(axes, "ndim"):
  424. if axes.ndim == 1 and len(axes) == 1:
  425. return axes[0]
  426. return axes
  427. def hist_frame(
  428. data: DataFrame,
  429. column=None,
  430. by=None,
  431. grid: bool = True,
  432. xlabelsize: int | None = None,
  433. xrot=None,
  434. ylabelsize: int | None = None,
  435. yrot=None,
  436. ax=None,
  437. sharex: bool = False,
  438. sharey: bool = False,
  439. figsize: tuple[float, float] | None = None,
  440. layout=None,
  441. bins: int = 10,
  442. legend: bool = False,
  443. **kwds,
  444. ):
  445. if legend and "label" in kwds:
  446. raise ValueError("Cannot use both legend and label")
  447. if by is not None:
  448. axes = _grouped_hist(
  449. data,
  450. column=column,
  451. by=by,
  452. ax=ax,
  453. grid=grid,
  454. figsize=figsize,
  455. sharex=sharex,
  456. sharey=sharey,
  457. layout=layout,
  458. bins=bins,
  459. xlabelsize=xlabelsize,
  460. xrot=xrot,
  461. ylabelsize=ylabelsize,
  462. yrot=yrot,
  463. legend=legend,
  464. **kwds,
  465. )
  466. return axes
  467. if column is not None:
  468. if not isinstance(column, (list, np.ndarray, ABCIndex)):
  469. column = [column]
  470. data = data[column]
  471. # GH32590
  472. data = data.select_dtypes(
  473. include=(np.number, "datetime64", "datetimetz"), exclude="timedelta"
  474. )
  475. naxes = len(data.columns)
  476. if naxes == 0:
  477. raise ValueError(
  478. "hist method requires numerical or datetime columns, nothing to plot."
  479. )
  480. fig, axes = create_subplots(
  481. naxes=naxes,
  482. ax=ax,
  483. squeeze=False,
  484. sharex=sharex,
  485. sharey=sharey,
  486. figsize=figsize,
  487. layout=layout,
  488. )
  489. _axes = flatten_axes(axes)
  490. can_set_label = "label" not in kwds
  491. for i, col in enumerate(data.columns):
  492. ax = _axes[i]
  493. if legend and can_set_label:
  494. kwds["label"] = col
  495. ax.hist(data[col].dropna().values, bins=bins, **kwds)
  496. ax.set_title(col)
  497. ax.grid(grid)
  498. if legend:
  499. ax.legend()
  500. set_ticks_props(
  501. axes, xlabelsize=xlabelsize, xrot=xrot, ylabelsize=ylabelsize, yrot=yrot
  502. )
  503. maybe_adjust_figure(fig, wspace=0.3, hspace=0.3)
  504. return axes