matplotlib_plugin.py 6.3 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220
  1. from collections import namedtuple
  2. import numpy as np
  3. from ...util import dtype as dtypes
  4. from ...exposure import is_low_contrast
  5. from ..._shared.utils import warn
  6. from math import floor, ceil
  7. _default_colormap = 'gray'
  8. _nonstandard_colormap = 'viridis'
  9. _diverging_colormap = 'RdBu'
  10. ImageProperties = namedtuple(
  11. 'ImageProperties',
  12. ['signed', 'out_of_range_float', 'low_data_range', 'unsupported_dtype'],
  13. )
  14. def _get_image_properties(image):
  15. """Determine nonstandard properties of an input image.
  16. Parameters
  17. ----------
  18. image : array
  19. The input image.
  20. Returns
  21. -------
  22. ip : ImageProperties named tuple
  23. The properties of the image:
  24. - signed: whether the image has negative values.
  25. - out_of_range_float: if the image has floating point data
  26. outside of [-1, 1].
  27. - low_data_range: if the image is in the standard image
  28. range (e.g. [0, 1] for a floating point image) but its
  29. data range would be too small to display with standard
  30. image ranges.
  31. - unsupported_dtype: if the image data type is not a
  32. standard skimage type, e.g. ``numpy.uint64``.
  33. """
  34. immin, immax = np.min(image), np.max(image)
  35. imtype = image.dtype.type
  36. try:
  37. lo, hi = dtypes.dtype_range[imtype]
  38. except KeyError:
  39. lo, hi = immin, immax
  40. signed = immin < 0
  41. out_of_range_float = np.issubdtype(image.dtype, np.floating) and (
  42. immin < lo or immax > hi
  43. )
  44. low_data_range = immin != immax and is_low_contrast(image)
  45. unsupported_dtype = image.dtype not in dtypes._supported_types
  46. return ImageProperties(
  47. signed, out_of_range_float, low_data_range, unsupported_dtype
  48. )
  49. def _raise_warnings(image_properties):
  50. """Raise the appropriate warning for each nonstandard image type.
  51. Parameters
  52. ----------
  53. image_properties : ImageProperties named tuple
  54. The properties of the considered image.
  55. """
  56. ip = image_properties
  57. if ip.unsupported_dtype:
  58. warn(
  59. "Non-standard image type; displaying image with " "stretched contrast.",
  60. stacklevel=3,
  61. )
  62. if ip.low_data_range:
  63. warn(
  64. "Low image data range; displaying image with " "stretched contrast.",
  65. stacklevel=3,
  66. )
  67. if ip.out_of_range_float:
  68. warn(
  69. "Float image out of standard range; displaying "
  70. "image with stretched contrast.",
  71. stacklevel=3,
  72. )
  73. def _get_display_range(image):
  74. """Return the display range for a given set of image properties.
  75. Parameters
  76. ----------
  77. image : array
  78. The input image.
  79. Returns
  80. -------
  81. lo, hi : same type as immin, immax
  82. The display range to be used for the input image.
  83. cmap : string
  84. The name of the colormap to use.
  85. """
  86. ip = _get_image_properties(image)
  87. immin, immax = np.min(image), np.max(image)
  88. if ip.signed:
  89. magnitude = max(abs(immin), abs(immax))
  90. lo, hi = -magnitude, magnitude
  91. cmap = _diverging_colormap
  92. elif any(ip):
  93. _raise_warnings(ip)
  94. lo, hi = immin, immax
  95. cmap = _nonstandard_colormap
  96. else:
  97. lo = 0
  98. imtype = image.dtype.type
  99. hi = dtypes.dtype_range[imtype][1]
  100. cmap = _default_colormap
  101. return lo, hi, cmap
  102. def imshow(image, ax=None, show_cbar=None, **kwargs):
  103. """Show the input image and return the current axes.
  104. By default, the image is displayed in grayscale, rather than
  105. the matplotlib default colormap.
  106. Images are assumed to have standard range for their type. For
  107. example, if a floating point image has values in [0, 0.5], the
  108. most intense color will be gray50, not white.
  109. If the image exceeds the standard range, or if the range is too
  110. small to display, we fall back on displaying exactly the range of
  111. the input image, along with a colorbar to clearly indicate that
  112. this range transformation has occurred.
  113. For signed images, we use a diverging colormap centered at 0.
  114. Parameters
  115. ----------
  116. image : array, shape (M, N[, 3])
  117. The image to display.
  118. ax : `matplotlib.axes.Axes`, optional
  119. The axis to use for the image, defaults to plt.gca().
  120. show_cbar : bool, optional
  121. Whether to show the colorbar (used to override default behavior).
  122. **kwargs : Keyword arguments
  123. These are passed directly to `matplotlib.pyplot.imshow`.
  124. Returns
  125. -------
  126. ax_im : `matplotlib.pyplot.AxesImage`
  127. The `AxesImage` object returned by `plt.imshow`.
  128. """
  129. import matplotlib.pyplot as plt
  130. from mpl_toolkits.axes_grid1 import make_axes_locatable
  131. lo, hi, cmap = _get_display_range(image)
  132. kwargs.setdefault('interpolation', 'nearest')
  133. kwargs.setdefault('cmap', cmap)
  134. kwargs.setdefault('vmin', lo)
  135. kwargs.setdefault('vmax', hi)
  136. ax = ax or plt.gca()
  137. ax_im = ax.imshow(image, **kwargs)
  138. if (cmap != _default_colormap and show_cbar is not False) or show_cbar:
  139. divider = make_axes_locatable(ax)
  140. cax = divider.append_axes("right", size="5%", pad=0.05)
  141. plt.colorbar(ax_im, cax=cax)
  142. ax.get_figure().tight_layout()
  143. return ax_im
  144. def imshow_collection(ic, *args, **kwargs):
  145. """Display all images in the collection.
  146. Returns
  147. -------
  148. fig : `matplotlib.figure.Figure`
  149. The `Figure` object returned by `plt.subplots`.
  150. """
  151. import matplotlib.pyplot as plt
  152. if len(ic) < 1:
  153. raise ValueError('Number of images to plot must be greater than 0')
  154. # The target is to plot images on a grid with aspect ratio 4:3
  155. num_images = len(ic)
  156. # Two pairs of `nrows, ncols` are possible
  157. k = (num_images * 12) ** 0.5
  158. r1 = max(1, floor(k / 4))
  159. r2 = ceil(k / 4)
  160. c1 = ceil(num_images / r1)
  161. c2 = ceil(num_images / r2)
  162. # Select the one which is closer to 4:3
  163. if abs(r1 / c1 - 0.75) < abs(r2 / c2 - 0.75):
  164. nrows, ncols = r1, c1
  165. else:
  166. nrows, ncols = r2, c2
  167. fig, axes = plt.subplots(nrows=nrows, ncols=ncols)
  168. ax = np.asarray(axes).ravel()
  169. for n, image in enumerate(ic):
  170. ax[n].imshow(image, *args, **kwargs)
  171. kwargs['ax'] = axes
  172. return fig
  173. def imread(*args, **kwargs):
  174. import matplotlib.image
  175. return matplotlib.image.imread(*args, **kwargs)
  176. def _app_show():
  177. from matplotlib.pyplot import show
  178. show()