common.py 17 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525526527528529530531532533534535536537538539540541542543544545546547548549550551552553554555556557558559560561562563564565566567568569570571572573574575576577578579
  1. """
  2. Module consolidating common testing functions for checking plotting.
  3. """
  4. from __future__ import annotations
  5. from typing import TYPE_CHECKING
  6. import numpy as np
  7. from pandas.core.dtypes.api import is_list_like
  8. import pandas as pd
  9. from pandas import Series
  10. import pandas._testing as tm
  11. if TYPE_CHECKING:
  12. from collections.abc import Sequence
  13. from matplotlib.axes import Axes
  14. def _check_legend_labels(axes, labels=None, visible=True):
  15. """
  16. Check each axes has expected legend labels
  17. Parameters
  18. ----------
  19. axes : matplotlib Axes object, or its list-like
  20. labels : list-like
  21. expected legend labels
  22. visible : bool
  23. expected legend visibility. labels are checked only when visible is
  24. True
  25. """
  26. if visible and (labels is None):
  27. raise ValueError("labels must be specified when visible is True")
  28. axes = _flatten_visible(axes)
  29. for ax in axes:
  30. if visible:
  31. assert ax.get_legend() is not None
  32. _check_text_labels(ax.get_legend().get_texts(), labels)
  33. else:
  34. assert ax.get_legend() is None
  35. def _check_legend_marker(ax, expected_markers=None, visible=True):
  36. """
  37. Check ax has expected legend markers
  38. Parameters
  39. ----------
  40. ax : matplotlib Axes object
  41. expected_markers : list-like
  42. expected legend markers
  43. visible : bool
  44. expected legend visibility. labels are checked only when visible is
  45. True
  46. """
  47. if visible and (expected_markers is None):
  48. raise ValueError("Markers must be specified when visible is True")
  49. if visible:
  50. handles, _ = ax.get_legend_handles_labels()
  51. markers = [handle.get_marker() for handle in handles]
  52. assert markers == expected_markers
  53. else:
  54. assert ax.get_legend() is None
  55. def _check_data(xp, rs):
  56. """
  57. Check each axes has identical lines
  58. Parameters
  59. ----------
  60. xp : matplotlib Axes object
  61. rs : matplotlib Axes object
  62. """
  63. xp_lines = xp.get_lines()
  64. rs_lines = rs.get_lines()
  65. assert len(xp_lines) == len(rs_lines)
  66. for xpl, rsl in zip(xp_lines, rs_lines, strict=True):
  67. xpdata = xpl.get_xydata()
  68. rsdata = rsl.get_xydata()
  69. tm.assert_almost_equal(xpdata, rsdata)
  70. def _check_visible(collections, visible=True):
  71. """
  72. Check each artist is visible or not
  73. Parameters
  74. ----------
  75. collections : matplotlib Artist or its list-like
  76. target Artist or its list or collection
  77. visible : bool
  78. expected visibility
  79. """
  80. from matplotlib.collections import Collection
  81. if not isinstance(collections, Collection) and not is_list_like(collections):
  82. collections = [collections]
  83. for patch in collections:
  84. assert patch.get_visible() == visible
  85. def _check_patches_all_filled(axes: Axes | Sequence[Axes], filled: bool = True) -> None:
  86. """
  87. Check for each artist whether it is filled or not
  88. Parameters
  89. ----------
  90. axes : matplotlib Axes object, or its list-like
  91. filled : bool
  92. expected filling
  93. """
  94. axes = _flatten_visible(axes)
  95. for ax in axes:
  96. for patch in ax.patches:
  97. assert patch.fill == filled
  98. def _get_colors_mapped(series, colors):
  99. unique = series.unique()
  100. # unique and colors length can be differed
  101. # depending on slice value
  102. mapped = dict(zip(unique, colors))
  103. return [mapped[v] for v in series.values]
  104. def _check_colors(collections, linecolors=None, facecolors=None, mapping=None):
  105. """
  106. Check each artist has expected line colors and face colors
  107. Parameters
  108. ----------
  109. collections : list-like
  110. list or collection of target artist
  111. linecolors : list-like which has the same length as collections
  112. list of expected line colors
  113. facecolors : list-like which has the same length as collections
  114. list of expected face colors
  115. mapping : Series
  116. Series used for color grouping key
  117. used for andrew_curves, parallel_coordinates, radviz test
  118. """
  119. from matplotlib import colors
  120. from matplotlib.collections import (
  121. Collection,
  122. LineCollection,
  123. PolyCollection,
  124. )
  125. from matplotlib.lines import Line2D
  126. conv = colors.ColorConverter
  127. if linecolors is not None:
  128. if mapping is not None:
  129. linecolors = _get_colors_mapped(mapping, linecolors)
  130. linecolors = linecolors[: len(collections)]
  131. assert len(collections) == len(linecolors)
  132. for patch, color in zip(collections, linecolors, strict=True):
  133. if isinstance(patch, Line2D):
  134. result = patch.get_color()
  135. # Line2D may contains string color expression
  136. result = conv.to_rgba(result)
  137. elif isinstance(patch, (PolyCollection, LineCollection)):
  138. result = tuple(patch.get_edgecolor()[0])
  139. else:
  140. result = patch.get_edgecolor()
  141. expected = conv.to_rgba(color)
  142. assert result == expected
  143. if facecolors is not None:
  144. if mapping is not None:
  145. facecolors = _get_colors_mapped(mapping, facecolors)
  146. facecolors = facecolors[: len(collections)]
  147. assert len(collections) == len(facecolors)
  148. for patch, color in zip(collections, facecolors, strict=True):
  149. if isinstance(patch, Collection):
  150. # returned as list of np.array
  151. result = patch.get_facecolor()[0]
  152. else:
  153. result = patch.get_facecolor()
  154. if isinstance(result, np.ndarray):
  155. result = tuple(result)
  156. expected = conv.to_rgba(color)
  157. assert result == expected
  158. def _check_text_labels(texts, expected):
  159. """
  160. Check each text has expected labels
  161. Parameters
  162. ----------
  163. texts : matplotlib Text object, or its list-like
  164. target text, or its list
  165. expected : str or list-like which has the same length as texts
  166. expected text label, or its list
  167. """
  168. if not is_list_like(texts):
  169. assert texts.get_text() == expected
  170. else:
  171. labels = [t.get_text() for t in texts]
  172. assert len(labels) == len(expected)
  173. for label, e in zip(labels, expected, strict=True):
  174. assert label == e
  175. def _check_ticks_props(axes, xlabelsize=None, xrot=None, ylabelsize=None, yrot=None):
  176. """
  177. Check each axes has expected tick properties
  178. Parameters
  179. ----------
  180. axes : matplotlib Axes object, or its list-like
  181. xlabelsize : number
  182. expected xticks font size
  183. xrot : number
  184. expected xticks rotation
  185. ylabelsize : number
  186. expected yticks font size
  187. yrot : number
  188. expected yticks rotation
  189. """
  190. from matplotlib.ticker import NullFormatter
  191. axes = _flatten_visible(axes)
  192. for ax in axes:
  193. if xlabelsize is not None or xrot is not None:
  194. if isinstance(ax.xaxis.get_minor_formatter(), NullFormatter):
  195. # If minor ticks has NullFormatter, rot / fontsize are not
  196. # retained
  197. labels = ax.get_xticklabels()
  198. else:
  199. labels = ax.get_xticklabels() + ax.get_xticklabels(minor=True)
  200. for label in labels:
  201. if xlabelsize is not None:
  202. tm.assert_almost_equal(label.get_fontsize(), xlabelsize)
  203. if xrot is not None:
  204. tm.assert_almost_equal(label.get_rotation(), xrot)
  205. if ylabelsize is not None or yrot is not None:
  206. if isinstance(ax.yaxis.get_minor_formatter(), NullFormatter):
  207. labels = ax.get_yticklabels()
  208. else:
  209. labels = ax.get_yticklabels() + ax.get_yticklabels(minor=True)
  210. for label in labels:
  211. if ylabelsize is not None:
  212. tm.assert_almost_equal(label.get_fontsize(), ylabelsize)
  213. if yrot is not None:
  214. tm.assert_almost_equal(label.get_rotation(), yrot)
  215. def _check_ax_scales(axes, xaxis="linear", yaxis="linear"):
  216. """
  217. Check each axes has expected scales
  218. Parameters
  219. ----------
  220. axes : matplotlib Axes object, or its list-like
  221. xaxis : {'linear', 'log'}
  222. expected xaxis scale
  223. yaxis : {'linear', 'log'}
  224. expected yaxis scale
  225. """
  226. axes = _flatten_visible(axes)
  227. for ax in axes:
  228. assert ax.xaxis.get_scale() == xaxis
  229. assert ax.yaxis.get_scale() == yaxis
  230. def _check_axes_shape(axes, axes_num=None, layout=None, figsize=None):
  231. """
  232. Check expected number of axes is drawn in expected layout
  233. Parameters
  234. ----------
  235. axes : matplotlib Axes object, or its list-like
  236. axes_num : number
  237. expected number of axes. Unnecessary axes should be set to
  238. invisible.
  239. layout : tuple
  240. expected layout, (expected number of rows , columns)
  241. figsize : tuple
  242. expected figsize. default is matplotlib default
  243. """
  244. from pandas.plotting._matplotlib.tools import flatten_axes
  245. if figsize is None:
  246. figsize = (6.4, 4.8)
  247. visible_axes = _flatten_visible(axes)
  248. if axes_num is not None:
  249. assert len(visible_axes) == axes_num
  250. for ax in visible_axes:
  251. # check something drawn on visible axes
  252. assert len(ax.get_children()) > 0
  253. if layout is not None:
  254. x_set = set()
  255. y_set = set()
  256. for ax in flatten_axes(axes):
  257. # check axes coordinates to estimate layout
  258. points = ax.get_position().get_points()
  259. x_set.add(points[0][0])
  260. y_set.add(points[0][1])
  261. result = (len(y_set), len(x_set))
  262. assert result == layout
  263. tm.assert_numpy_array_equal(
  264. visible_axes[0].figure.get_size_inches(),
  265. np.array(figsize, dtype=np.float64),
  266. )
  267. def _flatten_visible(axes: Axes | Sequence[Axes]) -> Sequence[Axes]:
  268. """
  269. Flatten axes, and filter only visible
  270. Parameters
  271. ----------
  272. axes : matplotlib Axes object, or its list-like
  273. """
  274. from pandas.plotting._matplotlib.tools import flatten_axes
  275. axes_ndarray = flatten_axes(axes)
  276. axes = [ax for ax in axes_ndarray if ax.get_visible()]
  277. return axes
  278. def _check_has_errorbars(axes, xerr=0, yerr=0):
  279. """
  280. Check axes has expected number of errorbars
  281. Parameters
  282. ----------
  283. axes : matplotlib Axes object, or its list-like
  284. xerr : number
  285. expected number of x errorbar
  286. yerr : number
  287. expected number of y errorbar
  288. """
  289. axes = _flatten_visible(axes)
  290. for ax in axes:
  291. containers = ax.containers
  292. xerr_count = 0
  293. yerr_count = 0
  294. for c in containers:
  295. has_xerr = getattr(c, "has_xerr", False)
  296. has_yerr = getattr(c, "has_yerr", False)
  297. if has_xerr:
  298. xerr_count += 1
  299. if has_yerr:
  300. yerr_count += 1
  301. assert xerr == xerr_count
  302. assert yerr == yerr_count
  303. def _check_box_return_type(
  304. returned, return_type, expected_keys=None, check_ax_title=True
  305. ):
  306. """
  307. Check box returned type is correct
  308. Parameters
  309. ----------
  310. returned : object to be tested, returned from boxplot
  311. return_type : str
  312. return_type passed to boxplot
  313. expected_keys : list-like, optional
  314. group labels in subplot case. If not passed,
  315. the function checks assuming boxplot uses single ax
  316. check_ax_title : bool
  317. Whether to check the ax.title is the same as expected_key
  318. Intended to be checked by calling from ``boxplot``.
  319. Normal ``plot`` doesn't attach ``ax.title``, it must be disabled.
  320. """
  321. from matplotlib.axes import Axes
  322. types = {"dict": dict, "axes": Axes, "both": tuple}
  323. if expected_keys is None:
  324. # should be fixed when the returning default is changed
  325. if return_type is None:
  326. return_type = "dict"
  327. assert isinstance(returned, types[return_type])
  328. if return_type == "both":
  329. assert isinstance(returned.ax, Axes)
  330. assert isinstance(returned.lines, dict)
  331. else:
  332. # should be fixed when the returning default is changed
  333. if return_type is None:
  334. for r in _flatten_visible(returned):
  335. assert isinstance(r, Axes)
  336. return
  337. assert isinstance(returned, Series)
  338. assert sorted(returned.keys()) == sorted(expected_keys)
  339. for key, value in returned.items():
  340. assert isinstance(value, types[return_type])
  341. # check returned dict has correct mapping
  342. if return_type == "axes":
  343. if check_ax_title:
  344. assert value.get_title() == key
  345. elif return_type == "both":
  346. if check_ax_title:
  347. assert value.ax.get_title() == key
  348. assert isinstance(value.ax, Axes)
  349. assert isinstance(value.lines, dict)
  350. elif return_type == "dict":
  351. line = value["medians"][0]
  352. axes = line.axes
  353. if check_ax_title:
  354. assert axes.get_title() == key
  355. else:
  356. raise AssertionError
  357. def _check_grid_settings(obj, kinds, kws=None):
  358. # Make sure plot defaults to rcParams['axes.grid'] setting, GH 9792
  359. import matplotlib as mpl
  360. def is_grid_on():
  361. xticks = mpl.pyplot.gca().xaxis.get_major_ticks()
  362. yticks = mpl.pyplot.gca().yaxis.get_major_ticks()
  363. xoff = all(not g.gridline.get_visible() for g in xticks)
  364. yoff = all(not g.gridline.get_visible() for g in yticks)
  365. return not (xoff and yoff)
  366. if kws is None:
  367. kws = {}
  368. spndx = 1
  369. for kind in kinds:
  370. mpl.pyplot.subplot(1, 4 * len(kinds), spndx)
  371. spndx += 1
  372. mpl.rc("axes", grid=False)
  373. obj.plot(kind=kind, **kws)
  374. assert not is_grid_on()
  375. mpl.pyplot.clf()
  376. mpl.pyplot.subplot(1, 4 * len(kinds), spndx)
  377. spndx += 1
  378. mpl.rc("axes", grid=True)
  379. obj.plot(kind=kind, grid=False, **kws)
  380. assert not is_grid_on()
  381. mpl.pyplot.clf()
  382. if kind not in ["pie", "hexbin", "scatter"]:
  383. mpl.pyplot.subplot(1, 4 * len(kinds), spndx)
  384. spndx += 1
  385. mpl.rc("axes", grid=True)
  386. obj.plot(kind=kind, **kws)
  387. assert is_grid_on()
  388. mpl.pyplot.clf()
  389. mpl.pyplot.subplot(1, 4 * len(kinds), spndx)
  390. spndx += 1
  391. mpl.rc("axes", grid=False)
  392. obj.plot(kind=kind, grid=True, **kws)
  393. assert is_grid_on()
  394. mpl.pyplot.clf()
  395. def _unpack_cycler(rcParams, field="color"):
  396. """
  397. Auxiliary function for correctly unpacking cycler after MPL >= 1.5
  398. """
  399. return [v[field] for v in rcParams["axes.prop_cycle"]]
  400. def get_x_axis(ax):
  401. return ax._shared_axes["x"]
  402. def get_y_axis(ax):
  403. return ax._shared_axes["y"]
  404. def assert_is_valid_plot_return_object(objs) -> None:
  405. from matplotlib.artist import Artist
  406. from matplotlib.axes import Axes
  407. if isinstance(objs, (Series, np.ndarray)):
  408. if isinstance(objs, Series):
  409. objs = objs._values
  410. for el in objs.reshape(-1):
  411. msg = (
  412. "one of 'objs' is not a matplotlib Axes instance, "
  413. f"type encountered {type(el).__name__!r}"
  414. )
  415. assert isinstance(el, (Axes, dict)), msg
  416. else:
  417. msg = (
  418. "objs is neither an ndarray of Artist instances nor a single "
  419. "ArtistArtist instance, tuple, or dict, 'objs' is a "
  420. f"{type(objs).__name__!r}"
  421. )
  422. assert isinstance(objs, (Artist, tuple, dict)), msg
  423. def _check_plot_works(f, default_axes=False, **kwargs):
  424. """
  425. Create plot and ensure that plot return object is valid.
  426. Parameters
  427. ----------
  428. f : func
  429. Plotting function.
  430. default_axes : bool, optional
  431. If False (default):
  432. - If `ax` not in `kwargs`, then create subplot(211) and plot there
  433. - Create new subplot(212) and plot there as well
  434. - Mind special corner case for bootstrap_plot (see `_gen_two_subplots`)
  435. If True:
  436. - Simply run plotting function with kwargs provided
  437. - All required axes instances will be created automatically
  438. - It is recommended to use it when the plotting function
  439. creates multiple axes itself. It helps avoid warnings like
  440. 'UserWarning: To output multiple subplots,
  441. the figure containing the passed axes is being cleared'
  442. **kwargs
  443. Keyword arguments passed to the plotting function.
  444. Returns
  445. -------
  446. Plot object returned by the last plotting.
  447. """
  448. import matplotlib.pyplot as plt
  449. if default_axes:
  450. gen_plots = _gen_default_plot
  451. else:
  452. gen_plots = _gen_two_subplots
  453. ret = None
  454. fig = kwargs.get("figure", plt.gcf())
  455. fig.clf()
  456. for ret in gen_plots(f, fig, **kwargs):
  457. assert_is_valid_plot_return_object(ret)
  458. return ret
  459. def _gen_default_plot(f, fig, **kwargs):
  460. """
  461. Create plot in a default way.
  462. """
  463. yield f(**kwargs)
  464. def _gen_two_subplots(f, fig, **kwargs):
  465. """
  466. Create plot on two subplots forcefully created.
  467. """
  468. if "ax" not in kwargs:
  469. fig.add_subplot(211)
  470. yield f(**kwargs)
  471. if f is pd.plotting.bootstrap_plot:
  472. assert "ax" not in kwargs
  473. else:
  474. kwargs["ax"] = fig.add_subplot(212)
  475. yield f(**kwargs)