refinement.py 13 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393
  1. '''
  2. Part of the implementation is borrowed and modified from LaMa, publicly available at
  3. https://github.com/saic-mdal/lama
  4. '''
  5. import cv2
  6. import numpy as np
  7. import torch
  8. import torch.nn as nn
  9. from kornia.filters import gaussian_blur2d
  10. from kornia.geometry.transform import resize
  11. from kornia.morphology import erosion
  12. from torch.nn import functional as F
  13. from torch.optim import SGD, Adam
  14. from tqdm import tqdm
  15. from .modules.ffc import FFCResnetBlock
  16. def move_to_device(obj, device):
  17. if isinstance(obj, nn.Module):
  18. return obj.to(device)
  19. if torch.is_tensor(obj):
  20. return obj.to(device)
  21. if isinstance(obj, (tuple, list)):
  22. return [move_to_device(el, device) for el in obj]
  23. if isinstance(obj, dict):
  24. return {name: move_to_device(val, device) for name, val in obj.items()}
  25. raise ValueError(f'Unexpected type {type(obj)}')
  26. def ceil_modulo(x, mod):
  27. if x % mod == 0:
  28. return x
  29. return (x // mod + 1) * mod
  30. def pad_tensor_to_modulo(img, mod):
  31. batch_size, channels, height, width = img.shape
  32. out_height = ceil_modulo(height, mod)
  33. out_width = ceil_modulo(width, mod)
  34. return F.pad(
  35. img,
  36. pad=(0, out_width - width, 0, out_height - height),
  37. mode='reflect')
  38. def _pyrdown(im: torch.Tensor, downsize: tuple = None):
  39. """downscale the image"""
  40. if downsize is None:
  41. downsize = (im.shape[2] // 2, im.shape[3] // 2)
  42. assert im.shape[
  43. 1] == 3, 'Expected shape for the input to be (n,3,height,width)'
  44. im = gaussian_blur2d(im, kernel_size=(5, 5), sigma=(1.0, 1.0))
  45. im = F.interpolate(im, size=downsize, mode='bilinear', align_corners=False)
  46. return im
  47. def _pyrdown_mask(mask: torch.Tensor,
  48. downsize: tuple = None,
  49. eps: float = 1e-8,
  50. blur_mask: bool = True,
  51. round_up: bool = True):
  52. """downscale the mask tensor
  53. Parameters
  54. ----------
  55. mask : torch.Tensor
  56. mask of size (B, 1, H, W)
  57. downsize : tuple, optional
  58. size to downscale to. If None, image is downscaled to half, by default None
  59. eps : float, optional
  60. threshold value for binarizing the mask, by default 1e-8
  61. blur_mask : bool, optional
  62. if True, apply gaussian filter before downscaling, by default True
  63. round_up : bool, optional
  64. if True, values above eps are marked 1, else, values below 1-eps are marked 0, by default True
  65. Returns
  66. -------
  67. torch.Tensor
  68. downscaled mask
  69. """
  70. if downsize is None:
  71. downsize = (mask.shape[2] // 2, mask.shape[3] // 2)
  72. assert mask.shape[
  73. 1] == 1, 'Expected shape for the input to be (n,1,height,width)'
  74. if blur_mask is True:
  75. mask = gaussian_blur2d(mask, kernel_size=(5, 5), sigma=(1.0, 1.0))
  76. mask = F.interpolate(
  77. mask, size=downsize, mode='bilinear', align_corners=False)
  78. else:
  79. mask = F.interpolate(
  80. mask, size=downsize, mode='bilinear', align_corners=False)
  81. if round_up:
  82. mask[mask >= eps] = 1
  83. mask[mask < eps] = 0
  84. else:
  85. mask[mask >= 1.0 - eps] = 1
  86. mask[mask < 1.0 - eps] = 0
  87. return mask
  88. def _erode_mask(mask: torch.Tensor,
  89. ekernel: torch.Tensor = None,
  90. eps: float = 1e-8):
  91. """erode the mask, and set gray pixels to 0"""
  92. if ekernel is not None:
  93. mask = erosion(mask, ekernel)
  94. mask[mask >= 1.0 - eps] = 1
  95. mask[mask < 1.0 - eps] = 0
  96. return mask
  97. def _l1_loss(pred: torch.Tensor,
  98. pred_downscaled: torch.Tensor,
  99. ref: torch.Tensor,
  100. mask: torch.Tensor,
  101. mask_downscaled: torch.Tensor,
  102. image: torch.Tensor,
  103. on_pred: bool = True):
  104. """l1 loss on src pixels, and downscaled predictions if on_pred=True"""
  105. loss = torch.mean(torch.abs(pred[mask < 1e-8] - image[mask < 1e-8]))
  106. if on_pred:
  107. loss += torch.mean(
  108. torch.abs(pred_downscaled[mask_downscaled >= 1e-8]
  109. - ref[mask_downscaled >= 1e-8]))
  110. return loss
  111. def _infer(image: torch.Tensor,
  112. mask: torch.Tensor,
  113. forward_front: nn.Module,
  114. forward_rears: nn.Module,
  115. ref_lower_res: torch.Tensor,
  116. orig_shape: tuple,
  117. devices: list,
  118. scale_ind: int,
  119. n_iters: int = 15,
  120. lr: float = 0.002):
  121. """Performs inference with refinement at a given scale.
  122. Parameters
  123. ----------
  124. image : torch.Tensor
  125. input image to be inpainted, of size (1,3,H,W)
  126. mask : torch.Tensor
  127. input inpainting mask, of size (1,1,H,W)
  128. forward_front : nn.Module
  129. the front part of the inpainting network
  130. forward_rears : nn.Module
  131. the rear part of the inpainting network
  132. ref_lower_res : torch.Tensor
  133. the inpainting at previous scale, used as reference image
  134. orig_shape : tuple
  135. shape of the original input image before padding
  136. devices : list
  137. list of available devices
  138. scale_ind : int
  139. the scale index
  140. n_iters : int, optional
  141. number of iterations of refinement, by default 15
  142. lr : float, optional
  143. learning rate, by default 0.002
  144. Returns
  145. -------
  146. torch.Tensor
  147. inpainted image
  148. """
  149. masked_image = image * (1 - mask)
  150. masked_image = torch.cat([masked_image, mask], dim=1)
  151. mask = mask.repeat(1, 3, 1, 1)
  152. if ref_lower_res is not None:
  153. ref_lower_res = ref_lower_res.detach()
  154. with torch.no_grad():
  155. z1, z2 = forward_front(masked_image)
  156. # Inference
  157. mask = mask.to(devices[-1])
  158. ekernel = torch.from_numpy(
  159. cv2.getStructuringElement(cv2.MORPH_ELLIPSE,
  160. (15, 15)).astype(bool)).float()
  161. ekernel = ekernel.to(devices[-1])
  162. image = image.to(devices[-1])
  163. z1, z2 = z1.detach().to(devices[0]), z2.detach().to(devices[0])
  164. z1.requires_grad, z2.requires_grad = True, True
  165. optimizer = Adam([z1, z2], lr=lr)
  166. pbar = tqdm(range(n_iters), leave=False)
  167. for idi in pbar:
  168. optimizer.zero_grad()
  169. input_feat = (z1, z2)
  170. for idd, forward_rear in enumerate(forward_rears):
  171. output_feat = forward_rear(input_feat)
  172. if idd < len(devices) - 1:
  173. midz1, midz2 = output_feat
  174. midz1, midz2 = midz1.to(devices[idd + 1]), midz2.to(
  175. devices[idd + 1])
  176. input_feat = (midz1, midz2)
  177. else:
  178. pred = output_feat
  179. if ref_lower_res is None:
  180. break
  181. losses = {}
  182. # scaled loss with downsampler
  183. pred_downscaled = _pyrdown(pred[:, :, :orig_shape[0], :orig_shape[1]])
  184. mask_downscaled = _pyrdown_mask(
  185. mask[:, :1, :orig_shape[0], :orig_shape[1]],
  186. blur_mask=False,
  187. round_up=False)
  188. mask_downscaled = _erode_mask(mask_downscaled, ekernel=ekernel)
  189. mask_downscaled = mask_downscaled.repeat(1, 3, 1, 1)
  190. losses['ms_l1'] = _l1_loss(
  191. pred,
  192. pred_downscaled,
  193. ref_lower_res,
  194. mask,
  195. mask_downscaled,
  196. image,
  197. on_pred=True)
  198. loss = sum(losses.values())
  199. pbar.set_description(
  200. 'Refining scale {} using scale {} ...current loss: {:.4f}'.format(
  201. scale_ind + 1, scale_ind, loss.item()))
  202. if idi < n_iters - 1:
  203. loss.backward()
  204. optimizer.step()
  205. del pred_downscaled
  206. del loss
  207. del pred
  208. # "pred" is the prediction after Plug-n-Play module
  209. inpainted = mask * pred + (1 - mask) * image
  210. inpainted = inpainted.detach().cpu()
  211. return inpainted
  212. def _get_image_mask_pyramid(batch: dict, min_side: int, max_scales: int,
  213. px_budget: int):
  214. """Build the image mask pyramid
  215. Parameters
  216. ----------
  217. batch : dict
  218. batch containing image, mask, etc
  219. min_side : int
  220. minimum side length to limit the number of scales of the pyramid
  221. max_scales : int
  222. maximum number of scales allowed
  223. px_budget : int
  224. the product H*W cannot exceed this budget, because of resource constraints
  225. Returns
  226. -------
  227. tuple
  228. image-mask pyramid in the form of list of images and list of masks
  229. """
  230. assert batch['image'].shape[
  231. 0] == 1, 'refiner works on only batches of size 1!'
  232. h, w = batch['unpad_to_size']
  233. h, w = h[0].item(), w[0].item()
  234. image = batch['image'][..., :h, :w]
  235. mask = batch['mask'][..., :h, :w]
  236. if h * w > px_budget:
  237. # resize
  238. ratio = np.sqrt(px_budget / float(h * w))
  239. h_orig, w_orig = h, w
  240. h, w = int(h * ratio), int(w * ratio)
  241. print(
  242. f'Original image too large for refinement! Resizing {(h_orig,w_orig)} to {(h,w)}...'
  243. )
  244. image = resize(
  245. image, (h, w), interpolation='bilinear', align_corners=False)
  246. mask = resize(
  247. mask, (h, w), interpolation='bilinear', align_corners=False)
  248. mask[mask > 1e-8] = 1
  249. breadth = min(h, w)
  250. n_scales = min(1 + int(round(max(0, np.log2(breadth / min_side)))),
  251. max_scales)
  252. ls_images = []
  253. ls_masks = []
  254. ls_images.append(image)
  255. ls_masks.append(mask)
  256. for _ in range(n_scales - 1):
  257. image_p = _pyrdown(ls_images[-1])
  258. mask_p = _pyrdown_mask(ls_masks[-1])
  259. ls_images.append(image_p)
  260. ls_masks.append(mask_p)
  261. # reverse the lists because we want the lowest resolution image as index 0
  262. return ls_images[::-1], ls_masks[::-1]
  263. def refine_predict(batch: dict, inpainter: nn.Module, gpu_ids: str,
  264. modulo: int, n_iters: int, lr: float, min_side: int,
  265. max_scales: int, px_budget: int):
  266. """Refines the inpainting of the network
  267. Parameters
  268. ----------
  269. batch : dict
  270. image-mask batch, currently we assume the batchsize to be 1
  271. inpainter : nn.Module
  272. the inpainting neural network
  273. gpu_ids : str
  274. the GPU ids of the machine to use. If only single GPU, use: "0,"
  275. modulo : int
  276. pad the image to ensure dimension % modulo == 0
  277. n_iters : int
  278. number of iterations of refinement for each scale
  279. lr : float
  280. learning rate
  281. min_side : int
  282. all sides of image on all scales should be >= min_side / sqrt(2)
  283. max_scales : int
  284. max number of downscaling scales for the image-mask pyramid
  285. px_budget : int
  286. pixels budget. Any image will be resized to satisfy height*width <= px_budget
  287. Returns
  288. -------
  289. torch.Tensor
  290. inpainted image of size (1,3,H,W)
  291. """
  292. inpainter = inpainter.model
  293. assert not inpainter.training
  294. assert not inpainter.add_noise_kwargs
  295. assert inpainter.concat_mask
  296. gpu_ids = [
  297. f'cuda:{gpuid}' for gpuid in gpu_ids.replace(' ', '').split(',')
  298. if gpuid.isdigit()
  299. ]
  300. n_resnet_blocks = 0
  301. first_resblock_ind = 0
  302. found_first_resblock = False
  303. for idl in range(len(inpainter.generator.model)):
  304. if isinstance(inpainter.generator.model[idl], FFCResnetBlock):
  305. n_resnet_blocks += 1
  306. found_first_resblock = True
  307. elif not found_first_resblock:
  308. first_resblock_ind += 1
  309. resblocks_per_gpu = n_resnet_blocks // len(gpu_ids)
  310. devices = [torch.device(gpu_id) for gpu_id in gpu_ids]
  311. # split the model into front, and rear parts
  312. forward_front = inpainter.generator.model[0:first_resblock_ind]
  313. forward_front.to(devices[0])
  314. forward_rears = []
  315. for idd in range(len(gpu_ids)):
  316. if idd < len(gpu_ids) - 1:
  317. forward_rears.append(
  318. inpainter.generator.model[first_resblock_ind
  319. + resblocks_per_gpu
  320. * (idd):first_resblock_ind
  321. + resblocks_per_gpu * (idd + 1)])
  322. else:
  323. forward_rears.append(
  324. inpainter.generator.model[first_resblock_ind
  325. + resblocks_per_gpu * (idd):])
  326. forward_rears[idd].to(devices[idd])
  327. ls_images, ls_masks = _get_image_mask_pyramid(batch, min_side, max_scales,
  328. px_budget)
  329. image_inpainted = None
  330. for ids, (image, mask) in enumerate(zip(ls_images, ls_masks)):
  331. orig_shape = image.shape[2:]
  332. image = pad_tensor_to_modulo(image, modulo)
  333. mask = pad_tensor_to_modulo(mask, modulo)
  334. mask[mask >= 1e-8] = 1.0
  335. mask[mask < 1e-8] = 0.0
  336. image, mask = move_to_device(image, devices[0]), move_to_device(
  337. mask, devices[0])
  338. if image_inpainted is not None:
  339. image_inpainted = move_to_device(image_inpainted, devices[-1])
  340. image_inpainted = _infer(image, mask, forward_front, forward_rears,
  341. image_inpainted, orig_shape, devices, ids,
  342. n_iters, lr)
  343. image_inpainted = image_inpainted[:, :, :orig_shape[0], :orig_shape[1]]
  344. # detach everything to save resources
  345. image = image.detach().cpu()
  346. mask = mask.detach().cpu()
  347. return image_inpainted