apply_parallel.py 7.5 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213
  1. import numpy
  2. __all__ = ['apply_parallel']
  3. def _get_chunks(shape, ncpu):
  4. """Split the array into equal sized chunks based on the number of
  5. available processors. The last chunk in each dimension absorbs the
  6. remainder array elements if the number of CPUs does not divide evenly into
  7. the number of array elements.
  8. Examples
  9. --------
  10. >>> _get_chunks((4, 4), 4)
  11. ((2, 2), (2, 2))
  12. >>> _get_chunks((4, 4), 2)
  13. ((2, 2), (4,))
  14. >>> _get_chunks((5, 5), 2)
  15. ((2, 3), (5,))
  16. >>> _get_chunks((2, 4), 2)
  17. ((1, 1), (4,))
  18. """
  19. # since apply_parallel is in the critical import path, we lazy import
  20. # math just when we need it.
  21. from math import ceil
  22. chunks = []
  23. nchunks_per_dim = int(ceil(ncpu ** (1.0 / len(shape))))
  24. used_chunks = 1
  25. for i in shape:
  26. if used_chunks < ncpu:
  27. regular_chunk = i // nchunks_per_dim
  28. remainder_chunk = regular_chunk + (i % nchunks_per_dim)
  29. if regular_chunk == 0:
  30. chunk_lens = (remainder_chunk,)
  31. else:
  32. chunk_lens = (regular_chunk,) * (nchunks_per_dim - 1) + (
  33. remainder_chunk,
  34. )
  35. else:
  36. chunk_lens = (i,)
  37. chunks.append(chunk_lens)
  38. used_chunks *= nchunks_per_dim
  39. return tuple(chunks)
  40. def _ensure_dask_array(array, chunks=None):
  41. import dask.array as da
  42. if isinstance(array, da.Array):
  43. return array
  44. return da.from_array(array, chunks=chunks)
  45. def apply_parallel(
  46. function,
  47. array,
  48. chunks=None,
  49. depth=0,
  50. mode=None,
  51. extra_arguments=(),
  52. extra_keywords=None,
  53. *,
  54. dtype=None,
  55. compute=None,
  56. channel_axis=None,
  57. ):
  58. """Map a function in parallel across an array.
  59. Split an array into possibly overlapping chunks of a given depth and
  60. boundary type, call the given function in parallel on the chunks, combine
  61. the chunks and return the resulting array.
  62. Parameters
  63. ----------
  64. function : function
  65. Function to be mapped which takes an array as an argument.
  66. array : numpy array or dask array
  67. Array which the function will be applied to.
  68. chunks : int, tuple, or tuple of tuples, optional
  69. A single integer is interpreted as the length of one side of a square
  70. chunk that should be tiled across the array. One tuple of length
  71. ``array.ndim`` represents the shape of a chunk, and it is tiled across
  72. the array. A list of tuples of length ``ndim``, where each sub-tuple
  73. is a sequence of chunk sizes along the corresponding dimension. If
  74. None, the array is broken up into chunks based on the number of
  75. available cpus. More information about chunks is in the documentation
  76. `here <https://dask.pydata.org/en/latest/array-design.html>`_. When
  77. `channel_axis` is not None, the tuples can be length ``ndim - 1`` and
  78. a single chunk will be used along the channel axis.
  79. depth : int or sequence of int, optional
  80. The depth of the added boundary cells. A tuple can be used to specify a
  81. different depth per array axis. Defaults to zero. When `channel_axis`
  82. is not None, and a tuple of length ``ndim - 1`` is provided, a depth of
  83. 0 will be used along the channel axis.
  84. mode : {'reflect', 'symmetric', 'periodic', 'wrap', 'nearest', 'edge'}, optional
  85. Type of external boundary padding.
  86. extra_arguments : tuple, optional
  87. Tuple of arguments to be passed to the function.
  88. extra_keywords : dictionary, optional
  89. Dictionary of keyword arguments to be passed to the function.
  90. dtype : data-type or None, optional
  91. The data-type of the `function` output. If None, Dask will attempt to
  92. infer this by calling the function on data of shape ``(1,) * ndim``.
  93. For functions expecting RGB or multichannel data this may be
  94. problematic. In such cases, the user should manually specify this dtype
  95. argument instead.
  96. .. versionadded:: 0.18
  97. ``dtype`` was added in 0.18.
  98. compute : bool, optional
  99. If ``True``, compute eagerly returning a NumPy Array.
  100. If ``False``, compute lazily returning a Dask Array.
  101. If ``None`` (default), compute based on array type provided
  102. (eagerly for NumPy Arrays and lazily for Dask Arrays).
  103. channel_axis : int or None, optional
  104. If None, the image is assumed to be a grayscale (single channel) image.
  105. Otherwise, this parameter indicates which axis of the array corresponds
  106. to channels.
  107. Returns
  108. -------
  109. out : ndarray or dask Array
  110. Returns the result of the applying the operation.
  111. Type is dependent on the ``compute`` argument.
  112. Notes
  113. -----
  114. Numpy edge modes 'symmetric', 'wrap', and 'edge' are converted to the
  115. equivalent ``dask`` boundary modes 'reflect', 'periodic' and 'nearest',
  116. respectively.
  117. Setting ``compute=False`` can be useful for chaining later operations.
  118. For example region selection to preview a result or storing large data
  119. to disk instead of loading in memory.
  120. """
  121. try:
  122. # Importing dask takes time. since apply_parallel is on the
  123. # minimum import path of skimage, we lazy attempt to import dask
  124. import dask.array as da
  125. except ImportError:
  126. raise RuntimeError(
  127. "Could not import 'dask'. Please install " "using 'pip install dask'"
  128. )
  129. if extra_keywords is None:
  130. extra_keywords = {}
  131. if compute is None:
  132. compute = not isinstance(array, da.Array)
  133. if channel_axis is not None:
  134. channel_axis = channel_axis % array.ndim
  135. if chunks is None:
  136. shape = array.shape
  137. try:
  138. # since apply_parallel is in the critical import path, we lazy
  139. # import multiprocessing just when we need it.
  140. from multiprocessing import cpu_count
  141. ncpu = cpu_count()
  142. except NotImplementedError:
  143. ncpu = 4
  144. if channel_axis is not None:
  145. # use a single chunk along the channel axis
  146. spatial_shape = shape[:channel_axis] + shape[channel_axis + 1 :]
  147. chunks = list(_get_chunks(spatial_shape, ncpu))
  148. chunks.insert(channel_axis, shape[channel_axis])
  149. chunks = tuple(chunks)
  150. else:
  151. chunks = _get_chunks(shape, ncpu)
  152. elif channel_axis is not None and len(chunks) == array.ndim - 1:
  153. # insert a single chunk along the channel axis
  154. chunks = list(chunks)
  155. chunks.insert(channel_axis, array.shape[channel_axis])
  156. chunks = tuple(chunks)
  157. if mode == 'wrap':
  158. mode = 'periodic'
  159. elif mode == 'symmetric':
  160. mode = 'reflect'
  161. elif mode == 'edge':
  162. mode = 'nearest'
  163. elif mode is None:
  164. # default value for Dask.
  165. # Note: that for dask >= 2022.03 it will change to 'none' so we set it
  166. # here for consistent behavior across Dask versions.
  167. mode = 'reflect'
  168. if channel_axis is not None:
  169. if numpy.isscalar(depth):
  170. # depth is zero along channel_axis
  171. depth = [depth] * (array.ndim - 1)
  172. depth = list(depth)
  173. if len(depth) == array.ndim - 1:
  174. depth.insert(channel_axis, 0)
  175. depth = tuple(depth)
  176. def wrapped_func(arr):
  177. return function(arr, *extra_arguments, **extra_keywords)
  178. darr = _ensure_dask_array(array, chunks=chunks)
  179. res = darr.map_overlap(wrapped_func, depth, boundary=mode, dtype=dtype)
  180. if compute:
  181. res = res.compute()
  182. return res