misc.py 13 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481
  1. from __future__ import annotations
  2. import random
  3. from typing import TYPE_CHECKING
  4. from matplotlib import patches
  5. import matplotlib.lines as mlines
  6. import numpy as np
  7. from pandas.core.dtypes.missing import notna
  8. from pandas.io.formats.printing import pprint_thing
  9. from pandas.plotting._matplotlib.style import get_standard_colors
  10. from pandas.plotting._matplotlib.tools import (
  11. create_subplots,
  12. do_adjust_figure,
  13. maybe_adjust_figure,
  14. set_ticks_props,
  15. )
  16. if TYPE_CHECKING:
  17. from collections.abc import Hashable
  18. from matplotlib.axes import Axes
  19. from matplotlib.figure import Figure
  20. from pandas import (
  21. DataFrame,
  22. Index,
  23. Series,
  24. )
  25. def scatter_matrix(
  26. frame: DataFrame,
  27. alpha: float = 0.5,
  28. figsize: tuple[float, float] | None = None,
  29. ax=None,
  30. grid: bool = False,
  31. diagonal: str = "hist",
  32. marker: str = ".",
  33. density_kwds=None,
  34. hist_kwds=None,
  35. range_padding: float = 0.05,
  36. **kwds,
  37. ):
  38. df = frame._get_numeric_data()
  39. n = df.columns.size
  40. naxes = n * n
  41. fig, axes = create_subplots(naxes=naxes, figsize=figsize, ax=ax, squeeze=False)
  42. # no gaps between subplots
  43. maybe_adjust_figure(fig, wspace=0, hspace=0)
  44. mask = notna(df)
  45. marker = _get_marker_compat(marker)
  46. hist_kwds = hist_kwds or {}
  47. density_kwds = density_kwds or {}
  48. # GH 14855
  49. kwds.setdefault("edgecolors", "none")
  50. boundaries_list = []
  51. for a in df.columns:
  52. values = df[a].values[mask[a].values]
  53. rmin_, rmax_ = np.min(values), np.max(values)
  54. rdelta_ext = (rmax_ - rmin_) * range_padding / 2
  55. boundaries_list.append((rmin_ - rdelta_ext, rmax_ + rdelta_ext))
  56. for i, a in enumerate(df.columns):
  57. for j, b in enumerate(df.columns):
  58. ax = axes[i, j]
  59. if i == j:
  60. values = df[a].values[mask[a].values]
  61. # Deal with the diagonal by drawing a histogram there.
  62. if diagonal == "hist":
  63. ax.hist(values, **hist_kwds)
  64. elif diagonal in ("kde", "density"):
  65. from scipy.stats import gaussian_kde
  66. y = values
  67. gkde = gaussian_kde(y)
  68. ind = np.linspace(y.min(), y.max(), 1000)
  69. ax.plot(ind, gkde.evaluate(ind), **density_kwds)
  70. ax.set_xlim(boundaries_list[i])
  71. else:
  72. common = (mask[a] & mask[b]).values
  73. ax.scatter(
  74. df[b][common], df[a][common], marker=marker, alpha=alpha, **kwds
  75. )
  76. ax.set_xlim(boundaries_list[j])
  77. ax.set_ylim(boundaries_list[i])
  78. ax.set_xlabel(b)
  79. ax.set_ylabel(a)
  80. if j != 0:
  81. ax.yaxis.set_visible(False)
  82. if i != n - 1:
  83. ax.xaxis.set_visible(False)
  84. if len(df.columns) > 1:
  85. lim1 = boundaries_list[0]
  86. locs = axes[0][1].yaxis.get_majorticklocs()
  87. locs = locs[(lim1[0] <= locs) & (locs <= lim1[1])]
  88. adj = (locs - lim1[0]) / (lim1[1] - lim1[0])
  89. lim0 = axes[0][0].get_ylim()
  90. adj = adj * (lim0[1] - lim0[0]) + lim0[0]
  91. axes[0][0].yaxis.set_ticks(adj)
  92. if np.all(locs == locs.astype(int)):
  93. # if all ticks are int
  94. locs = locs.astype(int)
  95. axes[0][0].yaxis.set_ticklabels(locs)
  96. set_ticks_props(axes, xlabelsize=8, xrot=90, ylabelsize=8, yrot=0)
  97. return axes
  98. def _get_marker_compat(marker):
  99. if marker not in mlines.lineMarkers:
  100. return "o"
  101. return marker
  102. def radviz(
  103. frame: DataFrame,
  104. class_column,
  105. ax: Axes | None = None,
  106. color=None,
  107. colormap=None,
  108. **kwds,
  109. ) -> Axes:
  110. import matplotlib.pyplot as plt
  111. def normalize(series):
  112. a = min(series)
  113. b = max(series)
  114. return (series - a) / (b - a)
  115. n = len(frame)
  116. classes = frame[class_column].drop_duplicates()
  117. class_col = frame[class_column]
  118. df = frame.drop(class_column, axis=1).apply(normalize)
  119. if ax is None:
  120. ax = plt.gca()
  121. ax.set_xlim(-1, 1)
  122. ax.set_ylim(-1, 1)
  123. to_plot: dict[Hashable, list[list]] = {}
  124. colors = get_standard_colors(
  125. num_colors=len(classes), colormap=colormap, color_type="random", color=color
  126. )
  127. for kls in classes:
  128. to_plot[kls] = [[], []]
  129. m = len(frame.columns) - 1
  130. s = np.array(
  131. [(np.cos(t), np.sin(t)) for t in [2 * np.pi * (i / m) for i in range(m)]]
  132. )
  133. for i in range(n):
  134. row = df.iloc[i].values
  135. row_ = np.repeat(np.expand_dims(row, axis=1), 2, axis=1)
  136. y = (s * row_).sum(axis=0) / row.sum()
  137. kls = class_col.iat[i]
  138. to_plot[kls][0].append(y[0])
  139. to_plot[kls][1].append(y[1])
  140. for i, kls in enumerate(classes):
  141. ax.scatter(
  142. to_plot[kls][0],
  143. to_plot[kls][1],
  144. color=colors[i],
  145. label=pprint_thing(kls),
  146. **kwds,
  147. )
  148. ax.legend()
  149. ax.add_patch(patches.Circle((0.0, 0.0), radius=1.0, facecolor="none"))
  150. for xy, name in zip(s, df.columns):
  151. ax.add_patch(patches.Circle(xy, radius=0.025, facecolor="gray"))
  152. if xy[0] < 0.0 and xy[1] < 0.0:
  153. ax.text(
  154. xy[0] - 0.025, xy[1] - 0.025, name, ha="right", va="top", size="small"
  155. )
  156. elif xy[0] < 0.0 <= xy[1]:
  157. ax.text(
  158. xy[0] - 0.025,
  159. xy[1] + 0.025,
  160. name,
  161. ha="right",
  162. va="bottom",
  163. size="small",
  164. )
  165. elif xy[1] < 0.0 <= xy[0]:
  166. ax.text(
  167. xy[0] + 0.025, xy[1] - 0.025, name, ha="left", va="top", size="small"
  168. )
  169. elif xy[0] >= 0.0 and xy[1] >= 0.0:
  170. ax.text(
  171. xy[0] + 0.025, xy[1] + 0.025, name, ha="left", va="bottom", size="small"
  172. )
  173. ax.axis("equal")
  174. return ax
  175. def andrews_curves(
  176. frame: DataFrame,
  177. class_column,
  178. ax: Axes | None = None,
  179. samples: int = 200,
  180. color=None,
  181. colormap=None,
  182. **kwds,
  183. ) -> Axes:
  184. import matplotlib.pyplot as plt
  185. def function(amplitudes):
  186. def f(t):
  187. x1 = amplitudes[0]
  188. result = x1 / np.sqrt(2.0)
  189. # Take the rest of the coefficients and resize them
  190. # appropriately. Take a copy of amplitudes as otherwise numpy
  191. # deletes the element from amplitudes itself.
  192. coeffs = np.delete(np.copy(amplitudes), 0)
  193. coeffs = np.resize(coeffs, (int((coeffs.size + 1) / 2), 2))
  194. # Generate the harmonics and arguments for the sin and cos
  195. # functions.
  196. harmonics = np.arange(0, coeffs.shape[0]) + 1
  197. trig_args = np.outer(harmonics, t)
  198. result += np.sum(
  199. coeffs[:, 0, np.newaxis] * np.sin(trig_args)
  200. + coeffs[:, 1, np.newaxis] * np.cos(trig_args),
  201. axis=0,
  202. )
  203. return result
  204. return f
  205. n = len(frame)
  206. class_col = frame[class_column]
  207. classes = frame[class_column].drop_duplicates()
  208. df = frame.drop(class_column, axis=1)
  209. t = np.linspace(-np.pi, np.pi, samples)
  210. used_legends: set[str] = set()
  211. color_values = get_standard_colors(
  212. num_colors=len(classes), colormap=colormap, color_type="random", color=color
  213. )
  214. colors = dict(zip(classes, color_values))
  215. if ax is None:
  216. ax = plt.gca()
  217. ax.set_xlim(-np.pi, np.pi)
  218. for i in range(n):
  219. row = df.iloc[i].values
  220. f = function(row)
  221. y = f(t)
  222. kls = class_col.iat[i]
  223. label = pprint_thing(kls)
  224. if label not in used_legends:
  225. used_legends.add(label)
  226. ax.plot(t, y, color=colors[kls], label=label, **kwds)
  227. else:
  228. ax.plot(t, y, color=colors[kls], **kwds)
  229. ax.legend(loc="upper right")
  230. ax.grid()
  231. return ax
  232. def bootstrap_plot(
  233. series: Series,
  234. fig: Figure | None = None,
  235. size: int = 50,
  236. samples: int = 500,
  237. **kwds,
  238. ) -> Figure:
  239. import matplotlib.pyplot as plt
  240. # TODO: is the failure mentioned below still relevant?
  241. # random.sample(ndarray, int) fails on python 3.3, sigh
  242. data = list(series.values)
  243. samplings = [random.sample(data, size) for _ in range(samples)]
  244. means = np.array([np.mean(sampling) for sampling in samplings])
  245. medians = np.array([np.median(sampling) for sampling in samplings])
  246. midranges = np.array(
  247. [(min(sampling) + max(sampling)) * 0.5 for sampling in samplings]
  248. )
  249. if fig is None:
  250. fig = plt.figure()
  251. x = list(range(samples))
  252. axes = []
  253. ax1 = fig.add_subplot(2, 3, 1)
  254. ax1.set_xlabel("Sample")
  255. axes.append(ax1)
  256. ax1.plot(x, means, **kwds)
  257. ax2 = fig.add_subplot(2, 3, 2)
  258. ax2.set_xlabel("Sample")
  259. axes.append(ax2)
  260. ax2.plot(x, medians, **kwds)
  261. ax3 = fig.add_subplot(2, 3, 3)
  262. ax3.set_xlabel("Sample")
  263. axes.append(ax3)
  264. ax3.plot(x, midranges, **kwds)
  265. ax4 = fig.add_subplot(2, 3, 4)
  266. ax4.set_xlabel("Mean")
  267. axes.append(ax4)
  268. ax4.hist(means, **kwds)
  269. ax5 = fig.add_subplot(2, 3, 5)
  270. ax5.set_xlabel("Median")
  271. axes.append(ax5)
  272. ax5.hist(medians, **kwds)
  273. ax6 = fig.add_subplot(2, 3, 6)
  274. ax6.set_xlabel("Midrange")
  275. axes.append(ax6)
  276. ax6.hist(midranges, **kwds)
  277. for axis in axes:
  278. plt.setp(axis.get_xticklabels(), fontsize=8)
  279. plt.setp(axis.get_yticklabels(), fontsize=8)
  280. if do_adjust_figure(fig):
  281. plt.tight_layout()
  282. return fig
  283. def parallel_coordinates(
  284. frame: DataFrame,
  285. class_column,
  286. cols=None,
  287. ax: Axes | None = None,
  288. color=None,
  289. use_columns: bool = False,
  290. xticks=None,
  291. colormap=None,
  292. axvlines: bool = True,
  293. axvlines_kwds=None,
  294. sort_labels: bool = False,
  295. **kwds,
  296. ) -> Axes:
  297. import matplotlib.pyplot as plt
  298. if axvlines_kwds is None:
  299. axvlines_kwds = {"linewidth": 1, "color": "black"}
  300. n = len(frame)
  301. classes = frame[class_column].drop_duplicates()
  302. class_col = frame[class_column]
  303. if cols is None:
  304. df = frame.drop(class_column, axis=1)
  305. else:
  306. df = frame[cols]
  307. used_legends: set[str] = set()
  308. ncols = len(df.columns)
  309. # determine values to use for xticks
  310. x: list[int] | Index
  311. if use_columns is True:
  312. if not np.all(np.isreal(list(df.columns))):
  313. raise ValueError("Columns must be numeric to be used as xticks")
  314. x = df.columns
  315. elif xticks is not None:
  316. if not np.all(np.isreal(xticks)):
  317. raise ValueError("xticks specified must be numeric")
  318. if len(xticks) != ncols:
  319. raise ValueError("Length of xticks must match number of columns")
  320. x = xticks
  321. else:
  322. x = list(range(ncols))
  323. if ax is None:
  324. ax = plt.gca()
  325. color_values = get_standard_colors(
  326. num_colors=len(classes), colormap=colormap, color_type="random", color=color
  327. )
  328. if sort_labels:
  329. classes = sorted(classes)
  330. color_values = sorted(color_values)
  331. colors = dict(zip(classes, color_values))
  332. for i in range(n):
  333. y = df.iloc[i].values
  334. kls = class_col.iat[i]
  335. label = pprint_thing(kls)
  336. if label not in used_legends:
  337. used_legends.add(label)
  338. ax.plot(x, y, color=colors[kls], label=label, **kwds)
  339. else:
  340. ax.plot(x, y, color=colors[kls], **kwds)
  341. if axvlines:
  342. for i in x:
  343. ax.axvline(i, **axvlines_kwds)
  344. ax.set_xticks(x)
  345. ax.set_xticklabels(df.columns)
  346. ax.set_xlim(x[0], x[-1])
  347. ax.legend(loc="upper right")
  348. ax.grid()
  349. return ax
  350. def lag_plot(series: Series, lag: int = 1, ax: Axes | None = None, **kwds) -> Axes:
  351. # workaround because `c='b'` is hardcoded in matplotlib's scatter method
  352. import matplotlib.pyplot as plt
  353. kwds.setdefault("c", plt.rcParams["patch.facecolor"])
  354. data = series.values
  355. y1 = data[:-lag]
  356. y2 = data[lag:]
  357. if ax is None:
  358. ax = plt.gca()
  359. ax.set_xlabel("y(t)")
  360. ax.set_ylabel(f"y(t + {lag})")
  361. ax.scatter(y1, y2, **kwds)
  362. return ax
  363. def autocorrelation_plot(series: Series, ax: Axes | None = None, **kwds) -> Axes:
  364. import matplotlib.pyplot as plt
  365. n = len(series)
  366. data = np.asarray(series)
  367. if ax is None:
  368. ax = plt.gca()
  369. ax.set_xlim(1, n)
  370. ax.set_ylim(-1.0, 1.0)
  371. mean = np.mean(data)
  372. c0 = np.sum((data - mean) ** 2) / n
  373. def r(h):
  374. return ((data[: n - h] - mean) * (data[h:] - mean)).sum() / n / c0
  375. x = np.arange(n) + 1
  376. y = [r(loc) for loc in x]
  377. z95 = 1.959963984540054
  378. z99 = 2.5758293035489004
  379. ax.axhline(y=z99 / np.sqrt(n), linestyle="--", color="grey")
  380. ax.axhline(y=z95 / np.sqrt(n), color="grey")
  381. ax.axhline(y=0.0, color="black")
  382. ax.axhline(y=-z95 / np.sqrt(n), color="grey")
  383. ax.axhline(y=-z99 / np.sqrt(n), linestyle="--", color="grey")
  384. ax.set_xlabel("Lag")
  385. ax.set_ylabel("Autocorrelation")
  386. ax.plot(x, y, **kwds)
  387. if "label" in kwds:
  388. ax.legend()
  389. ax.grid()
  390. return ax
  391. def unpack_single_str_list(keys):
  392. # GH 42795
  393. if isinstance(keys, list) and len(keys) == 1:
  394. keys = keys[0]
  395. return keys