spath.py 3.3 KB

1234567891011121314151617181920212223242526272829303132333435363738394041424344454647484950515253545556575859606162636465666768697071727374757677787980818283
  1. import numpy as np
  2. from . import _spath
  3. def shortest_path(arr, reach=1, axis=-1, output_indexlist=False):
  4. """Find the shortest path through an n-d array from one side to another.
  5. Parameters
  6. ----------
  7. arr : ndarray of float64
  8. reach : int, optional
  9. By default (``reach = 1``), the shortest path can only move
  10. one row up or down for every step it moves forward (i.e.,
  11. the path gradient is limited to 1). `reach` defines the
  12. number of elements that can be skipped along each non-axis
  13. dimension at each step.
  14. axis : int, optional
  15. The axis along which the path must always move forward (default -1)
  16. output_indexlist : bool, optional
  17. See return value `p` for explanation.
  18. Returns
  19. -------
  20. p : iterable of int
  21. For each step along `axis`, the coordinate of the shortest path.
  22. If `output_indexlist` is True, then the path is returned as a list of
  23. n-d tuples that index into `arr`. If False, then the path is returned
  24. as an array listing the coordinates of the path along the non-axis
  25. dimensions for each step along the axis dimension. That is,
  26. `p.shape == (arr.shape[axis], arr.ndim-1)` except that p is squeezed
  27. before returning so if `arr.ndim == 2`, then
  28. `p.shape == (arr.shape[axis],)`
  29. cost : float
  30. Cost of path. This is the absolute sum of all the
  31. differences along the path.
  32. """
  33. # First: calculate the valid moves from any given position. Basically,
  34. # always move +1 along the given axis, and then can move anywhere within
  35. # a grid defined by the reach.
  36. if axis < 0:
  37. axis += arr.ndim
  38. offset_ind_shape = (2 * reach + 1,) * (arr.ndim - 1)
  39. offset_indices = np.indices(offset_ind_shape) - reach
  40. offset_indices = np.insert(offset_indices, axis, np.ones(offset_ind_shape), axis=0)
  41. offset_size = np.multiply.reduce(offset_ind_shape)
  42. offsets = np.reshape(offset_indices, (arr.ndim, offset_size), order='F').T
  43. # Valid starting positions are anywhere on the hyperplane defined by
  44. # position 0 on the given axis. Ending positions are anywhere on the
  45. # hyperplane at position -1 along the same.
  46. non_axis_shape = arr.shape[:axis] + arr.shape[axis + 1 :]
  47. non_axis_indices = np.indices(non_axis_shape)
  48. non_axis_size = np.multiply.reduce(non_axis_shape)
  49. start_indices = np.insert(non_axis_indices, axis, np.zeros(non_axis_shape), axis=0)
  50. starts = np.reshape(start_indices, (arr.ndim, non_axis_size), order='F').T
  51. end_indices = np.insert(
  52. non_axis_indices,
  53. axis,
  54. np.full(non_axis_shape, -1, dtype=non_axis_indices.dtype),
  55. axis=0,
  56. )
  57. ends = np.reshape(end_indices, (arr.ndim, non_axis_size), order='F').T
  58. # Find the minimum-cost path to one of the end-points
  59. m = _spath.MCP_Diff(arr, offsets=offsets)
  60. costs, traceback = m.find_costs(starts, ends, find_all_ends=False)
  61. # Figure out which end-point was found
  62. for end in ends:
  63. cost = costs[tuple(end)]
  64. if cost != np.inf:
  65. break
  66. traceback = m.traceback(end)
  67. if not output_indexlist:
  68. traceback = np.array(traceback)
  69. traceback = np.concatenate(
  70. [traceback[:, :axis], traceback[:, axis + 1 :]], axis=1
  71. )
  72. traceback = np.squeeze(traceback)
  73. return traceback, cost