_optical_flow_utils.py 3.6 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150
  1. """Common tools to optical flow algorithms."""
  2. import numpy as np
  3. from scipy import ndimage as ndi
  4. from ..transform import pyramid_reduce
  5. from ..util.dtype import _convert
  6. def _get_warp_points(grid, flow):
  7. """Compute warp point coordinates.
  8. Parameters
  9. ----------
  10. grid : iterable
  11. The sparse grid to be warped (obtained using
  12. ``np.meshgrid(..., sparse=True)).``)
  13. flow : ndarray
  14. The warping motion field.
  15. Returns
  16. -------
  17. out : ndarray
  18. The warp point coordinates.
  19. """
  20. out = flow.copy()
  21. for idx, g in enumerate(grid):
  22. out[idx, ...] += g
  23. return out
  24. def _resize_flow(flow, shape):
  25. """Rescale the values of the vector field (u, v) to the desired shape.
  26. The values of the output vector field are scaled to the new
  27. resolution.
  28. Parameters
  29. ----------
  30. flow : ndarray
  31. The motion field to be processed.
  32. shape : iterable
  33. Couple of integers representing the output shape.
  34. Returns
  35. -------
  36. rflow : ndarray
  37. The resized and rescaled motion field.
  38. """
  39. scale = [n / o for n, o in zip(shape, flow.shape[1:])]
  40. scale_factor = np.array(scale, dtype=flow.dtype)
  41. for _ in shape:
  42. scale_factor = scale_factor[..., np.newaxis]
  43. rflow = scale_factor * ndi.zoom(
  44. flow, [1] + scale, order=0, mode='nearest', prefilter=False
  45. )
  46. return rflow
  47. def _get_pyramid(I, downscale=2.0, nlevel=10, min_size=16):
  48. """Construct image pyramid.
  49. Parameters
  50. ----------
  51. I : ndarray
  52. The image to be preprocessed (Grayscale or RGB).
  53. downscale : float
  54. The pyramid downscale factor.
  55. nlevel : int
  56. The maximum number of pyramid levels.
  57. min_size : int
  58. The minimum size for any dimension of the pyramid levels.
  59. Returns
  60. -------
  61. pyramid : list[ndarray]
  62. The coarse to fine images pyramid.
  63. """
  64. pyramid = [I]
  65. size = min(I.shape)
  66. count = 1
  67. while (count < nlevel) and (size > downscale * min_size):
  68. J = pyramid_reduce(pyramid[-1], downscale, channel_axis=None)
  69. pyramid.append(J)
  70. size = min(J.shape)
  71. count += 1
  72. return pyramid[::-1]
  73. def _coarse_to_fine(
  74. I0, I1, solver, downscale=2, nlevel=10, min_size=16, dtype=np.float32
  75. ):
  76. """Generic coarse to fine solver.
  77. Parameters
  78. ----------
  79. I0 : ndarray
  80. The first grayscale image of the sequence.
  81. I1 : ndarray
  82. The second grayscale image of the sequence.
  83. solver : callable
  84. The solver applied at each pyramid level.
  85. downscale : float
  86. The pyramid downscale factor.
  87. nlevel : int
  88. The maximum number of pyramid levels.
  89. min_size : int
  90. The minimum size for any dimension of the pyramid levels.
  91. dtype : dtype
  92. Output data type.
  93. Returns
  94. -------
  95. flow : ndarray
  96. The estimated optical flow components for each axis.
  97. """
  98. if I0.shape != I1.shape:
  99. raise ValueError("Input images should have the same shape")
  100. if np.dtype(dtype).char not in 'efdg':
  101. raise ValueError("Only floating point data type are valid" " for optical flow")
  102. pyramid = list(
  103. zip(
  104. _get_pyramid(_convert(I0, dtype), downscale, nlevel, min_size),
  105. _get_pyramid(_convert(I1, dtype), downscale, nlevel, min_size),
  106. )
  107. )
  108. # Initialization to 0 at coarsest level.
  109. flow = np.zeros((pyramid[0][0].ndim,) + pyramid[0][0].shape, dtype=dtype)
  110. flow = solver(pyramid[0][0], pyramid[0][1], flow)
  111. for J0, J1 in pyramid[1:]:
  112. flow = solver(J0, J1, _resize_flow(flow, J0.shape))
  113. return flow