histogram_matching.py 3.1 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293
  1. import numpy as np
  2. from .._shared import utils
  3. def _match_cumulative_cdf(source, template):
  4. """
  5. Return modified source array so that the cumulative density function of
  6. its values matches the cumulative density function of the template.
  7. """
  8. if source.dtype.kind == 'u':
  9. src_lookup = source.reshape(-1)
  10. src_counts = np.bincount(src_lookup)
  11. tmpl_counts = np.bincount(template.reshape(-1))
  12. # omit values where the count was 0
  13. tmpl_values = np.nonzero(tmpl_counts)[0]
  14. tmpl_counts = tmpl_counts[tmpl_values]
  15. else:
  16. src_values, src_lookup, src_counts = np.unique(
  17. source.reshape(-1), return_inverse=True, return_counts=True
  18. )
  19. tmpl_values, tmpl_counts = np.unique(template.reshape(-1), return_counts=True)
  20. # calculate normalized quantiles for each array
  21. src_quantiles = np.cumsum(src_counts) / source.size
  22. tmpl_quantiles = np.cumsum(tmpl_counts) / template.size
  23. interp_a_values = np.interp(src_quantiles, tmpl_quantiles, tmpl_values)
  24. return interp_a_values[src_lookup].reshape(source.shape)
  25. @utils.channel_as_last_axis(channel_arg_positions=(0, 1))
  26. def match_histograms(image, reference, *, channel_axis=None):
  27. """Adjust an image so that its cumulative histogram matches that of another.
  28. The adjustment is applied separately for each channel.
  29. Parameters
  30. ----------
  31. image : ndarray
  32. Input image. Can be gray-scale or in color.
  33. reference : ndarray
  34. Image to match histogram of. Must have the same number of channels as
  35. image.
  36. channel_axis : int or None, optional
  37. If None, the image is assumed to be a grayscale (single channel) image.
  38. Otherwise, this parameter indicates which axis of the array corresponds
  39. to channels.
  40. Returns
  41. -------
  42. matched : ndarray
  43. Transformed input image.
  44. Raises
  45. ------
  46. ValueError
  47. Thrown when the number of channels in the input image and the reference
  48. differ.
  49. References
  50. ----------
  51. .. [1] http://paulbourke.net/miscellaneous/equalisation/
  52. """
  53. if image.ndim != reference.ndim:
  54. raise ValueError(
  55. 'Image and reference must have the same number ' 'of channels.'
  56. )
  57. if channel_axis is not None:
  58. if image.shape[-1] != reference.shape[-1]:
  59. raise ValueError(
  60. 'Number of channels in the input image and '
  61. 'reference image must match!'
  62. )
  63. matched = np.empty(image.shape, dtype=image.dtype)
  64. for channel in range(image.shape[-1]):
  65. matched_channel = _match_cumulative_cdf(
  66. image[..., channel], reference[..., channel]
  67. )
  68. matched[..., channel] = matched_channel
  69. else:
  70. # _match_cumulative_cdf will always return float64 due to np.interp
  71. matched = _match_cumulative_cdf(image, reference)
  72. if matched.dtype.kind == 'f':
  73. # output a float32 result when the input is float16 or float32
  74. out_dtype = utils._supported_float_type(image.dtype)
  75. matched = matched.astype(out_dtype, copy=False)
  76. return matched