_cycle_spin.py 5.9 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172
  1. from itertools import product
  2. import numpy as np
  3. from .._shared import utils
  4. from .._shared.utils import warn, deprecate_parameter, DEPRECATED
  5. try:
  6. import dask
  7. dask_available = True
  8. except ImportError:
  9. dask_available = False
  10. def _generate_shifts(ndim, multichannel, max_shifts, shift_steps=1):
  11. """Returns all combinations of shifts in n dimensions over the specified
  12. max_shifts and step sizes.
  13. Examples
  14. --------
  15. >>> s = list(_generate_shifts(2, False, max_shifts=(1, 2), shift_steps=1))
  16. >>> print(s)
  17. [(0, 0), (0, 1), (0, 2), (1, 0), (1, 1), (1, 2)]
  18. """
  19. mc = int(multichannel)
  20. if np.isscalar(max_shifts):
  21. max_shifts = (max_shifts,) * (ndim - mc) + (0,) * mc
  22. elif multichannel and len(max_shifts) == ndim - 1:
  23. max_shifts = tuple(max_shifts) + (0,)
  24. elif len(max_shifts) != ndim:
  25. raise ValueError("max_shifts should have length ndim")
  26. if np.isscalar(shift_steps):
  27. shift_steps = (shift_steps,) * (ndim - mc) + (1,) * mc
  28. elif multichannel and len(shift_steps) == ndim - 1:
  29. shift_steps = tuple(shift_steps) + (1,)
  30. elif len(shift_steps) != ndim:
  31. raise ValueError("max_shifts should have length ndim")
  32. if any(s < 1 for s in shift_steps):
  33. raise ValueError("shift_steps must all be >= 1")
  34. if multichannel and max_shifts[-1] != 0:
  35. raise ValueError(
  36. "Multichannel cycle spinning should not have shifts along the " "last axis."
  37. )
  38. return product(*[range(0, s + 1, t) for s, t in zip(max_shifts, shift_steps)])
  39. @deprecate_parameter(
  40. deprecated_name="num_workers",
  41. new_name="workers",
  42. start_version="0.26",
  43. stop_version="0.28",
  44. )
  45. @utils.channel_as_last_axis()
  46. def cycle_spin(
  47. x,
  48. func,
  49. max_shifts,
  50. shift_steps=1,
  51. num_workers=DEPRECATED,
  52. func_kw=None,
  53. *,
  54. workers=None,
  55. channel_axis=None,
  56. ):
  57. """Cycle spinning (repeatedly apply func to shifted versions of x).
  58. Parameters
  59. ----------
  60. x : array-like
  61. Data for input to ``func``.
  62. func : function
  63. A function to apply to circularly shifted versions of ``x``. Should
  64. take ``x`` as its first argument. Any additional arguments can be
  65. supplied via ``func_kw``.
  66. max_shifts : int or tuple
  67. If an integer, shifts in ``range(0, max_shifts+1)`` will be used along
  68. each axis of ``x``. If a tuple, ``range(0, max_shifts[i]+1)`` will be
  69. along axis i.
  70. shift_steps : int or tuple, optional
  71. The step size for the shifts applied along axis, i, are::
  72. ``range((0, max_shifts[i]+1, shift_steps[i]))``. If an integer is
  73. provided, the same step size is used for all axes.
  74. workers : int or None, optional
  75. The number of parallel threads to use during cycle spinning. If set to
  76. ``None``, the full set of available cores are used.
  77. func_kw : dict, optional
  78. Additional keyword arguments to supply to ``func``.
  79. channel_axis : int or None, optional
  80. If None, the image is assumed to be a grayscale (single channel) image.
  81. Otherwise, this parameter indicates which axis of the array corresponds
  82. to channels.
  83. .. versionadded:: 0.19
  84. ``channel_axis`` was added in 0.19.
  85. Returns
  86. -------
  87. avg_y : np.ndarray
  88. The output of ``func(x, **func_kw)`` averaged over all combinations of
  89. the specified axis shifts.
  90. Notes
  91. -----
  92. Cycle spinning was proposed as a way to approach shift-invariance via
  93. performing several circular shifts of a shift-variant transform [1]_.
  94. For a n-level discrete wavelet transforms, one may wish to perform all
  95. shifts up to ``max_shifts = 2**n - 1``. In practice, much of the benefit
  96. can often be realized with only a small number of shifts per axis.
  97. For transforms such as the blockwise discrete cosine transform, one may
  98. wish to evaluate shifts up to the block size used by the transform.
  99. References
  100. ----------
  101. .. [1] R.R. Coifman and D.L. Donoho. "Translation-Invariant De-Noising".
  102. Wavelets and Statistics, Lecture Notes in Statistics, vol.103.
  103. Springer, New York, 1995, pp.125-150.
  104. :DOI:`10.1007/978-1-4612-2544-7_9`
  105. Examples
  106. --------
  107. >>> import skimage.data
  108. >>> from skimage import img_as_float
  109. >>> from skimage.restoration import denoise_tv_chambolle, cycle_spin
  110. >>> img = img_as_float(skimage.data.camera())
  111. >>> sigma = 0.1
  112. >>> img = img + sigma * np.random.standard_normal(img.shape)
  113. >>> denoised = cycle_spin(img, func=denoise_tv_chambolle,
  114. ... max_shifts=3) # doctest: +IGNORE_WARNINGS
  115. """
  116. if func_kw is None:
  117. func_kw = {}
  118. x = np.asanyarray(x)
  119. multichannel = channel_axis is not None
  120. all_shifts = _generate_shifts(x.ndim, multichannel, max_shifts, shift_steps)
  121. all_shifts = list(all_shifts)
  122. roll_axes = tuple(range(x.ndim))
  123. def _run_one_shift(shift):
  124. # shift, apply function, inverse shift
  125. xs = np.roll(x, shift, axis=roll_axes)
  126. tmp = func(xs, **func_kw)
  127. return np.roll(tmp, tuple(-s for s in shift), axis=roll_axes)
  128. if not dask_available and (workers is None or workers > 1):
  129. workers = 1
  130. warn(
  131. 'The optional dask dependency is not installed. '
  132. 'The number of workers is set to 1. To silence '
  133. 'this warning, install dask or explicitly set `workers=1` '
  134. 'when calling the `cycle_spin` function',
  135. stacklevel=4,
  136. )
  137. # compute a running average across the cycle shifts
  138. if workers == 1:
  139. # serial processing
  140. mean = _run_one_shift(all_shifts[0])
  141. for shift in all_shifts[1:]:
  142. mean += _run_one_shift(shift)
  143. mean /= len(all_shifts)
  144. else:
  145. # multithreaded via dask
  146. futures = [dask.delayed(_run_one_shift)(s) for s in all_shifts]
  147. mean = sum(futures) / len(futures)
  148. mean = mean.compute(workers=workers)
  149. return mean