test_misc.py 25 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525526527528529530531532533534535536537538539540541542543544545546547548549550551552553554555556557558559560561562563564565566567568569570571572573574575576577578579580581582583584585586587588589590591592593594595596597598599600601602603604605606607608609610611612613614615616617618619620621622623624625626627628629630631632633634635636637638639640641642643644645646647648649650651652653654655656657658659660661662663664665666667668669670671672673674675676677678679680681682683684685686687688689690691692693694695696697698699700701702703704705706707708709710711712713714715716717718719720
  1. """ Test cases for misc plot functions """
  2. import os
  3. import numpy as np
  4. import pytest
  5. import pandas.util._test_decorators as td
  6. from pandas import (
  7. DataFrame,
  8. Index,
  9. Series,
  10. Timestamp,
  11. date_range,
  12. interval_range,
  13. period_range,
  14. plotting,
  15. read_csv,
  16. )
  17. import pandas._testing as tm
  18. from pandas.tests.plotting.common import (
  19. _check_colors,
  20. _check_legend_labels,
  21. _check_plot_works,
  22. _check_text_labels,
  23. _check_ticks_props,
  24. )
  25. mpl = pytest.importorskip("matplotlib")
  26. plt = pytest.importorskip("matplotlib.pyplot")
  27. cm = pytest.importorskip("matplotlib.cm")
  28. @pytest.fixture
  29. def iris(datapath) -> DataFrame:
  30. """
  31. The iris dataset as a DataFrame.
  32. """
  33. return read_csv(datapath("io", "data", "csv", "iris.csv"))
  34. @td.skip_if_installed("matplotlib")
  35. def test_import_error_message():
  36. # GH-19810
  37. df = DataFrame({"A": [1, 2]})
  38. with pytest.raises(ImportError, match="matplotlib is required for plotting"):
  39. df.plot()
  40. def test_get_accessor_args():
  41. func = plotting._core.PlotAccessor._get_call_args
  42. msg = "Called plot accessor for type list, expected Series or DataFrame"
  43. with pytest.raises(TypeError, match=msg):
  44. func(backend_name="", data=[], args=[], kwargs={})
  45. msg = "should not be called with positional arguments"
  46. with pytest.raises(TypeError, match=msg):
  47. func(backend_name="", data=Series(dtype=object), args=["line", None], kwargs={})
  48. x, y, kind, kwargs = func(
  49. backend_name="",
  50. data=DataFrame(),
  51. args=["x"],
  52. kwargs={"y": "y", "kind": "bar", "grid": False},
  53. )
  54. assert x == "x"
  55. assert y == "y"
  56. assert kind == "bar"
  57. assert kwargs == {"grid": False}
  58. x, y, kind, kwargs = func(
  59. backend_name="pandas.plotting._matplotlib",
  60. data=Series(dtype=object),
  61. args=[],
  62. kwargs={},
  63. )
  64. assert x is None
  65. assert y is None
  66. assert kind == "line"
  67. assert len(kwargs) == 24
  68. @pytest.mark.parametrize("kind", plotting.PlotAccessor._all_kinds)
  69. @pytest.mark.parametrize(
  70. "data", [DataFrame(np.arange(15).reshape(5, 3)), Series(range(5))]
  71. )
  72. @pytest.mark.parametrize(
  73. "index",
  74. [
  75. Index(range(5)),
  76. date_range("2020-01-01", periods=5),
  77. period_range("2020-01-01", periods=5),
  78. ],
  79. )
  80. def test_savefig(kind, data, index):
  81. fig, ax = plt.subplots()
  82. data.index = index
  83. kwargs = {}
  84. if kind in ["hexbin", "scatter", "pie"]:
  85. if isinstance(data, Series):
  86. pytest.skip(f"{kind} not supported with Series")
  87. kwargs = {"x": 0, "y": 1}
  88. data.plot(kind=kind, ax=ax, **kwargs)
  89. fig.savefig(os.devnull)
  90. class TestSeriesPlots:
  91. def test_autocorrelation_plot(self):
  92. from pandas.plotting import autocorrelation_plot
  93. ser = Series(
  94. np.arange(10, dtype=np.float64),
  95. index=date_range("2020-01-01", periods=10),
  96. name="ts",
  97. )
  98. # Ensure no UserWarning when making plot
  99. with tm.assert_produces_warning(None):
  100. _check_plot_works(autocorrelation_plot, series=ser)
  101. _check_plot_works(autocorrelation_plot, series=ser.values)
  102. ax = autocorrelation_plot(ser, label="Test")
  103. _check_legend_labels(ax, labels=["Test"])
  104. @pytest.mark.parametrize("kwargs", [{}, {"lag": 5}])
  105. def test_lag_plot(self, kwargs):
  106. from pandas.plotting import lag_plot
  107. ser = Series(
  108. np.arange(10, dtype=np.float64),
  109. index=date_range("2020-01-01", periods=10),
  110. name="ts",
  111. )
  112. _check_plot_works(lag_plot, series=ser, **kwargs)
  113. def test_bootstrap_plot(self):
  114. from pandas.plotting import bootstrap_plot
  115. ser = Series(
  116. np.arange(10, dtype=np.float64),
  117. index=date_range("2020-01-01", periods=10),
  118. name="ts",
  119. )
  120. _check_plot_works(bootstrap_plot, series=ser, size=10)
  121. class TestDataFramePlots:
  122. @pytest.mark.parametrize("pass_axis", [False, True])
  123. def test_scatter_matrix_axis(self, pass_axis):
  124. pytest.importorskip("scipy")
  125. scatter_matrix = plotting.scatter_matrix
  126. ax = None
  127. if pass_axis:
  128. _, ax = mpl.pyplot.subplots(3, 3)
  129. df = DataFrame(np.random.default_rng(2).standard_normal((100, 3)))
  130. # we are plotting multiples on a sub-plot
  131. with tm.assert_produces_warning(UserWarning, check_stacklevel=False):
  132. axes = _check_plot_works(
  133. scatter_matrix,
  134. frame=df,
  135. range_padding=0.1,
  136. ax=ax,
  137. )
  138. axes0_labels = axes[0][0].yaxis.get_majorticklabels()
  139. # GH 5662
  140. expected = ["-2", "0", "2"]
  141. _check_text_labels(axes0_labels, expected)
  142. _check_ticks_props(axes, xlabelsize=8, xrot=90, ylabelsize=8, yrot=0)
  143. @pytest.mark.parametrize("pass_axis", [False, True])
  144. def test_scatter_matrix_axis_smaller(self, pass_axis):
  145. pytest.importorskip("scipy")
  146. scatter_matrix = plotting.scatter_matrix
  147. ax = None
  148. if pass_axis:
  149. _, ax = mpl.pyplot.subplots(3, 3)
  150. df = DataFrame(np.random.default_rng(11).standard_normal((100, 3)))
  151. df[0] = (df[0] - 2) / 3
  152. # we are plotting multiples on a sub-plot
  153. with tm.assert_produces_warning(UserWarning, check_stacklevel=False):
  154. axes = _check_plot_works(
  155. scatter_matrix,
  156. frame=df,
  157. range_padding=0.1,
  158. ax=ax,
  159. )
  160. axes0_labels = axes[0][0].yaxis.get_majorticklabels()
  161. expected = ["-1.0", "-0.5", "0.0"]
  162. _check_text_labels(axes0_labels, expected)
  163. _check_ticks_props(axes, xlabelsize=8, xrot=90, ylabelsize=8, yrot=0)
  164. @pytest.mark.slow
  165. def test_andrews_curves_no_warning(self, iris):
  166. from pandas.plotting import andrews_curves
  167. df = iris
  168. # Ensure no UserWarning when making plot
  169. with tm.assert_produces_warning(None):
  170. _check_plot_works(andrews_curves, frame=df, class_column="Name")
  171. @pytest.mark.slow
  172. @pytest.mark.parametrize(
  173. "linecolors",
  174. [
  175. ("#556270", "#4ECDC4", "#C7F464"),
  176. ["dodgerblue", "aquamarine", "seagreen"],
  177. ],
  178. )
  179. @pytest.mark.parametrize(
  180. "df",
  181. [
  182. "iris",
  183. DataFrame(
  184. {
  185. "A": np.random.default_rng(2).standard_normal(10),
  186. "B": np.random.default_rng(2).standard_normal(10),
  187. "C": np.random.default_rng(2).standard_normal(10),
  188. "Name": ["A"] * 10,
  189. }
  190. ),
  191. ],
  192. )
  193. def test_andrews_curves_linecolors(self, request, df, linecolors):
  194. from pandas.plotting import andrews_curves
  195. if isinstance(df, str):
  196. df = request.getfixturevalue(df)
  197. ax = _check_plot_works(
  198. andrews_curves, frame=df, class_column="Name", color=linecolors
  199. )
  200. _check_colors(
  201. ax.get_lines()[:10], linecolors=linecolors, mapping=df["Name"][:10]
  202. )
  203. @pytest.mark.slow
  204. @pytest.mark.parametrize(
  205. "df",
  206. [
  207. "iris",
  208. DataFrame(
  209. {
  210. "A": np.random.default_rng(2).standard_normal(10),
  211. "B": np.random.default_rng(2).standard_normal(10),
  212. "C": np.random.default_rng(2).standard_normal(10),
  213. "Name": ["A"] * 10,
  214. }
  215. ),
  216. ],
  217. )
  218. def test_andrews_curves_cmap(self, request, df):
  219. from pandas.plotting import andrews_curves
  220. if isinstance(df, str):
  221. df = request.getfixturevalue(df)
  222. cmaps = [cm.jet(n) for n in np.linspace(0, 1, df["Name"].nunique())]
  223. ax = _check_plot_works(
  224. andrews_curves, frame=df, class_column="Name", color=cmaps
  225. )
  226. _check_colors(ax.get_lines()[:10], linecolors=cmaps, mapping=df["Name"][:10])
  227. @pytest.mark.slow
  228. def test_andrews_curves_handle(self):
  229. from pandas.plotting import andrews_curves
  230. colors = ["b", "g", "r"]
  231. df = DataFrame({"A": [1, 2, 3], "B": [1, 2, 3], "C": [1, 2, 3], "Name": colors})
  232. ax = andrews_curves(df, "Name", color=colors)
  233. handles, _ = ax.get_legend_handles_labels()
  234. _check_colors(handles, linecolors=colors)
  235. @pytest.mark.slow
  236. @pytest.mark.parametrize(
  237. "color",
  238. [("#556270", "#4ECDC4", "#C7F464"), ["dodgerblue", "aquamarine", "seagreen"]],
  239. )
  240. def test_parallel_coordinates_colors(self, iris, color):
  241. from pandas.plotting import parallel_coordinates
  242. df = iris
  243. ax = _check_plot_works(
  244. parallel_coordinates, frame=df, class_column="Name", color=color
  245. )
  246. _check_colors(ax.get_lines()[:10], linecolors=color, mapping=df["Name"][:10])
  247. @pytest.mark.slow
  248. def test_parallel_coordinates_cmap(self, iris):
  249. from matplotlib import cm
  250. from pandas.plotting import parallel_coordinates
  251. df = iris
  252. ax = _check_plot_works(
  253. parallel_coordinates, frame=df, class_column="Name", colormap=cm.jet
  254. )
  255. cmaps = [cm.jet(n) for n in np.linspace(0, 1, df["Name"].nunique())]
  256. _check_colors(ax.get_lines()[:10], linecolors=cmaps, mapping=df["Name"][:10])
  257. @pytest.mark.slow
  258. def test_parallel_coordinates_line_diff(self, iris):
  259. from pandas.plotting import parallel_coordinates
  260. df = iris
  261. ax = _check_plot_works(parallel_coordinates, frame=df, class_column="Name")
  262. nlines = len(ax.get_lines())
  263. nxticks = len(ax.xaxis.get_ticklabels())
  264. ax = _check_plot_works(
  265. parallel_coordinates, frame=df, class_column="Name", axvlines=False
  266. )
  267. assert len(ax.get_lines()) == (nlines - nxticks)
  268. @pytest.mark.slow
  269. def test_parallel_coordinates_handles(self, iris):
  270. from pandas.plotting import parallel_coordinates
  271. df = iris
  272. colors = ["b", "g", "r"]
  273. df = DataFrame({"A": [1, 2, 3], "B": [1, 2, 3], "C": [1, 2, 3], "Name": colors})
  274. ax = parallel_coordinates(df, "Name", color=colors)
  275. handles, _ = ax.get_legend_handles_labels()
  276. _check_colors(handles, linecolors=colors)
  277. # not sure if this is indicative of a problem
  278. @pytest.mark.filterwarnings("ignore:Attempting to set:UserWarning")
  279. def test_parallel_coordinates_with_sorted_labels(self):
  280. """For #15908"""
  281. from pandas.plotting import parallel_coordinates
  282. df = DataFrame(
  283. {
  284. "feat": list(range(30)),
  285. "class": [2 for _ in range(10)]
  286. + [3 for _ in range(10)]
  287. + [1 for _ in range(10)],
  288. }
  289. )
  290. ax = parallel_coordinates(df, "class", sort_labels=True)
  291. polylines, labels = ax.get_legend_handles_labels()
  292. color_label_tuples = zip(
  293. [polyline.get_color() for polyline in polylines], labels
  294. )
  295. ordered_color_label_tuples = sorted(color_label_tuples, key=lambda x: x[1])
  296. prev_next_tupels = zip(
  297. list(ordered_color_label_tuples[0:-1]), list(ordered_color_label_tuples[1:])
  298. )
  299. for prev, nxt in prev_next_tupels:
  300. # labels and colors are ordered strictly increasing
  301. assert prev[1] < nxt[1] and prev[0] < nxt[0]
  302. def test_radviz_no_warning(self, iris):
  303. from pandas.plotting import radviz
  304. df = iris
  305. # Ensure no UserWarning when making plot
  306. with tm.assert_produces_warning(None):
  307. _check_plot_works(radviz, frame=df, class_column="Name")
  308. @pytest.mark.parametrize(
  309. "color",
  310. [("#556270", "#4ECDC4", "#C7F464"), ["dodgerblue", "aquamarine", "seagreen"]],
  311. )
  312. def test_radviz_color(self, iris, color):
  313. from pandas.plotting import radviz
  314. df = iris
  315. ax = _check_plot_works(radviz, frame=df, class_column="Name", color=color)
  316. # skip Circle drawn as ticks
  317. patches = [p for p in ax.patches[:20] if p.get_label() != ""]
  318. _check_colors(patches[:10], facecolors=color, mapping=df["Name"][:10])
  319. def test_radviz_color_cmap(self, iris):
  320. from matplotlib import cm
  321. from pandas.plotting import radviz
  322. df = iris
  323. ax = _check_plot_works(radviz, frame=df, class_column="Name", colormap=cm.jet)
  324. cmaps = [cm.jet(n) for n in np.linspace(0, 1, df["Name"].nunique())]
  325. patches = [p for p in ax.patches[:20] if p.get_label() != ""]
  326. _check_colors(patches, facecolors=cmaps, mapping=df["Name"][:10])
  327. def test_radviz_colors_handles(self):
  328. from pandas.plotting import radviz
  329. colors = [[0.0, 0.0, 1.0, 1.0], [0.0, 0.5, 1.0, 1.0], [1.0, 0.0, 0.0, 1.0]]
  330. df = DataFrame(
  331. {"A": [1, 2, 3], "B": [2, 1, 3], "C": [3, 2, 1], "Name": ["b", "g", "r"]}
  332. )
  333. ax = radviz(df, "Name", color=colors)
  334. handles, _ = ax.get_legend_handles_labels()
  335. _check_colors(handles, facecolors=colors)
  336. def test_subplot_titles(self, iris):
  337. df = iris.drop("Name", axis=1).head()
  338. # Use the column names as the subplot titles
  339. title = list(df.columns)
  340. # Case len(title) == len(df)
  341. plot = df.plot(subplots=True, title=title)
  342. assert [p.get_title() for p in plot] == title
  343. def test_subplot_titles_too_much(self, iris):
  344. df = iris.drop("Name", axis=1).head()
  345. # Use the column names as the subplot titles
  346. title = list(df.columns)
  347. # Case len(title) > len(df)
  348. msg = (
  349. "The length of `title` must equal the number of columns if "
  350. "using `title` of type `list` and `subplots=True`"
  351. )
  352. with pytest.raises(ValueError, match=msg):
  353. df.plot(subplots=True, title=title + ["kittens > puppies"])
  354. def test_subplot_titles_too_little(self, iris):
  355. df = iris.drop("Name", axis=1).head()
  356. # Use the column names as the subplot titles
  357. title = list(df.columns)
  358. msg = (
  359. "The length of `title` must equal the number of columns if "
  360. "using `title` of type `list` and `subplots=True`"
  361. )
  362. # Case len(title) < len(df)
  363. with pytest.raises(ValueError, match=msg):
  364. df.plot(subplots=True, title=title[:2])
  365. def test_subplot_titles_subplots_false(self, iris):
  366. df = iris.drop("Name", axis=1).head()
  367. # Use the column names as the subplot titles
  368. title = list(df.columns)
  369. # Case subplots=False and title is of type list
  370. msg = (
  371. "Using `title` of type `list` is not supported unless "
  372. "`subplots=True` is passed"
  373. )
  374. with pytest.raises(ValueError, match=msg):
  375. df.plot(subplots=False, title=title)
  376. def test_subplot_titles_numeric_square_layout(self, iris):
  377. df = iris.drop("Name", axis=1).head()
  378. # Use the column names as the subplot titles
  379. title = list(df.columns)
  380. # Case df with 3 numeric columns but layout of (2,2)
  381. plot = df.drop("SepalWidth", axis=1).plot(
  382. subplots=True, layout=(2, 2), title=title[:-1]
  383. )
  384. title_list = [ax.get_title() for sublist in plot for ax in sublist]
  385. assert title_list == title[:3] + [""]
  386. def test_get_standard_colors_random_seed(self):
  387. # GH17525
  388. df = DataFrame(np.zeros((10, 10)))
  389. # Make sure that the random seed isn't reset by get_standard_colors
  390. plotting.parallel_coordinates(df, 0)
  391. rand1 = np.random.default_rng(None).random()
  392. plotting.parallel_coordinates(df, 0)
  393. rand2 = np.random.default_rng(None).random()
  394. assert rand1 != rand2
  395. def test_get_standard_colors_consistency(self):
  396. # GH17525
  397. # Make sure it produces the same colors every time it's called
  398. from pandas.plotting._matplotlib.style import get_standard_colors
  399. color1 = get_standard_colors(1, color_type="random")
  400. color2 = get_standard_colors(1, color_type="random")
  401. assert color1 == color2
  402. def test_get_standard_colors_default_num_colors(self):
  403. from pandas.plotting._matplotlib.style import get_standard_colors
  404. # Make sure the default color_types returns the specified amount
  405. color1 = get_standard_colors(1, color_type="default")
  406. color2 = get_standard_colors(9, color_type="default")
  407. color3 = get_standard_colors(20, color_type="default")
  408. assert len(color1) == 1
  409. assert len(color2) == 9
  410. assert len(color3) == 20
  411. def test_plot_single_color(self):
  412. # Example from #20585. All 3 bars should have the same color
  413. df = DataFrame(
  414. {
  415. "account-start": ["2017-02-03", "2017-03-03", "2017-01-01"],
  416. "client": ["Alice Anders", "Bob Baker", "Charlie Chaplin"],
  417. "balance": [-1432.32, 10.43, 30000.00],
  418. "db-id": [1234, 2424, 251],
  419. "proxy-id": [525, 1525, 2542],
  420. "rank": [52, 525, 32],
  421. }
  422. )
  423. ax = df.client.value_counts().plot.bar()
  424. colors = [rect.get_facecolor() for rect in ax.get_children()[0:3]]
  425. assert all(color == colors[0] for color in colors)
  426. def test_get_standard_colors_no_appending(self):
  427. # GH20726
  428. # Make sure not to add more colors so that matplotlib can cycle
  429. # correctly.
  430. from matplotlib import cm
  431. from pandas.plotting._matplotlib.style import get_standard_colors
  432. color_before = cm.gnuplot(range(5))
  433. color_after = get_standard_colors(1, color=color_before)
  434. assert len(color_after) == len(color_before)
  435. df = DataFrame(
  436. np.random.default_rng(2).standard_normal((48, 4)), columns=list("ABCD")
  437. )
  438. color_list = cm.gnuplot(np.linspace(0, 1, 16))
  439. p = df.A.plot.bar(figsize=(16, 7), color=color_list)
  440. assert p.patches[1].get_facecolor() == p.patches[17].get_facecolor()
  441. @pytest.mark.parametrize("kind", ["bar", "line"])
  442. def test_dictionary_color(self, kind):
  443. # issue-8193
  444. # Test plot color dictionary format
  445. data_files = ["a", "b"]
  446. expected = [(0.5, 0.24, 0.6), (0.3, 0.7, 0.7)]
  447. df1 = DataFrame(np.random.default_rng(2).random((2, 2)), columns=data_files)
  448. dic_color = {"b": (0.3, 0.7, 0.7), "a": (0.5, 0.24, 0.6)}
  449. ax = df1.plot(kind=kind, color=dic_color)
  450. if kind == "bar":
  451. colors = [rect.get_facecolor()[0:-1] for rect in ax.get_children()[0:3:2]]
  452. else:
  453. colors = [rect.get_color() for rect in ax.get_lines()[0:2]]
  454. assert all(color == expected[index] for index, color in enumerate(colors))
  455. def test_bar_plot(self):
  456. # GH38947
  457. # Test bar plot with string and int index
  458. from matplotlib.text import Text
  459. expected = [Text(0, 0, "0"), Text(1, 0, "Total")]
  460. df = DataFrame(
  461. {
  462. "a": [1, 2],
  463. },
  464. index=Index([0, "Total"]),
  465. )
  466. plot_bar = df.plot.bar()
  467. assert all(
  468. (a.get_text() == b.get_text())
  469. for a, b in zip(plot_bar.get_xticklabels(), expected)
  470. )
  471. def test_barh_plot_labels_mixed_integer_string(self):
  472. # GH39126
  473. # Test barh plot with string and integer at the same column
  474. from matplotlib.text import Text
  475. df = DataFrame([{"word": 1, "value": 0}, {"word": "knowledge", "value": 2}])
  476. plot_barh = df.plot.barh(x="word", legend=None)
  477. expected_yticklabels = [Text(0, 0, "1"), Text(0, 1, "knowledge")]
  478. assert all(
  479. actual.get_text() == expected.get_text()
  480. for actual, expected in zip(
  481. plot_barh.get_yticklabels(), expected_yticklabels
  482. )
  483. )
  484. def test_has_externally_shared_axis_x_axis(self):
  485. # GH33819
  486. # Test _has_externally_shared_axis() works for x-axis
  487. func = plotting._matplotlib.tools._has_externally_shared_axis
  488. fig = mpl.pyplot.figure()
  489. plots = fig.subplots(2, 4)
  490. # Create *externally* shared axes for first and third columns
  491. plots[0][0] = fig.add_subplot(231, sharex=plots[1][0])
  492. plots[0][2] = fig.add_subplot(233, sharex=plots[1][2])
  493. # Create *internally* shared axes for second and third columns
  494. plots[0][1].twinx()
  495. plots[0][2].twinx()
  496. # First column is only externally shared
  497. # Second column is only internally shared
  498. # Third column is both
  499. # Fourth column is neither
  500. assert func(plots[0][0], "x")
  501. assert not func(plots[0][1], "x")
  502. assert func(plots[0][2], "x")
  503. assert not func(plots[0][3], "x")
  504. def test_has_externally_shared_axis_y_axis(self):
  505. # GH33819
  506. # Test _has_externally_shared_axis() works for y-axis
  507. func = plotting._matplotlib.tools._has_externally_shared_axis
  508. fig = mpl.pyplot.figure()
  509. plots = fig.subplots(4, 2)
  510. # Create *externally* shared axes for first and third rows
  511. plots[0][0] = fig.add_subplot(321, sharey=plots[0][1])
  512. plots[2][0] = fig.add_subplot(325, sharey=plots[2][1])
  513. # Create *internally* shared axes for second and third rows
  514. plots[1][0].twiny()
  515. plots[2][0].twiny()
  516. # First row is only externally shared
  517. # Second row is only internally shared
  518. # Third row is both
  519. # Fourth row is neither
  520. assert func(plots[0][0], "y")
  521. assert not func(plots[1][0], "y")
  522. assert func(plots[2][0], "y")
  523. assert not func(plots[3][0], "y")
  524. def test_has_externally_shared_axis_invalid_compare_axis(self):
  525. # GH33819
  526. # Test _has_externally_shared_axis() raises an exception when
  527. # passed an invalid value as compare_axis parameter
  528. func = plotting._matplotlib.tools._has_externally_shared_axis
  529. fig = mpl.pyplot.figure()
  530. plots = fig.subplots(4, 2)
  531. # Create arbitrary axes
  532. plots[0][0] = fig.add_subplot(321, sharey=plots[0][1])
  533. # Check that an invalid compare_axis value triggers the expected exception
  534. msg = "needs 'x' or 'y' as a second parameter"
  535. with pytest.raises(ValueError, match=msg):
  536. func(plots[0][0], "z")
  537. def test_externally_shared_axes(self):
  538. # Example from GH33819
  539. # Create data
  540. df = DataFrame(
  541. {
  542. "a": np.random.default_rng(2).standard_normal(1000),
  543. "b": np.random.default_rng(2).standard_normal(1000),
  544. }
  545. )
  546. # Create figure
  547. fig = mpl.pyplot.figure()
  548. plots = fig.subplots(2, 3)
  549. # Create *externally* shared axes
  550. plots[0][0] = fig.add_subplot(231, sharex=plots[1][0])
  551. # note: no plots[0][1] that's the twin only case
  552. plots[0][2] = fig.add_subplot(233, sharex=plots[1][2])
  553. # Create *internally* shared axes
  554. # note: no plots[0][0] that's the external only case
  555. twin_ax1 = plots[0][1].twinx()
  556. twin_ax2 = plots[0][2].twinx()
  557. # Plot data to primary axes
  558. df["a"].plot(ax=plots[0][0], title="External share only").set_xlabel(
  559. "this label should never be visible"
  560. )
  561. df["a"].plot(ax=plots[1][0])
  562. df["a"].plot(ax=plots[0][1], title="Internal share (twin) only").set_xlabel(
  563. "this label should always be visible"
  564. )
  565. df["a"].plot(ax=plots[1][1])
  566. df["a"].plot(ax=plots[0][2], title="Both").set_xlabel(
  567. "this label should never be visible"
  568. )
  569. df["a"].plot(ax=plots[1][2])
  570. # Plot data to twinned axes
  571. df["b"].plot(ax=twin_ax1, color="green")
  572. df["b"].plot(ax=twin_ax2, color="yellow")
  573. assert not plots[0][0].xaxis.get_label().get_visible()
  574. assert plots[0][1].xaxis.get_label().get_visible()
  575. assert not plots[0][2].xaxis.get_label().get_visible()
  576. def test_plot_bar_axis_units_timestamp_conversion(self):
  577. # GH 38736
  578. # Ensure string x-axis from the second plot will not be converted to datetime
  579. # due to axis data from first plot
  580. df = DataFrame(
  581. [1.0],
  582. index=[Timestamp("2022-02-22 22:22:22")],
  583. )
  584. _check_plot_works(df.plot)
  585. s = Series({"A": 1.0})
  586. _check_plot_works(s.plot.bar)
  587. def test_bar_plt_xaxis_intervalrange(self):
  588. # GH 38969
  589. # Ensure IntervalIndex x-axis produces a bar plot as expected
  590. from matplotlib.text import Text
  591. expected = [Text(0, 0, "([0, 1],)"), Text(1, 0, "([1, 2],)")]
  592. s = Series(
  593. [1, 2],
  594. index=[interval_range(0, 2, closed="both")],
  595. )
  596. _check_plot_works(s.plot.bar)
  597. assert all(
  598. (a.get_text() == b.get_text())
  599. for a, b in zip(s.plot.bar().get_xticklabels(), expected)
  600. )