_regular_grid.py 3.8 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114
  1. import numpy as np
  2. def regular_grid(ar_shape, n_points):
  3. """Find `n_points` regularly spaced along `ar_shape`.
  4. The returned points (as slices) should be as close to cubically-spaced as
  5. possible. Essentially, the points are spaced by the Nth root of the input
  6. array size, where N is the number of dimensions. However, if an array
  7. dimension cannot fit a full step size, it is "discarded", and the
  8. computation is done for only the remaining dimensions.
  9. Parameters
  10. ----------
  11. ar_shape : array-like of ints
  12. The shape of the space embedding the grid. ``len(ar_shape)`` is the
  13. number of dimensions.
  14. n_points : int
  15. The (approximate) number of points to embed in the space.
  16. Returns
  17. -------
  18. slices : tuple of slice objects
  19. A slice along each dimension of `ar_shape`, such that the intersection
  20. of all the slices give the coordinates of regularly spaced points.
  21. .. versionchanged:: 0.14.1
  22. In scikit-image 0.14.1 and 0.15, the return type was changed from a
  23. list to a tuple to ensure `compatibility with Numpy 1.15`_ and
  24. higher. If your code requires the returned result to be a list, you
  25. may convert the output of this function to a list with:
  26. >>> result = list(regular_grid(ar_shape=(3, 20, 40), n_points=8))
  27. .. _compatibility with NumPy 1.15: https://github.com/numpy/numpy/blob/master/doc/release/1.15.0-notes.rst#deprecations
  28. Examples
  29. --------
  30. >>> ar = np.zeros((20, 40))
  31. >>> g = regular_grid(ar.shape, 8)
  32. >>> g
  33. (slice(5, None, 10), slice(5, None, 10))
  34. >>> ar[g] = 1
  35. >>> ar.sum()
  36. 8.0
  37. >>> ar = np.zeros((20, 40))
  38. >>> g = regular_grid(ar.shape, 32)
  39. >>> g
  40. (slice(2, None, 5), slice(2, None, 5))
  41. >>> ar[g] = 1
  42. >>> ar.sum()
  43. 32.0
  44. >>> ar = np.zeros((3, 20, 40))
  45. >>> g = regular_grid(ar.shape, 8)
  46. >>> g
  47. (slice(1, None, 3), slice(5, None, 10), slice(5, None, 10))
  48. >>> ar[g] = 1
  49. >>> ar.sum()
  50. 8.0
  51. """
  52. ar_shape = np.asanyarray(ar_shape)
  53. ndim = len(ar_shape)
  54. unsort_dim_idxs = np.argsort(np.argsort(ar_shape))
  55. sorted_dims = np.sort(ar_shape)
  56. space_size = float(np.prod(ar_shape))
  57. if space_size <= n_points:
  58. return (slice(None),) * ndim
  59. stepsizes = np.full(ndim, (space_size / n_points) ** (1.0 / ndim), dtype='float64')
  60. if (sorted_dims < stepsizes).any():
  61. for dim in range(ndim):
  62. stepsizes[dim] = sorted_dims[dim]
  63. space_size = float(np.prod(sorted_dims[dim + 1 :]))
  64. stepsizes[dim + 1 :] = (space_size / n_points) ** (1.0 / (ndim - dim - 1))
  65. if (sorted_dims >= stepsizes).all():
  66. break
  67. starts = (stepsizes // 2).astype(int)
  68. stepsizes = np.round(stepsizes).astype(int)
  69. slices = [slice(start, None, step) for start, step in zip(starts, stepsizes)]
  70. slices = tuple(slices[i] for i in unsort_dim_idxs)
  71. return slices
  72. def regular_seeds(ar_shape, n_points, dtype=int):
  73. """Return an image with ~`n_points` regularly-spaced nonzero pixels.
  74. Parameters
  75. ----------
  76. ar_shape : tuple of int
  77. The shape of the desired output image.
  78. n_points : int
  79. The desired number of nonzero points.
  80. dtype : numpy data type, optional
  81. The desired data type of the output.
  82. Returns
  83. -------
  84. seed_img : array of int or bool
  85. The desired image.
  86. Examples
  87. --------
  88. >>> regular_seeds((5, 5), 4)
  89. array([[0, 0, 0, 0, 0],
  90. [0, 1, 0, 2, 0],
  91. [0, 0, 0, 0, 0],
  92. [0, 3, 0, 4, 0],
  93. [0, 0, 0, 0, 0]])
  94. """
  95. grid = regular_grid(ar_shape, n_points)
  96. seed_img = np.zeros(ar_shape, dtype=dtype)
  97. seed_img[grid] = 1 + np.reshape(
  98. np.arange(seed_img[grid].size), seed_img[grid].shape
  99. )
  100. return seed_img