slic_superpixels.py 16 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449
  1. import math
  2. from collections.abc import Iterable
  3. from warnings import warn
  4. import numpy as np
  5. from numpy import random
  6. from scipy.cluster.vq import kmeans2
  7. from scipy.spatial.distance import pdist, squareform
  8. from .._shared import utils
  9. from .._shared.filters import gaussian
  10. from ..color import rgb2lab
  11. from ..util import img_as_float, regular_grid
  12. from ._slic import _enforce_label_connectivity_cython, _slic_cython
  13. def _get_mask_centroids(mask, n_centroids, multichannel):
  14. """Find regularly spaced centroids on a mask.
  15. Parameters
  16. ----------
  17. mask : 3D ndarray
  18. The mask within which the centroids must be positioned.
  19. n_centroids : int
  20. The number of centroids to be returned.
  21. Returns
  22. -------
  23. centroids : 2D ndarray
  24. The coordinates of the centroids with shape (n_centroids, 3).
  25. steps : 1D ndarray
  26. The approximate distance between two seeds in all dimensions.
  27. """
  28. # Get tight ROI around the mask to optimize
  29. coord = np.array(np.nonzero(mask), dtype=float).T
  30. # Fix random seed to ensure repeatability
  31. # Keep old-style RandomState here as expected results in tests depend on it
  32. rng = random.RandomState(123)
  33. # select n_centroids randomly distributed points from within the mask
  34. idx_full = np.arange(len(coord), dtype=int)
  35. idx = np.sort(rng.choice(idx_full, min(n_centroids, len(coord)), replace=False))
  36. # To save time, when n_centroids << len(coords), use only a subset of the
  37. # coordinates when calling k-means. Rather than the full set of coords,
  38. # we will use a substantially larger subset than n_centroids. Here we
  39. # somewhat arbitrarily choose dense_factor=10 to make the samples
  40. # 10 times closer together along each axis than the n_centroids samples.
  41. dense_factor = 10
  42. ndim_spatial = mask.ndim - 1 if multichannel else mask.ndim
  43. n_dense = int((dense_factor**ndim_spatial) * n_centroids)
  44. if len(coord) > n_dense:
  45. # subset of points to use for the k-means calculation
  46. # (much denser than idx, but less than the full set)
  47. idx_dense = np.sort(rng.choice(idx_full, n_dense, replace=False))
  48. else:
  49. idx_dense = Ellipsis
  50. centroids, _ = kmeans2(coord[idx_dense], coord[idx], iter=5)
  51. # Compute the minimum distance of each centroid to the others
  52. dist = squareform(pdist(centroids))
  53. np.fill_diagonal(dist, np.inf)
  54. closest_pts = dist.argmin(-1)
  55. steps = abs(centroids - centroids[closest_pts, :]).mean(0)
  56. return centroids, steps
  57. def _get_grid_centroids(image, n_centroids):
  58. """Find regularly spaced centroids on the image.
  59. Parameters
  60. ----------
  61. image : 2D, 3D or 4D ndarray
  62. Input image, which can be 2D or 3D, and grayscale or
  63. multichannel.
  64. n_centroids : int
  65. The (approximate) number of centroids to be returned.
  66. Returns
  67. -------
  68. centroids : 2D ndarray
  69. The coordinates of the centroids with shape (~n_centroids, 3).
  70. steps : 1D ndarray
  71. The approximate distance between two seeds in all dimensions.
  72. """
  73. d, h, w = image.shape[:3]
  74. grid_z, grid_y, grid_x = np.mgrid[:d, :h, :w]
  75. slices = regular_grid(image.shape[:3], n_centroids)
  76. centroids_z = grid_z[slices].ravel()[..., np.newaxis]
  77. centroids_y = grid_y[slices].ravel()[..., np.newaxis]
  78. centroids_x = grid_x[slices].ravel()[..., np.newaxis]
  79. centroids = np.concatenate([centroids_z, centroids_y, centroids_x], axis=-1)
  80. steps = np.asarray([float(s.step) if s.step is not None else 1.0 for s in slices])
  81. return centroids, steps
  82. @utils.channel_as_last_axis(multichannel_output=False)
  83. def slic(
  84. image,
  85. n_segments=100,
  86. compactness=10.0,
  87. max_num_iter=10,
  88. sigma=0,
  89. spacing=None,
  90. convert2lab=None,
  91. enforce_connectivity=True,
  92. min_size_factor=0.5,
  93. max_size_factor=3,
  94. slic_zero=False,
  95. start_label=1,
  96. mask=None,
  97. *,
  98. channel_axis=-1,
  99. ):
  100. """Segments image using k-means clustering in Color-(x,y,z) space.
  101. Parameters
  102. ----------
  103. image : (M, N[, P][, C]) ndarray
  104. Input image. Can be 2D or 3D, and grayscale or multichannel
  105. (see `channel_axis` parameter).
  106. Input image must either be NaN-free or the NaN's must be masked out.
  107. n_segments : int, optional
  108. The (approximate) number of labels in the segmented output image.
  109. compactness : float, optional
  110. Balances color proximity and space proximity. Higher values give
  111. more weight to space proximity, making superpixel shapes more
  112. square/cubic. In SLICO mode, this is the initial compactness.
  113. This parameter depends strongly on image contrast and on the
  114. shapes of objects in the image. We recommend exploring possible
  115. values on a log scale, e.g., 0.01, 0.1, 1, 10, 100, before
  116. refining around a chosen value.
  117. max_num_iter : int, optional
  118. Maximum number of iterations of k-means.
  119. sigma : float or array-like of floats, optional
  120. Width of Gaussian smoothing kernel for pre-processing for each
  121. dimension of the image. The same sigma is applied to each dimension in
  122. case of a scalar value. Zero means no smoothing.
  123. Note that `sigma` is automatically scaled if it is scalar and
  124. if a manual voxel spacing is provided (see Notes section). If
  125. sigma is array-like, its size must match ``image``'s number
  126. of spatial dimensions.
  127. spacing : array-like of floats, optional
  128. The voxel spacing along each spatial dimension. By default,
  129. `slic` assumes uniform spacing (same voxel resolution along
  130. each spatial dimension).
  131. This parameter controls the weights of the distances along the
  132. spatial dimensions during k-means clustering.
  133. convert2lab : bool, optional
  134. Whether the input should be converted to Lab colorspace prior to
  135. segmentation. The input image *must* be RGB. Highly recommended.
  136. This option defaults to ``True`` when ``channel_axis` is not None *and*
  137. ``image.shape[-1] == 3``.
  138. enforce_connectivity : bool, optional
  139. Whether the generated segments are connected or not
  140. min_size_factor : float, optional
  141. Proportion of the minimum segment size to be removed with respect
  142. to the supposed segment size ```depth*width*height/n_segments```
  143. max_size_factor : float, optional
  144. Proportion of the maximum connected segment size. A value of 3 works
  145. in most of the cases.
  146. slic_zero : bool, optional
  147. Run SLIC-zero, the zero-parameter mode of SLIC. [2]_
  148. start_label : int, optional
  149. The labels' index start. Should be 0 or 1.
  150. .. versionadded:: 0.17
  151. ``start_label`` was introduced in 0.17
  152. mask : ndarray, optional
  153. If provided, superpixels are computed only where mask is True,
  154. and seed points are homogeneously distributed over the mask
  155. using a k-means clustering strategy. Mask number of dimensions
  156. must be equal to image number of spatial dimensions.
  157. .. versionadded:: 0.17
  158. ``mask`` was introduced in 0.17
  159. channel_axis : int or None, optional
  160. If None, the image is assumed to be a grayscale (single channel) image.
  161. Otherwise, this parameter indicates which axis of the array corresponds
  162. to channels.
  163. .. versionadded:: 0.19
  164. ``channel_axis`` was added in 0.19.
  165. Returns
  166. -------
  167. labels : 2D or 3D array
  168. Integer mask indicating segment labels.
  169. Raises
  170. ------
  171. ValueError
  172. If ``convert2lab`` is set to ``True`` but the last array
  173. dimension is not of length 3.
  174. ValueError
  175. If ``start_label`` is not 0 or 1.
  176. ValueError
  177. If ``image`` contains unmasked NaN values.
  178. ValueError
  179. If ``image`` contains unmasked infinite values.
  180. ValueError
  181. If ``image`` is 2D but ``channel_axis`` is -1 (the default).
  182. Notes
  183. -----
  184. * If `sigma > 0`, the image is smoothed using a Gaussian kernel prior to
  185. segmentation.
  186. * If `sigma` is scalar and `spacing` is provided, the kernel width is
  187. divided along each dimension by the spacing. For example, if ``sigma=1``
  188. and ``spacing=[5, 1, 1]``, the effective `sigma` is ``[0.2, 1, 1]``. This
  189. ensures sensible smoothing for anisotropic images.
  190. * The image is rescaled to be in [0, 1] prior to processing (masked
  191. values are ignored).
  192. * Images of shape (M, N, 3) are interpreted as 2D RGB images by default. To
  193. interpret them as 3D with the last dimension having length 3, use
  194. `channel_axis=None`.
  195. * `start_label` is introduced to handle the issue [4]_. Label indexing
  196. starts at 1 by default.
  197. References
  198. ----------
  199. .. [1] Radhakrishna Achanta, Appu Shaji, Kevin Smith, Aurelien Lucchi,
  200. Pascal Fua, and Sabine Süsstrunk, SLIC Superpixels Compared to
  201. State-of-the-art Superpixel Methods, TPAMI, May 2012.
  202. :DOI:`10.1109/TPAMI.2012.120`
  203. .. [2] https://www.epfl.ch/labs/ivrl/research/slic-superpixels/#SLICO
  204. .. [3] Irving, Benjamin. "maskSLIC: regional superpixel generation with
  205. application to local pathology characterisation in medical images.",
  206. 2016, :arXiv:`1606.09518`
  207. .. [4] https://github.com/scikit-image/scikit-image/issues/3722
  208. Examples
  209. --------
  210. >>> from skimage.segmentation import slic
  211. >>> from skimage.data import astronaut
  212. >>> img = astronaut()
  213. >>> segments = slic(img, n_segments=100, compactness=10)
  214. Increasing the compactness parameter yields more square regions:
  215. >>> segments = slic(img, n_segments=100, compactness=20)
  216. """
  217. if image.ndim == 2 and channel_axis is not None:
  218. raise ValueError(
  219. f"channel_axis={channel_axis} indicates multichannel, which is not "
  220. "supported for a two-dimensional image; use channel_axis=None if "
  221. "the image is grayscale"
  222. )
  223. image = img_as_float(image)
  224. float_dtype = utils._supported_float_type(image.dtype)
  225. # copy=True so subsequent in-place operations do not modify the
  226. # function input
  227. image = image.astype(float_dtype, copy=True)
  228. if mask is not None:
  229. # Create masked_image to rescale while ignoring masked values
  230. mask = np.ascontiguousarray(mask, dtype=bool)
  231. if channel_axis is not None:
  232. mask_ = np.expand_dims(mask, axis=channel_axis)
  233. mask_ = np.broadcast_to(mask_, image.shape)
  234. else:
  235. mask_ = mask
  236. image_values = image[mask_]
  237. else:
  238. image_values = image
  239. # Rescale image to [0, 1] to make choice of compactness insensitive to
  240. # input image scale.
  241. imin = image_values.min()
  242. imax = image_values.max()
  243. if np.isnan(imin):
  244. raise ValueError("unmasked NaN values in image are not supported")
  245. if np.isinf(imin) or np.isinf(imax):
  246. raise ValueError("unmasked infinite values in image are not supported")
  247. image -= imin
  248. if imax != imin:
  249. image /= imax - imin
  250. use_mask = mask is not None
  251. dtype = image.dtype
  252. is_2d = False
  253. multichannel = channel_axis is not None
  254. if image.ndim == 2:
  255. # 2D grayscale image
  256. image = image[np.newaxis, ..., np.newaxis]
  257. is_2d = True
  258. elif image.ndim == 3 and multichannel:
  259. # Make 2D multichannel image 3D with depth = 1
  260. image = image[np.newaxis, ...]
  261. is_2d = True
  262. elif image.ndim == 3 and not multichannel:
  263. # Add channel as single last dimension
  264. image = image[..., np.newaxis]
  265. if multichannel and (convert2lab or convert2lab is None):
  266. if image.shape[channel_axis] != 3 and convert2lab:
  267. raise ValueError("Lab colorspace conversion requires a RGB image.")
  268. elif image.shape[channel_axis] == 3:
  269. image = rgb2lab(image)
  270. if start_label not in [0, 1]:
  271. raise ValueError("start_label should be 0 or 1.")
  272. # initialize cluster centroids for desired number of segments
  273. update_centroids = False
  274. if use_mask:
  275. mask = mask.view('uint8')
  276. if mask.ndim == 2:
  277. mask = np.ascontiguousarray(mask[np.newaxis, ...])
  278. if mask.shape != image.shape[:3]:
  279. raise ValueError("image and mask should have the same shape.")
  280. centroids, steps = _get_mask_centroids(mask, n_segments, multichannel)
  281. update_centroids = True
  282. else:
  283. centroids, steps = _get_grid_centroids(image, n_segments)
  284. if spacing is None:
  285. spacing = np.ones(3, dtype=dtype)
  286. elif isinstance(spacing, Iterable):
  287. spacing = np.asarray(spacing, dtype=dtype)
  288. if is_2d:
  289. if spacing.size != 2:
  290. if spacing.size == 3:
  291. warn(
  292. "Input image is 2D: spacing number of "
  293. "elements must be 2. In the future, a ValueError "
  294. "will be raised.",
  295. FutureWarning,
  296. stacklevel=2,
  297. )
  298. else:
  299. raise ValueError(
  300. f"Input image is 2D, but spacing has "
  301. f"{spacing.size} elements (expected 2)."
  302. )
  303. else:
  304. spacing = np.insert(spacing, 0, 1)
  305. elif spacing.size != 3:
  306. raise ValueError(
  307. f"Input image is 3D, but spacing has "
  308. f"{spacing.size} elements (expected 3)."
  309. )
  310. spacing = np.ascontiguousarray(spacing, dtype=dtype)
  311. else:
  312. raise TypeError("spacing must be None or iterable.")
  313. if np.isscalar(sigma):
  314. sigma = np.array([sigma, sigma, sigma], dtype=dtype)
  315. sigma /= spacing
  316. elif isinstance(sigma, Iterable):
  317. sigma = np.asarray(sigma, dtype=dtype)
  318. if is_2d:
  319. if sigma.size != 2:
  320. if spacing.size == 3:
  321. warn(
  322. "Input image is 2D: sigma number of "
  323. "elements must be 2. In the future, a ValueError "
  324. "will be raised.",
  325. FutureWarning,
  326. stacklevel=2,
  327. )
  328. else:
  329. raise ValueError(
  330. f"Input image is 2D, but sigma has "
  331. f"{sigma.size} elements (expected 2)."
  332. )
  333. else:
  334. sigma = np.insert(sigma, 0, 0)
  335. elif sigma.size != 3:
  336. raise ValueError(
  337. f"Input image is 3D, but sigma has "
  338. f"{sigma.size} elements (expected 3)."
  339. )
  340. if (sigma > 0).any():
  341. # add zero smoothing for channel dimension
  342. sigma = list(sigma) + [0]
  343. image = gaussian(image, sigma=sigma, mode='reflect')
  344. n_centroids = centroids.shape[0]
  345. segments = np.ascontiguousarray(
  346. np.concatenate([centroids, np.zeros((n_centroids, image.shape[3]))], axis=-1),
  347. dtype=dtype,
  348. )
  349. # Scaling of ratio in the same way as in the SLIC paper so the
  350. # values have the same meaning
  351. step = max(steps)
  352. ratio = 1.0 / compactness
  353. image = np.ascontiguousarray(image * ratio, dtype=dtype)
  354. if update_centroids:
  355. # Step 2 of the algorithm [3]_
  356. _slic_cython(
  357. image,
  358. mask,
  359. segments,
  360. step,
  361. max_num_iter,
  362. spacing,
  363. slic_zero,
  364. ignore_color=True,
  365. start_label=start_label,
  366. )
  367. labels = _slic_cython(
  368. image,
  369. mask,
  370. segments,
  371. step,
  372. max_num_iter,
  373. spacing,
  374. slic_zero,
  375. ignore_color=False,
  376. start_label=start_label,
  377. )
  378. if enforce_connectivity:
  379. if use_mask:
  380. segment_size = mask.sum() / n_centroids
  381. else:
  382. segment_size = math.prod(image.shape[:3]) / n_centroids
  383. min_size = int(min_size_factor * segment_size)
  384. max_size = int(max_size_factor * segment_size)
  385. labels = _enforce_label_connectivity_cython(
  386. labels, min_size, max_size, start_label=start_label
  387. )
  388. if is_2d:
  389. labels = labels[0]
  390. return labels