coord.py 4.2 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124
  1. import numpy as np
  2. from scipy.spatial import cKDTree, distance
  3. def _ensure_spacing(coord, spacing, p_norm, max_out):
  4. """Returns a subset of coord where a minimum spacing is guaranteed.
  5. Parameters
  6. ----------
  7. coord : ndarray
  8. The coordinates of the considered points.
  9. spacing : float
  10. the maximum allowed spacing between the points.
  11. p_norm : float
  12. Which Minkowski p-norm to use. Should be in the range [1, inf].
  13. A finite large p may cause a ValueError if overflow can occur.
  14. ``inf`` corresponds to the Chebyshev distance and 2 to the
  15. Euclidean distance.
  16. max_out : int
  17. If not None, at most the first ``max_out`` candidates are
  18. returned.
  19. Returns
  20. -------
  21. output : ndarray
  22. A subset of coord where a minimum spacing is guaranteed.
  23. """
  24. # Use KDtree to find the peaks that are too close to each other
  25. tree = cKDTree(coord)
  26. indices = tree.query_ball_point(coord, r=spacing, p=p_norm)
  27. rejected_peaks_indices = set()
  28. naccepted = 0
  29. for idx, candidates in enumerate(indices):
  30. if idx not in rejected_peaks_indices:
  31. # keep current point and the points at exactly spacing from it
  32. candidates.remove(idx)
  33. dist = distance.cdist(
  34. [coord[idx]], coord[candidates], "minkowski", p=p_norm
  35. ).reshape(-1)
  36. candidates = [c for c, d in zip(candidates, dist) if d < spacing]
  37. # candidates.remove(keep)
  38. rejected_peaks_indices.update(candidates)
  39. naccepted += 1
  40. if max_out is not None and naccepted >= max_out:
  41. break
  42. # Remove the peaks that are too close to each other
  43. output = np.delete(coord, tuple(rejected_peaks_indices), axis=0)
  44. if max_out is not None:
  45. output = output[:max_out]
  46. return output
  47. def ensure_spacing(
  48. coords,
  49. spacing=1,
  50. p_norm=np.inf,
  51. min_split_size=50,
  52. max_out=None,
  53. *,
  54. max_split_size=2000,
  55. ):
  56. """Returns a subset of coord where a minimum spacing is guaranteed.
  57. Parameters
  58. ----------
  59. coords : array_like
  60. The coordinates of the considered points.
  61. spacing : float
  62. the maximum allowed spacing between the points.
  63. p_norm : float
  64. Which Minkowski p-norm to use. Should be in the range [1, inf].
  65. A finite large p may cause a ValueError if overflow can occur.
  66. ``inf`` corresponds to the Chebyshev distance and 2 to the
  67. Euclidean distance.
  68. min_split_size : int
  69. Minimum split size used to process ``coords`` by batch to save
  70. memory. If None, the memory saving strategy is not applied.
  71. max_out : int
  72. If not None, only the first ``max_out`` candidates are returned.
  73. max_split_size : int
  74. Maximum split size used to process ``coords`` by batch to save
  75. memory. This number was decided by profiling with a large number
  76. of points. Too small a number results in too much looping in
  77. Python instead of C, slowing down the process, while too large
  78. a number results in large memory allocations, slowdowns, and,
  79. potentially, in the process being killed -- see gh-6010. See
  80. benchmark results `here
  81. <https://github.com/scikit-image/scikit-image/pull/6035#discussion_r751518691>`_.
  82. Returns
  83. -------
  84. output : array_like
  85. A subset of coord where a minimum spacing is guaranteed.
  86. """
  87. output = coords
  88. if len(coords):
  89. coords = np.atleast_2d(coords)
  90. if min_split_size is None:
  91. batch_list = [coords]
  92. else:
  93. coord_count = len(coords)
  94. split_idx = [min_split_size]
  95. split_size = min_split_size
  96. while coord_count - split_idx[-1] > max_split_size:
  97. split_size *= 2
  98. split_idx.append(split_idx[-1] + min(split_size, max_split_size))
  99. batch_list = np.array_split(coords, split_idx)
  100. output = np.zeros((0, coords.shape[1]), dtype=coords.dtype)
  101. for batch in batch_list:
  102. output = _ensure_spacing(
  103. np.vstack([output, batch]), spacing, p_norm, max_out
  104. )
  105. if max_out is not None and len(output) >= max_out:
  106. break
  107. return output