random_walker_segmentation.py 21 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525526527528529530531532533534535536537538539540541542543544545546547548549550551552553554555556557558559560561562563564565566567568569570571572573574575576577578579580581582583584585586587588589590591592593
  1. """
  2. Random walker segmentation algorithm
  3. from *Random walks for image segmentation*, Leo Grady, IEEE Trans
  4. Pattern Anal Mach Intell. 2006 Nov;28(11):1768-83.
  5. Installing pyamg and using the 'cg_mg' mode of random_walker improves
  6. significantly the performance.
  7. """
  8. import numpy as np
  9. from scipy import sparse, ndimage as ndi
  10. from .._shared import utils
  11. from .._shared.utils import warn
  12. from .._shared.compat import SCIPY_CG_TOL_PARAM_NAME
  13. # executive summary for next code block: try to import umfpack from
  14. # scipy, but make sure not to raise a fuss if it fails since it's only
  15. # needed to speed up a few cases.
  16. # See discussions at:
  17. # https://groups.google.com/d/msg/scikit-image/FrM5IGP6wh4/1hp-FtVZmfcJ
  18. # https://stackoverflow.com/questions/13977970/ignore-exceptions-printed-to-stderr-in-del/13977992?noredirect=1#comment28386412_13977992
  19. try:
  20. from scipy.sparse.linalg.dsolve.linsolve import umfpack
  21. old_del = umfpack.UmfpackContext.__del__
  22. def new_del(self):
  23. try:
  24. old_del(self)
  25. except AttributeError:
  26. pass
  27. umfpack.UmfpackContext.__del__ = new_del
  28. UmfpackContext = umfpack.UmfpackContext()
  29. except ImportError:
  30. UmfpackContext = None
  31. try:
  32. from pyamg import ruge_stuben_solver
  33. amg_loaded = True
  34. except ImportError:
  35. amg_loaded = False
  36. from ..util import img_as_float
  37. from scipy.sparse.linalg import cg, spsolve
  38. def _make_graph_edges_3d(n_x, n_y, n_z):
  39. """Returns a list of edges for a 3D image.
  40. Parameters
  41. ----------
  42. n_x : integer
  43. The size of the grid in the x direction.
  44. n_y : integer
  45. The size of the grid in the y direction
  46. n_z : integer
  47. The size of the grid in the z direction
  48. Returns
  49. -------
  50. edges : (2, N) ndarray
  51. with the total number of edges::
  52. N = n_x * n_y * (nz - 1) +
  53. n_x * (n_y - 1) * nz +
  54. (n_x - 1) * n_y * nz
  55. Graph edges with each column describing a node-id pair.
  56. """
  57. vertices = np.arange(n_x * n_y * n_z).reshape((n_x, n_y, n_z))
  58. edges_deep = np.vstack((vertices[..., :-1].ravel(), vertices[..., 1:].ravel()))
  59. edges_right = np.vstack((vertices[:, :-1].ravel(), vertices[:, 1:].ravel()))
  60. edges_down = np.vstack((vertices[:-1].ravel(), vertices[1:].ravel()))
  61. edges = np.hstack((edges_deep, edges_right, edges_down))
  62. return edges
  63. def _compute_weights_3d(data, spacing, beta, eps, multichannel):
  64. # Weight calculation is main difference in multispectral version
  65. # Original gradient**2 replaced with sum of gradients ** 2
  66. gradients = (
  67. np.concatenate(
  68. [
  69. np.diff(data[..., 0], axis=ax).ravel() / spacing[ax]
  70. for ax in [2, 1, 0]
  71. if data.shape[ax] > 1
  72. ],
  73. axis=0,
  74. )
  75. ** 2
  76. )
  77. for channel in range(1, data.shape[-1]):
  78. gradients += (
  79. np.concatenate(
  80. [
  81. np.diff(data[..., channel], axis=ax).ravel() / spacing[ax]
  82. for ax in [2, 1, 0]
  83. if data.shape[ax] > 1
  84. ],
  85. axis=0,
  86. )
  87. ** 2
  88. )
  89. # All channels considered together in this standard deviation
  90. scale_factor = -beta / (10 * data.std())
  91. if multichannel:
  92. # New final term in beta to give == results in trivial case where
  93. # multiple identical spectra are passed.
  94. scale_factor /= np.sqrt(data.shape[-1])
  95. weights = np.exp(scale_factor * gradients)
  96. weights += eps
  97. return -weights
  98. def _build_laplacian(data, spacing, mask, beta, multichannel):
  99. l_x, l_y, l_z = data.shape[:3]
  100. edges = _make_graph_edges_3d(l_x, l_y, l_z)
  101. weights = _compute_weights_3d(
  102. data, spacing, beta=beta, eps=1.0e-10, multichannel=multichannel
  103. )
  104. if mask is not None:
  105. # Remove edges of the graph connected to masked nodes, as well
  106. # as corresponding weights of the edges.
  107. mask0 = np.hstack(
  108. [mask[..., :-1].ravel(), mask[:, :-1].ravel(), mask[:-1].ravel()]
  109. )
  110. mask1 = np.hstack(
  111. [mask[..., 1:].ravel(), mask[:, 1:].ravel(), mask[1:].ravel()]
  112. )
  113. ind_mask = np.logical_and(mask0, mask1)
  114. edges, weights = edges[:, ind_mask], weights[ind_mask]
  115. # Reassign edges labels to 0, 1, ... edges_number - 1
  116. _, inv_idx = np.unique(edges, return_inverse=True)
  117. edges = inv_idx.reshape(edges.shape)
  118. # Build the sparse linear system
  119. pixel_nb = l_x * l_y * l_z
  120. i_indices = edges.ravel()
  121. j_indices = edges[::-1].ravel()
  122. data = np.hstack((weights, weights))
  123. lap = sparse.csr_array((data, (i_indices, j_indices)), shape=(pixel_nb, pixel_nb))
  124. lap.setdiag(-np.ravel(lap.sum(axis=0)))
  125. return lap
  126. def _build_linear_system(data, spacing, labels, nlabels, mask, beta, multichannel):
  127. """
  128. Build the matrix A and rhs B of the linear system to solve.
  129. A and B are two block of the laplacian of the image graph.
  130. """
  131. if mask is None:
  132. labels = labels.ravel()
  133. else:
  134. labels = labels[mask]
  135. indices = np.arange(labels.size)
  136. seeds_mask = labels > 0
  137. unlabeled_indices = indices[~seeds_mask]
  138. seeds_indices = indices[seeds_mask]
  139. lap_sparse = _build_laplacian(
  140. data, spacing, mask=mask, beta=beta, multichannel=multichannel
  141. )
  142. rows = lap_sparse[unlabeled_indices, :]
  143. lap_sparse = rows[:, unlabeled_indices]
  144. B = -rows[:, seeds_indices]
  145. seeds = labels[seeds_mask]
  146. seeds_mask = sparse.csc_array(
  147. np.hstack([np.atleast_2d(seeds == lab).T for lab in range(1, nlabels + 1)])
  148. )
  149. rhs = B @ seeds_mask
  150. return lap_sparse, rhs
  151. def _solve_linear_system(lap_sparse, B, tol, mode):
  152. if mode is None:
  153. mode = 'cg_j'
  154. if mode == 'cg_mg' and not amg_loaded:
  155. warn(
  156. '"cg_mg" not available, it requires pyamg to be installed. '
  157. 'The "cg_j" mode will be used instead.',
  158. stacklevel=2,
  159. )
  160. mode = 'cg_j'
  161. if mode == 'bf':
  162. X = spsolve(lap_sparse, B.toarray()).T
  163. else:
  164. maxiter = None
  165. if mode == 'cg':
  166. if UmfpackContext is None:
  167. warn(
  168. '"cg" mode may be slow because UMFPACK is not available. '
  169. 'Consider building Scipy with UMFPACK or use a '
  170. 'preconditioned version of CG ("cg_j" or "cg_mg" modes).',
  171. stacklevel=2,
  172. )
  173. M = None
  174. elif mode == 'cg_j':
  175. n = lap_sparse.shape[-1]
  176. M = sparse.dia_array((1.0 / lap_sparse.diagonal(), 0), shape=(n, n))
  177. else:
  178. # mode == 'cg_mg'
  179. lap_sparse.indices, lap_sparse.indptr = _safe_downcast_indices(
  180. lap_sparse, np.int32, "index values too large for int32 mode 'cg_mg'"
  181. )
  182. ml = ruge_stuben_solver(lap_sparse, coarse_solver='pinv')
  183. M = ml.aspreconditioner(cycle='V')
  184. maxiter = 30
  185. rtol = {SCIPY_CG_TOL_PARAM_NAME: tol}
  186. cg_out = [
  187. cg(lap_sparse, B[:, [i]].toarray(), **rtol, atol=0, M=M, maxiter=maxiter)
  188. for i in range(B.shape[1])
  189. ]
  190. if np.any([info > 0 for _, info in cg_out]):
  191. warn(
  192. "Conjugate gradient convergence to tolerance not achieved. "
  193. "Consider decreasing beta to improve system conditionning.",
  194. stacklevel=2,
  195. )
  196. X = np.asarray([x for x, _ in cg_out])
  197. return X
  198. def _safe_downcast_indices(A, itype, msg):
  199. # check for safe downcasting
  200. max_value = np.iinfo(itype).max
  201. if A.indptr[-1] > max_value: # indptr[-1] is max b/c indptr always sorted
  202. raise ValueError(msg)
  203. if max(*A.shape) > max_value: # only check large enough arrays
  204. if np.any(A.indices > max_value):
  205. raise ValueError(msg)
  206. indices = A.indices.astype(itype, copy=False)
  207. indptr = A.indptr.astype(itype, copy=False)
  208. return indices, indptr
  209. def _preprocess(labels):
  210. label_values, inv_idx = np.unique(labels, return_inverse=True)
  211. if max(label_values) <= 0:
  212. raise ValueError(
  213. 'No seeds provided in label image: please ensure '
  214. 'it contains at least one positive value'
  215. )
  216. if not (label_values == 0).any():
  217. warn(
  218. 'Random walker only segments unlabeled areas, where '
  219. 'labels == 0. No zero valued areas in labels were '
  220. 'found. Returning provided labels.',
  221. stacklevel=2,
  222. )
  223. return labels, None, None, None, None
  224. # If some labeled pixels are isolated inside pruned zones, prune them
  225. # as well and keep the labels for the final output
  226. null_mask = labels == 0
  227. pos_mask = labels > 0
  228. mask = labels >= 0
  229. fill = ndi.binary_propagation(null_mask, mask=mask)
  230. isolated = np.logical_and(pos_mask, np.logical_not(fill))
  231. pos_mask[isolated] = False
  232. # If the array has pruned zones, be sure that no isolated pixels
  233. # exist between pruned zones (they could not be determined)
  234. if label_values[0] < 0 or np.any(isolated):
  235. isolated = np.logical_and(
  236. np.logical_not(ndi.binary_propagation(pos_mask, mask=mask)), null_mask
  237. )
  238. labels[isolated] = -1
  239. if np.all(isolated[null_mask]):
  240. warn(
  241. 'All unlabeled pixels are isolated, they could not be '
  242. 'determined by the random walker algorithm.',
  243. stacklevel=2,
  244. )
  245. return labels, None, None, None, None
  246. mask[isolated] = False
  247. mask = np.atleast_3d(mask)
  248. else:
  249. mask = None
  250. # Reorder label values to have consecutive integers (no gaps)
  251. zero_idx = np.searchsorted(label_values, 0)
  252. labels = np.atleast_3d(inv_idx.reshape(labels.shape) - zero_idx)
  253. nlabels = label_values[zero_idx + 1 :].shape[0]
  254. inds_isolated_seeds = np.nonzero(isolated)
  255. isolated_values = labels[inds_isolated_seeds]
  256. return labels, nlabels, mask, inds_isolated_seeds, isolated_values
  257. @utils.channel_as_last_axis(multichannel_output=False)
  258. def random_walker(
  259. data,
  260. labels,
  261. beta=130,
  262. mode='cg_j',
  263. tol=1.0e-3,
  264. copy=True,
  265. return_full_prob=False,
  266. spacing=None,
  267. *,
  268. prob_tol=1e-3,
  269. channel_axis=None,
  270. ):
  271. """Random walker algorithm for segmentation from markers.
  272. Random walker algorithm is implemented for gray-level or multichannel
  273. images.
  274. Parameters
  275. ----------
  276. data : (M, N[, P][, C]) ndarray
  277. Image to be segmented in phases. Gray-level `data` can be two- or
  278. three-dimensional; multichannel data can be three- or four-
  279. dimensional with `channel_axis` specifying the dimension containing
  280. channels. Data spacing is assumed isotropic unless the `spacing`
  281. keyword argument is used.
  282. labels : (M, N[, P]) array of ints
  283. Array of seed markers labeled with different positive integers
  284. for different phases. Zero-labeled pixels are unlabeled pixels.
  285. Negative labels correspond to inactive pixels that are not taken
  286. into account (they are removed from the graph). If labels are not
  287. consecutive integers, the labels array will be transformed so that
  288. labels are consecutive. In the multichannel case, `labels` should have
  289. the same shape as a single channel of `data`, i.e. without the final
  290. dimension denoting channels.
  291. beta : float, optional
  292. Penalization coefficient for the random walker motion
  293. (the greater `beta`, the more difficult the diffusion).
  294. mode : string, available options {'cg', 'cg_j', 'cg_mg', 'bf'}
  295. Mode for solving the linear system in the random walker algorithm.
  296. - 'bf' (brute force): an LU factorization of the Laplacian is
  297. computed. This is fast for small images (<1024x1024), but very slow
  298. and memory-intensive for large images (e.g., 3-D volumes).
  299. - 'cg' (conjugate gradient): the linear system is solved iteratively
  300. using the Conjugate Gradient method from scipy.sparse.linalg. This is
  301. less memory-consuming than the brute force method for large images,
  302. but it is quite slow.
  303. - 'cg_j' (conjugate gradient with Jacobi preconditionner): the
  304. Jacobi preconditionner is applied during the Conjugate
  305. gradient method iterations. This may accelerate the
  306. convergence of the 'cg' method.
  307. - 'cg_mg' (conjugate gradient with multigrid preconditioner): a
  308. preconditioner is computed using a multigrid solver, then the
  309. solution is computed with the Conjugate Gradient method. This mode
  310. requires that the pyamg module is installed.
  311. tol : float, optional
  312. Tolerance to achieve when solving the linear system using
  313. the conjugate gradient based modes ('cg', 'cg_j' and 'cg_mg').
  314. copy : bool, optional
  315. If copy is False, the `labels` array will be overwritten with
  316. the result of the segmentation. Use copy=False if you want to
  317. save on memory.
  318. return_full_prob : bool, optional
  319. If True, the probability that a pixel belongs to each of the
  320. labels will be returned, instead of only the most likely
  321. label.
  322. spacing : iterable of floats, optional
  323. Spacing between voxels in each spatial dimension. If `None`, then
  324. the spacing between pixels/voxels in each dimension is assumed 1.
  325. prob_tol : float, optional
  326. Tolerance on the resulting probability to be in the interval [0, 1].
  327. If the tolerance is not satisfied, a warning is displayed.
  328. channel_axis : int or None, optional
  329. If None, the image is assumed to be a grayscale (single channel) image.
  330. Otherwise, this parameter indicates which axis of the array corresponds
  331. to channels.
  332. .. versionadded:: 0.19
  333. ``channel_axis`` was added in 0.19.
  334. Returns
  335. -------
  336. output : ndarray
  337. * If `return_full_prob` is False, array of ints of same shape
  338. and data type as `labels`, in which each pixel has been
  339. labeled according to the marker that reached the pixel first
  340. by anisotropic diffusion.
  341. * If `return_full_prob` is True, array of floats of shape
  342. `(nlabels, labels.shape)`. `output[label_nb, i, j]` is the
  343. probability that label `label_nb` reaches the pixel `(i, j)`
  344. first.
  345. See Also
  346. --------
  347. skimage.segmentation.watershed
  348. A segmentation algorithm based on mathematical morphology
  349. and "flooding" of regions from markers.
  350. Notes
  351. -----
  352. Multichannel inputs are scaled with all channel data combined. Ensure all
  353. channels are separately normalized prior to running this algorithm.
  354. The `spacing` argument is specifically for anisotropic datasets, where
  355. data points are spaced differently in one or more spatial dimensions.
  356. Anisotropic data is commonly encountered in medical imaging.
  357. The algorithm was first proposed in [1]_.
  358. The algorithm solves the diffusion equation at infinite times for
  359. sources placed on markers of each phase in turn. A pixel is labeled with
  360. the phase that has the greatest probability to diffuse first to the pixel.
  361. The diffusion equation is solved by minimizing x.T L x for each phase,
  362. where L is the Laplacian of the weighted graph of the image, and x is
  363. the probability that a marker of the given phase arrives first at a pixel
  364. by diffusion (x=1 on markers of the phase, x=0 on the other markers, and
  365. the other coefficients are looked for). Each pixel is attributed the label
  366. for which it has a maximal value of x. The Laplacian L of the image
  367. is defined as:
  368. - L_ii = d_i, the number of neighbors of pixel i (the degree of i)
  369. - L_ij = -w_ij if i and j are adjacent pixels
  370. The weight w_ij is a decreasing function of the norm of the local gradient.
  371. This ensures that diffusion is easier between pixels of similar values.
  372. When the Laplacian is decomposed into blocks of marked and unmarked
  373. pixels::
  374. L = M B.T
  375. B A
  376. with first indices corresponding to marked pixels, and then to unmarked
  377. pixels, minimizing x.T L x for one phase amount to solving::
  378. A x = - B x_m
  379. where x_m = 1 on markers of the given phase, and 0 on other markers.
  380. This linear system is solved in the algorithm using a direct method for
  381. small images, and an iterative method for larger images.
  382. References
  383. ----------
  384. .. [1] Leo Grady, Random walks for image segmentation, IEEE Trans Pattern
  385. Anal Mach Intell. 2006 Nov;28(11):1768-83.
  386. :DOI:`10.1109/TPAMI.2006.233`.
  387. Examples
  388. --------
  389. >>> rng = np.random.default_rng()
  390. >>> a = np.zeros((10, 10)) + 0.2 * rng.random((10, 10))
  391. >>> a[5:8, 5:8] += 1
  392. >>> b = np.zeros_like(a, dtype=np.int32)
  393. >>> b[3, 3] = 1 # Marker for first phase
  394. >>> b[6, 6] = 2 # Marker for second phase
  395. >>> random_walker(a, b) # doctest: +SKIP
  396. array([[1, 1, 1, 1, 1, 1, 1, 1, 1, 1],
  397. [1, 1, 1, 1, 1, 1, 1, 1, 1, 1],
  398. [1, 1, 1, 1, 1, 1, 1, 1, 1, 1],
  399. [1, 1, 1, 1, 1, 1, 1, 1, 1, 1],
  400. [1, 1, 1, 1, 1, 1, 1, 1, 1, 1],
  401. [1, 1, 1, 1, 1, 2, 2, 2, 1, 1],
  402. [1, 1, 1, 1, 1, 2, 2, 2, 1, 1],
  403. [1, 1, 1, 1, 1, 2, 2, 2, 1, 1],
  404. [1, 1, 1, 1, 1, 1, 1, 1, 1, 1],
  405. [1, 1, 1, 1, 1, 1, 1, 1, 1, 1]], dtype=int32)
  406. """
  407. # Parse input data
  408. if mode not in ('cg_mg', 'cg', 'bf', 'cg_j', None):
  409. raise ValueError(
  410. f"{mode} is not a valid mode. Valid modes are 'cg_mg', "
  411. f"'cg', 'cg_j', 'bf', and None"
  412. )
  413. if data.dtype == np.float16:
  414. # SciPy sparse, which is used later on, doesn't officially support float16
  415. # This led to failures when testing with NumPy 1.26 (see gh-7635).
  416. data = data.astype(np.float32, casting="safe")
  417. # Spacing kwarg checks
  418. if spacing is None:
  419. spacing = np.ones(3)
  420. elif len(spacing) == labels.ndim:
  421. if len(spacing) == 2:
  422. # Need a dummy spacing for singleton 3rd dim
  423. spacing = np.r_[spacing, 1.0]
  424. spacing = np.asarray(spacing)
  425. else:
  426. raise ValueError(
  427. 'Input argument `spacing` incorrect, should be an '
  428. 'iterable with one number per spatial dimension.'
  429. )
  430. # This algorithm expects 4-D arrays of floats, where the first three
  431. # dimensions are spatial and the final denotes channels. 2-D images have
  432. # a singleton placeholder dimension added for the third spatial dimension,
  433. # and single channel images likewise have a singleton added for channels.
  434. # The following block ensures valid input and coerces it to the correct
  435. # form.
  436. multichannel = channel_axis is not None
  437. if not multichannel:
  438. if data.ndim not in (2, 3):
  439. raise ValueError(
  440. 'For non-multichannel input, data must be of ' 'dimension 2 or 3.'
  441. )
  442. if data.shape != labels.shape:
  443. raise ValueError('Incompatible data and labels shapes.')
  444. data = np.atleast_3d(img_as_float(data))[..., np.newaxis]
  445. else:
  446. if data.ndim not in (3, 4):
  447. raise ValueError(
  448. 'For multichannel input, data must have 3 or 4 ' 'dimensions.'
  449. )
  450. if data.shape[:-1] != labels.shape:
  451. raise ValueError('Incompatible data and labels shapes.')
  452. data = img_as_float(data)
  453. if data.ndim == 3: # 2D multispectral, needs singleton in 3rd axis
  454. data = data[:, :, np.newaxis, :]
  455. labels_shape = labels.shape
  456. labels_dtype = labels.dtype
  457. if copy:
  458. labels = np.copy(labels)
  459. (labels, nlabels, mask, inds_isolated_seeds, isolated_values) = _preprocess(labels)
  460. if isolated_values is None:
  461. # No non isolated zero valued areas in labels were
  462. # found. Returning provided labels.
  463. if return_full_prob:
  464. # Return the concatenation of the masks of each unique label
  465. return np.concatenate(
  466. [np.atleast_3d(labels == lab) for lab in np.unique(labels) if lab > 0],
  467. axis=-1,
  468. )
  469. return labels
  470. # Build the linear system (lap_sparse, B)
  471. lap_sparse, B = _build_linear_system(
  472. data, spacing, labels, nlabels, mask, beta, multichannel
  473. )
  474. # Solve the linear system lap_sparse X = B
  475. # where X[i, j] is the probability that a marker of label i arrives
  476. # first at pixel j by anisotropic diffusion.
  477. X = _solve_linear_system(lap_sparse, B, tol, mode)
  478. if X.min() < -prob_tol or X.max() > 1 + prob_tol:
  479. warn(
  480. 'The probability range is outside [0, 1] given the tolerance '
  481. '`prob_tol`. Consider decreasing `beta` and/or decreasing '
  482. '`tol`.'
  483. )
  484. # Build the output according to return_full_prob value
  485. # Put back labels of isolated seeds
  486. labels[inds_isolated_seeds] = isolated_values
  487. labels = labels.reshape(labels_shape)
  488. mask = labels == 0
  489. mask[inds_isolated_seeds] = False
  490. if return_full_prob:
  491. out = np.zeros((nlabels,) + labels_shape)
  492. for lab, (label_prob, prob) in enumerate(zip(out, X), start=1):
  493. label_prob[mask] = prob
  494. label_prob[labels == lab] = 1
  495. else:
  496. X = np.argmax(X, axis=0) + 1
  497. out = labels.astype(labels_dtype)
  498. out[mask] = X
  499. return out