inpaint.py 12 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340
  1. import numpy as np
  2. from scipy import sparse
  3. from scipy.sparse.linalg import spsolve
  4. import scipy.ndimage as ndi
  5. from scipy.ndimage import laplace
  6. import skimage
  7. from .._shared import utils
  8. from ..measure import label
  9. from ._inpaint import _build_matrix_inner
  10. def _get_neighborhood(nd_idx, radius, nd_shape):
  11. bounds_lo = np.maximum(nd_idx - radius, 0)
  12. bounds_hi = np.minimum(nd_idx + radius + 1, nd_shape)
  13. return bounds_lo, bounds_hi
  14. def _get_neigh_coef(shape, center, dtype=float):
  15. # Create biharmonic coefficients ndarray
  16. neigh_coef = np.zeros(shape, dtype=dtype)
  17. neigh_coef[center] = 1
  18. neigh_coef = laplace(laplace(neigh_coef))
  19. # extract non-zero locations and values
  20. coef_idx = np.where(neigh_coef)
  21. coef_vals = neigh_coef[coef_idx]
  22. coef_idx = np.stack(coef_idx, axis=0)
  23. return neigh_coef, coef_idx, coef_vals
  24. def _inpaint_biharmonic_single_region(
  25. image, mask, out, neigh_coef_full, coef_vals, raveled_offsets
  26. ):
  27. """Solve a (sparse) linear system corresponding to biharmonic inpainting.
  28. This function creates a linear system of the form:
  29. ``A @ u = b``
  30. where ``A`` is a sparse matrix, ``b`` is a vector enforcing smoothness and
  31. boundary constraints and ``u`` is the vector of inpainted values to be
  32. (uniquely) determined by solving the linear system.
  33. ``A`` is a sparse matrix of shape (n_mask, n_mask) where ``n_mask``
  34. corresponds to the number of non-zero values in ``mask`` (i.e. the number
  35. of pixels to be inpainted). Each row in A will have a number of non-zero
  36. values equal to the number of non-zero values in the biharmonic kernel,
  37. ``neigh_coef_full``. In practice, biharmonic kernels with reduced extent
  38. are used at the image borders. This matrix, ``A`` is the same for all
  39. image channels (since the same inpainting mask is currently used for all
  40. channels).
  41. ``u`` is a dense matrix of shape ``(n_mask, n_channels)`` and represents
  42. the vector of unknown values for each channel.
  43. ``b`` is a dense matrix of shape ``(n_mask, n_channels)`` and represents
  44. the desired output of convolving the solution with the biharmonic kernel.
  45. At mask locations where there is no overlap with known values, ``b`` will
  46. have a value of 0. This enforces the biharmonic smoothness constraint in
  47. the interior of inpainting regions. For regions near the boundary that
  48. overlap with known values, the entries in ``b`` enforce boundary conditions
  49. designed to avoid discontinuity with the known values.
  50. """
  51. n_channels = out.shape[-1]
  52. radius = neigh_coef_full.shape[0] // 2
  53. edge_mask = np.ones(mask.shape, dtype=bool)
  54. edge_mask[(slice(radius, -radius),) * mask.ndim] = 0
  55. boundary_mask = edge_mask * mask
  56. center_mask = ~edge_mask * mask
  57. boundary_pts = np.where(boundary_mask)
  58. boundary_i = np.flatnonzero(boundary_mask)
  59. center_i = np.flatnonzero(center_mask)
  60. mask_i = np.concatenate((boundary_i, center_i))
  61. center_pts = np.where(center_mask)
  62. mask_pts = tuple([np.concatenate((b, c)) for b, c in zip(boundary_pts, center_pts)])
  63. # Use convolution to predetermine the number of non-zero entries in the
  64. # sparse system matrix.
  65. structure = neigh_coef_full != 0
  66. tmp = ndi.convolve(mask, structure, output=np.uint8, mode='constant')
  67. nnz_matrix = tmp[mask].sum()
  68. # Need to estimate the number of zeros for the right hand side vector.
  69. # The computation below will slightly overestimate the true number of zeros
  70. # due to edge effects (the kernel itself gets shrunk in size near the
  71. # edges, but that isn't accounted for here). We can trim any excess entries
  72. # later.
  73. n_mask = np.count_nonzero(mask)
  74. n_struct = np.count_nonzero(structure)
  75. nnz_rhs_vector_max = n_mask - np.count_nonzero(tmp == n_struct)
  76. # pre-allocate arrays storing sparse matrix indices and values
  77. row_idx_known = np.empty(nnz_rhs_vector_max, dtype=np.intp)
  78. data_known = np.zeros((nnz_rhs_vector_max, n_channels), dtype=out.dtype)
  79. row_idx_unknown = np.empty(nnz_matrix, dtype=np.intp)
  80. col_idx_unknown = np.empty(nnz_matrix, dtype=np.intp)
  81. data_unknown = np.empty(nnz_matrix, dtype=out.dtype)
  82. # cache the various small, non-square Laplacians used near the boundary
  83. coef_cache = {}
  84. # Iterate over masked points near the boundary
  85. mask_flat = mask.reshape(-1)
  86. out_flat = np.ascontiguousarray(out.reshape((-1, n_channels)))
  87. idx_known = 0
  88. idx_unknown = 0
  89. mask_pt_n = -1
  90. boundary_pts = np.stack(boundary_pts, axis=1)
  91. for mask_pt_n, nd_idx in enumerate(boundary_pts):
  92. # Get bounded neighborhood of selected radius
  93. b_lo, b_hi = _get_neighborhood(nd_idx, radius, mask.shape)
  94. # Create (truncated) biharmonic coefficients ndarray
  95. coef_shape = tuple(b_hi - b_lo)
  96. coef_center = tuple(nd_idx - b_lo)
  97. coef_idx, coefs = coef_cache.get((coef_shape, coef_center), (None, None))
  98. if coef_idx is None:
  99. _, coef_idx, coefs = _get_neigh_coef(
  100. coef_shape, coef_center, dtype=out.dtype
  101. )
  102. coef_cache[(coef_shape, coef_center)] = (coef_idx, coefs)
  103. # compute corresponding 1d indices into the mask
  104. coef_idx = coef_idx + b_lo[:, np.newaxis]
  105. index1d = np.ravel_multi_index(coef_idx, mask.shape)
  106. # Iterate over masked point's neighborhood
  107. nvals = 0
  108. for coef, i in zip(coefs, index1d):
  109. if mask_flat[i]:
  110. row_idx_unknown[idx_unknown] = mask_pt_n
  111. col_idx_unknown[idx_unknown] = i
  112. data_unknown[idx_unknown] = coef
  113. idx_unknown += 1
  114. else:
  115. data_known[idx_known, :] -= coef * out_flat[i, :]
  116. nvals += 1
  117. if nvals:
  118. row_idx_known[idx_known] = mask_pt_n
  119. idx_known += 1
  120. # Call an efficient Cython-based implementation for all interior points
  121. row_start = mask_pt_n + 1
  122. known_start_idx = idx_known
  123. unknown_start_idx = idx_unknown
  124. nnz_rhs = _build_matrix_inner(
  125. # starting indices
  126. row_start,
  127. known_start_idx,
  128. unknown_start_idx,
  129. # input arrays
  130. center_i,
  131. raveled_offsets,
  132. coef_vals,
  133. mask_flat,
  134. out_flat,
  135. # output arrays
  136. row_idx_known,
  137. data_known,
  138. row_idx_unknown,
  139. col_idx_unknown,
  140. data_unknown,
  141. )
  142. # trim RHS vector values and indices to the exact length
  143. row_idx_known = row_idx_known[:nnz_rhs]
  144. data_known = data_known[:nnz_rhs, :]
  145. # Form sparse matrix of unknown values
  146. sp_shape = (n_mask, out.size)
  147. matrix_unknown = sparse.csr_array(
  148. (data_unknown, (row_idx_unknown, col_idx_unknown)), shape=sp_shape
  149. )
  150. # Solve linear system for masked points
  151. matrix_unknown = matrix_unknown[:, mask_i]
  152. # dense vectors representing the right hand side for each channel
  153. rhs = np.zeros((n_mask, n_channels), dtype=out.dtype)
  154. rhs[row_idx_known, :] = data_known
  155. # set use_umfpack to False so float32 data is supported
  156. result = spsolve(matrix_unknown, rhs, use_umfpack=False, permc_spec='MMD_ATA')
  157. if result.ndim == 1:
  158. result = result[:, np.newaxis]
  159. out[mask_pts] = result
  160. return out
  161. @utils.channel_as_last_axis()
  162. def inpaint_biharmonic(image, mask, *, split_into_regions=False, channel_axis=None):
  163. """Inpaint masked points in image with biharmonic equations.
  164. Parameters
  165. ----------
  166. image : (M[, N[, ..., P]][, C]) ndarray
  167. Input image.
  168. mask : (M[, N[, ..., P]]) ndarray
  169. Array of pixels to be inpainted. Have to be the same shape as one
  170. of the 'image' channels. Unknown pixels have to be represented with 1,
  171. known pixels - with 0.
  172. split_into_regions : bool, optional
  173. If True, inpainting is performed on a region-by-region basis. This is
  174. likely to be slower, but will have reduced memory requirements.
  175. channel_axis : int or None, optional
  176. If None, the image is assumed to be a grayscale (single channel) image.
  177. Otherwise, this parameter indicates which axis of the array corresponds
  178. to channels.
  179. .. versionadded:: 0.19
  180. ``channel_axis`` was added in 0.19.
  181. Returns
  182. -------
  183. out : (M[, N[, ..., P]][, C]) ndarray
  184. Input image with masked pixels inpainted.
  185. References
  186. ----------
  187. .. [1] S.B.Damelin and N.S.Hoang. "On Surface Completion and Image
  188. Inpainting by Biharmonic Functions: Numerical Aspects",
  189. International Journal of Mathematics and Mathematical Sciences,
  190. Vol. 2018, Article ID 3950312
  191. :DOI:`10.1155/2018/3950312`
  192. .. [2] C. K. Chui and H. N. Mhaskar, MRA Contextual-Recovery Extension of
  193. Smooth Functions on Manifolds, Appl. and Comp. Harmonic Anal.,
  194. 28 (2010), 104-113,
  195. :DOI:`10.1016/j.acha.2009.04.004`
  196. Examples
  197. --------
  198. >>> img = np.tile(np.square(np.linspace(0, 1, 5)), (5, 1))
  199. >>> mask = np.zeros_like(img)
  200. >>> mask[2, 2:] = 1
  201. >>> mask[1, 3:] = 1
  202. >>> mask[0, 4:] = 1
  203. >>> out = inpaint_biharmonic(img, mask)
  204. """
  205. if image.ndim < 1:
  206. raise ValueError('Input array has to be at least 1D')
  207. multichannel = channel_axis is not None
  208. img_baseshape = image.shape[:-1] if multichannel else image.shape
  209. if img_baseshape != mask.shape:
  210. raise ValueError('Input arrays have to be the same shape')
  211. if np.ma.isMaskedArray(image):
  212. raise TypeError('Masked arrays are not supported')
  213. image = skimage.img_as_float(image)
  214. # float16->float32 and float128->float64
  215. float_dtype = utils._supported_float_type(image.dtype)
  216. image = image.astype(float_dtype, copy=False)
  217. mask = mask.astype(bool, copy=False)
  218. if not multichannel:
  219. image = image[..., np.newaxis]
  220. out = np.copy(image, order='C')
  221. # Create biharmonic coefficients ndarray
  222. radius = 2
  223. coef_shape = (2 * radius + 1,) * mask.ndim
  224. coef_center = (radius,) * mask.ndim
  225. neigh_coef_full, coef_idx, coef_vals = _get_neigh_coef(
  226. coef_shape, coef_center, dtype=out.dtype
  227. )
  228. # stride for the last spatial dimension
  229. channel_stride_bytes = out.strides[-2]
  230. # offsets to all neighboring non-zero elements in the footprint
  231. offsets = coef_idx - radius
  232. # determine per-channel intensity limits
  233. known_points = image[~mask]
  234. limits = (known_points.min(axis=0), known_points.max(axis=0))
  235. if split_into_regions:
  236. # Split inpainting mask into independent regions
  237. kernel = ndi.generate_binary_structure(mask.ndim, 1)
  238. mask_dilated = ndi.binary_dilation(mask, structure=kernel)
  239. mask_labeled = label(mask_dilated)
  240. mask_labeled *= mask
  241. bbox_slices = ndi.find_objects(mask_labeled)
  242. for idx_region, bb_slice in enumerate(bbox_slices, 1):
  243. # expand object bounding boxes by the biharmonic kernel radius
  244. roi_sl = tuple(
  245. slice(max(sl.start - radius, 0), min(sl.stop + radius, size))
  246. for sl, size in zip(bb_slice, mask_labeled.shape)
  247. )
  248. # extract only the region surrounding the label of interest
  249. mask_region = mask_labeled[roi_sl] == idx_region
  250. # add slice for axes
  251. roi_sl += (slice(None),)
  252. # copy for contiguity and to account for possible ROI overlap
  253. otmp = out[roi_sl].copy()
  254. # compute raveled offsets for the ROI
  255. ostrides = np.array(
  256. [s // channel_stride_bytes for s in otmp[..., 0].strides]
  257. )
  258. raveled_offsets = np.sum(offsets * ostrides[..., np.newaxis], axis=0)
  259. _inpaint_biharmonic_single_region(
  260. image[roi_sl],
  261. mask_region,
  262. otmp,
  263. neigh_coef_full,
  264. coef_vals,
  265. raveled_offsets,
  266. )
  267. # assign output to the
  268. out[roi_sl] = otmp
  269. else:
  270. # compute raveled offsets for output image
  271. ostrides = np.array([s // channel_stride_bytes for s in out[..., 0].strides])
  272. raveled_offsets = np.sum(offsets * ostrides[..., np.newaxis], axis=0)
  273. _inpaint_biharmonic_single_region(
  274. image, mask, out, neigh_coef_full, coef_vals, raveled_offsets
  275. )
  276. # Handle enormous values on a per-channel basis
  277. np.clip(out, a_min=limits[0], a_max=limits[1], out=out)
  278. if not multichannel:
  279. out = out[..., 0]
  280. return out