_doctools.py 6.7 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206
  1. from __future__ import annotations
  2. from typing import TYPE_CHECKING
  3. import numpy as np
  4. import pandas as pd
  5. if TYPE_CHECKING:
  6. from collections.abc import Iterable
  7. from matplotlib.figure import Figure
  8. class TablePlotter:
  9. """
  10. Layout some DataFrames in vertical/horizontal layout for explanation.
  11. Used in merging.rst
  12. """
  13. def __init__(
  14. self,
  15. cell_width: float = 0.37,
  16. cell_height: float = 0.25,
  17. font_size: float = 7.5,
  18. ) -> None:
  19. self.cell_width = cell_width
  20. self.cell_height = cell_height
  21. self.font_size = font_size
  22. def _shape(self, df: pd.DataFrame) -> tuple[int, int]:
  23. """
  24. Calculate table shape considering index levels.
  25. """
  26. row, col = df.shape
  27. return row + df.columns.nlevels, col + df.index.nlevels
  28. def _get_cells(self, left, right, vertical) -> tuple[int, int]:
  29. """
  30. Calculate appropriate figure size based on left and right data.
  31. """
  32. if vertical:
  33. # calculate required number of cells
  34. vcells = max(sum(self._shape(df)[0] for df in left), self._shape(right)[0])
  35. hcells = max(self._shape(df)[1] for df in left) + self._shape(right)[1]
  36. else:
  37. vcells = max([self._shape(df)[0] for df in left] + [self._shape(right)[0]])
  38. hcells = sum([self._shape(df)[1] for df in left] + [self._shape(right)[1]])
  39. return hcells, vcells
  40. def plot(
  41. self, left, right, labels: Iterable[str] = (), vertical: bool = True
  42. ) -> Figure:
  43. """
  44. Plot left / right DataFrames in specified layout.
  45. Parameters
  46. ----------
  47. left : list of DataFrames before operation is applied
  48. right : DataFrame of operation result
  49. labels : list of str to be drawn as titles of left DataFrames
  50. vertical : bool, default True
  51. If True, use vertical layout. If False, use horizontal layout.
  52. """
  53. from matplotlib import gridspec
  54. import matplotlib.pyplot as plt
  55. if not isinstance(left, list):
  56. left = [left]
  57. left = [self._conv(df) for df in left]
  58. right = self._conv(right)
  59. hcells, vcells = self._get_cells(left, right, vertical)
  60. if vertical:
  61. figsize = self.cell_width * hcells, self.cell_height * vcells
  62. else:
  63. # include margin for titles
  64. figsize = self.cell_width * hcells, self.cell_height * vcells
  65. fig = plt.figure(figsize=figsize)
  66. if vertical:
  67. gs = gridspec.GridSpec(len(left), hcells)
  68. # left
  69. max_left_cols = max(self._shape(df)[1] for df in left)
  70. max_left_rows = max(self._shape(df)[0] for df in left)
  71. for i, (_left, _label) in enumerate(zip(left, labels, strict=True)):
  72. ax = fig.add_subplot(gs[i, 0:max_left_cols])
  73. self._make_table(ax, _left, title=_label, height=1.0 / max_left_rows)
  74. # right
  75. ax = plt.subplot(gs[:, max_left_cols:])
  76. self._make_table(ax, right, title="Result", height=1.05 / vcells)
  77. fig.subplots_adjust(top=0.9, bottom=0.05, left=0.05, right=0.95)
  78. else:
  79. max_rows = max(self._shape(df)[0] for df in [*left, right])
  80. height = 1.0 / np.max(max_rows)
  81. gs = gridspec.GridSpec(1, hcells)
  82. # left
  83. i = 0
  84. for df, _label in zip(left, labels, strict=True):
  85. sp = self._shape(df)
  86. ax = fig.add_subplot(gs[0, i : i + sp[1]])
  87. self._make_table(ax, df, title=_label, height=height)
  88. i += sp[1]
  89. # right
  90. ax = plt.subplot(gs[0, i:])
  91. self._make_table(ax, right, title="Result", height=height)
  92. fig.subplots_adjust(top=0.85, bottom=0.05, left=0.05, right=0.95)
  93. return fig
  94. def _conv(self, data):
  95. """
  96. Convert each input to appropriate for table outplot.
  97. """
  98. if isinstance(data, pd.Series):
  99. if data.name is None:
  100. data = data.to_frame(name="")
  101. else:
  102. data = data.to_frame()
  103. data = data.fillna("NaN")
  104. return data
  105. def _insert_index(self, data):
  106. # insert is destructive
  107. data = data.copy()
  108. idx_nlevels = data.index.nlevels
  109. if idx_nlevels == 1:
  110. data.insert(0, "Index", data.index)
  111. else:
  112. for i in range(idx_nlevels):
  113. data.insert(i, f"Index{i}", data.index._get_level_values(i))
  114. col_nlevels = data.columns.nlevels
  115. if col_nlevels > 1:
  116. col = data.columns._get_level_values(0)
  117. values = [
  118. data.columns._get_level_values(i)._values for i in range(1, col_nlevels)
  119. ]
  120. col_df = pd.DataFrame(values)
  121. data.columns = col_df.columns
  122. data = pd.concat([col_df, data])
  123. data.columns = col
  124. return data
  125. def _make_table(self, ax, df, title: str, height: float | None = None) -> None:
  126. if df is None:
  127. ax.set_visible(False)
  128. return
  129. from pandas import plotting
  130. idx_nlevels = df.index.nlevels
  131. col_nlevels = df.columns.nlevels
  132. # must be convert here to get index levels for colorization
  133. df = self._insert_index(df)
  134. tb = plotting.table(ax, df, loc=9)
  135. tb.set_fontsize(self.font_size)
  136. if height is None:
  137. height = 1.0 / (len(df) + 1)
  138. props = tb.properties()
  139. for (r, c), cell in props["celld"].items():
  140. if c == -1:
  141. cell.set_visible(False)
  142. elif r < col_nlevels and c < idx_nlevels:
  143. cell.set_visible(False)
  144. elif r < col_nlevels or c < idx_nlevels:
  145. cell.set_facecolor("#AAAAAA")
  146. cell.set_height(height)
  147. ax.set_title(title, size=self.font_size)
  148. ax.axis("off")
  149. def main() -> None:
  150. import matplotlib.pyplot as plt
  151. p = TablePlotter()
  152. df1 = pd.DataFrame({"A": [10, 11, 12], "B": [20, 21, 22], "C": [30, 31, 32]})
  153. df2 = pd.DataFrame({"A": [10, 12], "C": [30, 32]})
  154. p.plot([df1, df2], pd.concat([df1, df2]), labels=["df1", "df2"], vertical=True)
  155. plt.show()
  156. df3 = pd.DataFrame({"X": [10, 12], "Z": [30, 32]})
  157. p.plot(
  158. [df1, df3], pd.concat([df1, df3], axis=1), labels=["df1", "df2"], vertical=False
  159. )
  160. plt.show()
  161. idx = pd.MultiIndex.from_tuples(
  162. [(1, "A"), (1, "B"), (1, "C"), (2, "A"), (2, "B"), (2, "C")]
  163. )
  164. column = pd.MultiIndex.from_tuples([(1, "A"), (1, "B")])
  165. df3 = pd.DataFrame({"v1": [1, 2, 3, 4, 5, 6], "v2": [5, 6, 7, 8, 9, 10]}, index=idx)
  166. df3.columns = column
  167. p.plot(df3, df3, labels=["df3"])
  168. plt.show()
  169. if __name__ == "__main__":
  170. main()