_structural_similarity.py 10 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292
  1. import functools
  2. import numpy as np
  3. from scipy.ndimage import uniform_filter
  4. from .._shared import utils
  5. from .._shared.filters import gaussian
  6. from .._shared.utils import _supported_float_type, check_shape_equality, warn
  7. from ..util.arraycrop import crop
  8. from ..util.dtype import dtype_range
  9. __all__ = ['structural_similarity']
  10. def structural_similarity(
  11. im1,
  12. im2,
  13. *,
  14. win_size=None,
  15. gradient=False,
  16. data_range=None,
  17. channel_axis=None,
  18. gaussian_weights=False,
  19. full=False,
  20. **kwargs,
  21. ):
  22. """
  23. Compute the mean structural similarity index between two images.
  24. Please pay attention to the `data_range` parameter with floating-point images.
  25. Parameters
  26. ----------
  27. im1, im2 : ndarray
  28. Images. Any dimensionality with same shape.
  29. win_size : int or None, optional
  30. The side-length of the sliding window used in comparison. Must be an
  31. odd value. If `gaussian_weights` is True, this is ignored and the
  32. window size will depend on `sigma`.
  33. gradient : bool, optional
  34. If True, also return the gradient with respect to im2.
  35. data_range : float, optional
  36. The data range of the input image (difference between maximum and
  37. minimum possible values). By default, this is estimated from the image
  38. data type. This estimate may be wrong for floating-point image data.
  39. Therefore it is recommended to always pass this scalar value explicitly
  40. (see note below).
  41. channel_axis : int or None, optional
  42. If None, the image is assumed to be a grayscale (single channel) image.
  43. Otherwise, this parameter indicates which axis of the array corresponds
  44. to channels.
  45. .. versionadded:: 0.19
  46. ``channel_axis`` was added in 0.19.
  47. gaussian_weights : bool, optional
  48. If True, each patch has its mean and variance spatially weighted by a
  49. normalized Gaussian kernel of width sigma=1.5.
  50. full : bool, optional
  51. If True, also return the full structural similarity image.
  52. Other Parameters
  53. ----------------
  54. use_sample_covariance : bool
  55. If True, normalize covariances by N-1 rather than, N where N is the
  56. number of pixels within the sliding window.
  57. K1 : float
  58. Algorithm parameter, K1 (small constant, see [1]_).
  59. K2 : float
  60. Algorithm parameter, K2 (small constant, see [1]_).
  61. sigma : float
  62. Standard deviation for the Gaussian when `gaussian_weights` is True.
  63. Returns
  64. -------
  65. mssim : float
  66. The mean structural similarity index over the image.
  67. grad : ndarray
  68. The gradient of the structural similarity between im1 and im2 [2]_.
  69. This is only returned if `gradient` is set to True.
  70. S : ndarray
  71. The full SSIM image. This is only returned if `full` is set to True.
  72. Notes
  73. -----
  74. If `data_range` is not specified, the range is automatically guessed
  75. based on the image data type. However for floating-point image data, this
  76. estimate yields a result double the value of the desired range, as the
  77. `dtype_range` in `skimage.util.dtype.py` has defined intervals from -1 to
  78. +1. This yields an estimate of 2, instead of 1, which is most often
  79. required when working with image data (as negative light intensities are
  80. nonsensical). In case of working with YCbCr-like color data, note that
  81. these ranges are different per channel (Cb and Cr have double the range
  82. of Y), so one cannot calculate a channel-averaged SSIM with a single call
  83. to this function, as identical ranges are assumed for each channel.
  84. To match the implementation of Wang et al. [1]_, set `gaussian_weights`
  85. to True, `sigma` to 1.5, `use_sample_covariance` to False, and
  86. specify the `data_range` argument.
  87. .. versionchanged:: 0.16
  88. This function was renamed from ``skimage.measure.compare_ssim`` to
  89. ``skimage.metrics.structural_similarity``.
  90. References
  91. ----------
  92. .. [1] Wang, Z., Bovik, A. C., Sheikh, H. R., & Simoncelli, E. P.
  93. (2004). Image quality assessment: From error visibility to
  94. structural similarity. IEEE Transactions on Image Processing,
  95. 13, 600-612.
  96. https://ece.uwaterloo.ca/~z70wang/publications/ssim.pdf,
  97. :DOI:`10.1109/TIP.2003.819861`
  98. .. [2] Avanaki, A. N. (2009). Exact global histogram specification
  99. optimized for structural similarity. Optical Review, 16, 613-621.
  100. :arxiv:`0901.0065`
  101. :DOI:`10.1007/s10043-009-0119-z`
  102. """
  103. check_shape_equality(im1, im2)
  104. float_type = _supported_float_type(im1.dtype)
  105. if channel_axis is not None:
  106. # loop over channels
  107. args = dict(
  108. win_size=win_size,
  109. gradient=gradient,
  110. data_range=data_range,
  111. channel_axis=None,
  112. gaussian_weights=gaussian_weights,
  113. full=full,
  114. )
  115. args.update(kwargs)
  116. nch = im1.shape[channel_axis]
  117. mssim = np.empty(nch, dtype=float_type)
  118. if gradient:
  119. G = np.empty(im1.shape, dtype=float_type)
  120. if full:
  121. S = np.empty(im1.shape, dtype=float_type)
  122. channel_axis = channel_axis % im1.ndim
  123. _at = functools.partial(utils.slice_at_axis, axis=channel_axis)
  124. for ch in range(nch):
  125. ch_result = structural_similarity(im1[_at(ch)], im2[_at(ch)], **args)
  126. if gradient and full:
  127. mssim[ch], G[_at(ch)], S[_at(ch)] = ch_result
  128. elif gradient:
  129. mssim[ch], G[_at(ch)] = ch_result
  130. elif full:
  131. mssim[ch], S[_at(ch)] = ch_result
  132. else:
  133. mssim[ch] = ch_result
  134. mssim = mssim.mean()
  135. if gradient and full:
  136. return mssim, G, S
  137. elif gradient:
  138. return mssim, G
  139. elif full:
  140. return mssim, S
  141. else:
  142. return mssim
  143. K1 = kwargs.pop('K1', 0.01)
  144. K2 = kwargs.pop('K2', 0.03)
  145. sigma = kwargs.pop('sigma', 1.5)
  146. if K1 < 0:
  147. raise ValueError("K1 must be positive")
  148. if K2 < 0:
  149. raise ValueError("K2 must be positive")
  150. if sigma < 0:
  151. raise ValueError("sigma must be positive")
  152. use_sample_covariance = kwargs.pop('use_sample_covariance', True)
  153. if gaussian_weights:
  154. # Set to give an 11-tap filter with the default sigma of 1.5 to match
  155. # Wang et. al. 2004.
  156. truncate = 3.5
  157. if win_size is None:
  158. if gaussian_weights:
  159. # set win_size used by crop to match the filter size
  160. r = int(truncate * sigma + 0.5) # radius as in ndimage
  161. win_size = 2 * r + 1
  162. else:
  163. win_size = 7 # backwards compatibility
  164. if np.any((np.asarray(im1.shape) - win_size) < 0):
  165. raise ValueError(
  166. 'win_size exceeds image extent. '
  167. 'Either ensure that your images are '
  168. 'at least 7x7; or pass win_size explicitly '
  169. 'in the function call, with an odd value '
  170. 'less than or equal to the smaller side of your '
  171. 'images. If your images are multichannel '
  172. '(with color channels), set channel_axis to '
  173. 'the axis number corresponding to the channels.'
  174. )
  175. if not (win_size % 2 == 1):
  176. raise ValueError('Window size must be odd.')
  177. if data_range is None:
  178. if np.issubdtype(im1.dtype, np.floating) or np.issubdtype(
  179. im2.dtype, np.floating
  180. ):
  181. raise ValueError(
  182. 'Since image dtype is floating point, you must specify '
  183. 'the data_range parameter. Please read the documentation '
  184. 'carefully (including the note). It is recommended that '
  185. 'you always specify the data_range anyway.'
  186. )
  187. if im1.dtype != im2.dtype:
  188. warn(
  189. "Inputs have mismatched dtypes. Setting data_range based on im1.dtype.",
  190. stacklevel=2,
  191. )
  192. dmin, dmax = dtype_range[im1.dtype.type]
  193. data_range = dmax - dmin
  194. if np.issubdtype(im1.dtype, np.integer) and (im1.dtype != np.uint8):
  195. warn(
  196. "Setting data_range based on im1.dtype. "
  197. + f"data_range = {data_range:.0f}. "
  198. + "Please specify data_range explicitly to avoid mistakes.",
  199. stacklevel=2,
  200. )
  201. ndim = im1.ndim
  202. if gaussian_weights:
  203. filter_func = gaussian
  204. filter_args = {'sigma': sigma, 'truncate': truncate, 'mode': 'reflect'}
  205. else:
  206. filter_func = uniform_filter
  207. filter_args = {'size': win_size}
  208. # ndimage filters need floating point data
  209. im1 = im1.astype(float_type, copy=False)
  210. im2 = im2.astype(float_type, copy=False)
  211. NP = win_size**ndim
  212. # filter has already normalized by NP
  213. if use_sample_covariance:
  214. cov_norm = NP / (NP - 1) # sample covariance
  215. else:
  216. cov_norm = 1.0 # population covariance to match Wang et. al. 2004
  217. # compute (weighted) means
  218. ux = filter_func(im1, **filter_args)
  219. uy = filter_func(im2, **filter_args)
  220. # compute (weighted) variances and covariances
  221. uxx = filter_func(im1 * im1, **filter_args)
  222. uyy = filter_func(im2 * im2, **filter_args)
  223. uxy = filter_func(im1 * im2, **filter_args)
  224. vx = cov_norm * (uxx - ux * ux)
  225. vy = cov_norm * (uyy - uy * uy)
  226. vxy = cov_norm * (uxy - ux * uy)
  227. R = data_range
  228. C1 = (K1 * R) ** 2
  229. C2 = (K2 * R) ** 2
  230. A1, A2, B1, B2 = (
  231. 2 * ux * uy + C1,
  232. 2 * vxy + C2,
  233. ux**2 + uy**2 + C1,
  234. vx + vy + C2,
  235. )
  236. D = B1 * B2
  237. S = (A1 * A2) / D
  238. # to avoid edge effects will ignore filter radius strip around edges
  239. pad = (win_size - 1) // 2
  240. # compute (weighted) mean of ssim. Use float64 for accuracy.
  241. mssim = crop(S, pad).mean(dtype=np.float64)
  242. if gradient:
  243. # The following is Eqs. 7-8 of Avanaki 2009.
  244. grad = filter_func(A1 / D, **filter_args) * im1
  245. grad += filter_func(-S / B2, **filter_args) * im2
  246. grad += filter_func((ux * (A2 - A1) - uy * (B2 - B1) * S) / D, **filter_args)
  247. grad *= 2 / im1.size
  248. if full:
  249. return mssim, grad, S
  250. else:
  251. return mssim, grad
  252. else:
  253. if full:
  254. return mssim, S
  255. else:
  256. return mssim