''' Part of the implementation is borrowed and modified from LaMa, publicly available at https://github.com/saic-mdal/lama ''' import cv2 import numpy as np import torch import torch.nn as nn from kornia.filters import gaussian_blur2d from kornia.geometry.transform import resize from kornia.morphology import erosion from torch.nn import functional as F from torch.optim import SGD, Adam from tqdm import tqdm from .modules.ffc import FFCResnetBlock def move_to_device(obj, device): if isinstance(obj, nn.Module): return obj.to(device) if torch.is_tensor(obj): return obj.to(device) if isinstance(obj, (tuple, list)): return [move_to_device(el, device) for el in obj] if isinstance(obj, dict): return {name: move_to_device(val, device) for name, val in obj.items()} raise ValueError(f'Unexpected type {type(obj)}') def ceil_modulo(x, mod): if x % mod == 0: return x return (x // mod + 1) * mod def pad_tensor_to_modulo(img, mod): batch_size, channels, height, width = img.shape out_height = ceil_modulo(height, mod) out_width = ceil_modulo(width, mod) return F.pad( img, pad=(0, out_width - width, 0, out_height - height), mode='reflect') def _pyrdown(im: torch.Tensor, downsize: tuple = None): """downscale the image""" if downsize is None: downsize = (im.shape[2] // 2, im.shape[3] // 2) assert im.shape[ 1] == 3, 'Expected shape for the input to be (n,3,height,width)' im = gaussian_blur2d(im, kernel_size=(5, 5), sigma=(1.0, 1.0)) im = F.interpolate(im, size=downsize, mode='bilinear', align_corners=False) return im def _pyrdown_mask(mask: torch.Tensor, downsize: tuple = None, eps: float = 1e-8, blur_mask: bool = True, round_up: bool = True): """downscale the mask tensor Parameters ---------- mask : torch.Tensor mask of size (B, 1, H, W) downsize : tuple, optional size to downscale to. If None, image is downscaled to half, by default None eps : float, optional threshold value for binarizing the mask, by default 1e-8 blur_mask : bool, optional if True, apply gaussian filter before downscaling, by default True round_up : bool, optional if True, values above eps are marked 1, else, values below 1-eps are marked 0, by default True Returns ------- torch.Tensor downscaled mask """ if downsize is None: downsize = (mask.shape[2] // 2, mask.shape[3] // 2) assert mask.shape[ 1] == 1, 'Expected shape for the input to be (n,1,height,width)' if blur_mask is True: mask = gaussian_blur2d(mask, kernel_size=(5, 5), sigma=(1.0, 1.0)) mask = F.interpolate( mask, size=downsize, mode='bilinear', align_corners=False) else: mask = F.interpolate( mask, size=downsize, mode='bilinear', align_corners=False) if round_up: mask[mask >= eps] = 1 mask[mask < eps] = 0 else: mask[mask >= 1.0 - eps] = 1 mask[mask < 1.0 - eps] = 0 return mask def _erode_mask(mask: torch.Tensor, ekernel: torch.Tensor = None, eps: float = 1e-8): """erode the mask, and set gray pixels to 0""" if ekernel is not None: mask = erosion(mask, ekernel) mask[mask >= 1.0 - eps] = 1 mask[mask < 1.0 - eps] = 0 return mask def _l1_loss(pred: torch.Tensor, pred_downscaled: torch.Tensor, ref: torch.Tensor, mask: torch.Tensor, mask_downscaled: torch.Tensor, image: torch.Tensor, on_pred: bool = True): """l1 loss on src pixels, and downscaled predictions if on_pred=True""" loss = torch.mean(torch.abs(pred[mask < 1e-8] - image[mask < 1e-8])) if on_pred: loss += torch.mean( torch.abs(pred_downscaled[mask_downscaled >= 1e-8] - ref[mask_downscaled >= 1e-8])) return loss def _infer(image: torch.Tensor, mask: torch.Tensor, forward_front: nn.Module, forward_rears: nn.Module, ref_lower_res: torch.Tensor, orig_shape: tuple, devices: list, scale_ind: int, n_iters: int = 15, lr: float = 0.002): """Performs inference with refinement at a given scale. Parameters ---------- image : torch.Tensor input image to be inpainted, of size (1,3,H,W) mask : torch.Tensor input inpainting mask, of size (1,1,H,W) forward_front : nn.Module the front part of the inpainting network forward_rears : nn.Module the rear part of the inpainting network ref_lower_res : torch.Tensor the inpainting at previous scale, used as reference image orig_shape : tuple shape of the original input image before padding devices : list list of available devices scale_ind : int the scale index n_iters : int, optional number of iterations of refinement, by default 15 lr : float, optional learning rate, by default 0.002 Returns ------- torch.Tensor inpainted image """ masked_image = image * (1 - mask) masked_image = torch.cat([masked_image, mask], dim=1) mask = mask.repeat(1, 3, 1, 1) if ref_lower_res is not None: ref_lower_res = ref_lower_res.detach() with torch.no_grad(): z1, z2 = forward_front(masked_image) # Inference mask = mask.to(devices[-1]) ekernel = torch.from_numpy( cv2.getStructuringElement(cv2.MORPH_ELLIPSE, (15, 15)).astype(bool)).float() ekernel = ekernel.to(devices[-1]) image = image.to(devices[-1]) z1, z2 = z1.detach().to(devices[0]), z2.detach().to(devices[0]) z1.requires_grad, z2.requires_grad = True, True optimizer = Adam([z1, z2], lr=lr) pbar = tqdm(range(n_iters), leave=False) for idi in pbar: optimizer.zero_grad() input_feat = (z1, z2) for idd, forward_rear in enumerate(forward_rears): output_feat = forward_rear(input_feat) if idd < len(devices) - 1: midz1, midz2 = output_feat midz1, midz2 = midz1.to(devices[idd + 1]), midz2.to( devices[idd + 1]) input_feat = (midz1, midz2) else: pred = output_feat if ref_lower_res is None: break losses = {} # scaled loss with downsampler pred_downscaled = _pyrdown(pred[:, :, :orig_shape[0], :orig_shape[1]]) mask_downscaled = _pyrdown_mask( mask[:, :1, :orig_shape[0], :orig_shape[1]], blur_mask=False, round_up=False) mask_downscaled = _erode_mask(mask_downscaled, ekernel=ekernel) mask_downscaled = mask_downscaled.repeat(1, 3, 1, 1) losses['ms_l1'] = _l1_loss( pred, pred_downscaled, ref_lower_res, mask, mask_downscaled, image, on_pred=True) loss = sum(losses.values()) pbar.set_description( 'Refining scale {} using scale {} ...current loss: {:.4f}'.format( scale_ind + 1, scale_ind, loss.item())) if idi < n_iters - 1: loss.backward() optimizer.step() del pred_downscaled del loss del pred # "pred" is the prediction after Plug-n-Play module inpainted = mask * pred + (1 - mask) * image inpainted = inpainted.detach().cpu() return inpainted def _get_image_mask_pyramid(batch: dict, min_side: int, max_scales: int, px_budget: int): """Build the image mask pyramid Parameters ---------- batch : dict batch containing image, mask, etc min_side : int minimum side length to limit the number of scales of the pyramid max_scales : int maximum number of scales allowed px_budget : int the product H*W cannot exceed this budget, because of resource constraints Returns ------- tuple image-mask pyramid in the form of list of images and list of masks """ assert batch['image'].shape[ 0] == 1, 'refiner works on only batches of size 1!' h, w = batch['unpad_to_size'] h, w = h[0].item(), w[0].item() image = batch['image'][..., :h, :w] mask = batch['mask'][..., :h, :w] if h * w > px_budget: # resize ratio = np.sqrt(px_budget / float(h * w)) h_orig, w_orig = h, w h, w = int(h * ratio), int(w * ratio) print( f'Original image too large for refinement! Resizing {(h_orig,w_orig)} to {(h,w)}...' ) image = resize( image, (h, w), interpolation='bilinear', align_corners=False) mask = resize( mask, (h, w), interpolation='bilinear', align_corners=False) mask[mask > 1e-8] = 1 breadth = min(h, w) n_scales = min(1 + int(round(max(0, np.log2(breadth / min_side)))), max_scales) ls_images = [] ls_masks = [] ls_images.append(image) ls_masks.append(mask) for _ in range(n_scales - 1): image_p = _pyrdown(ls_images[-1]) mask_p = _pyrdown_mask(ls_masks[-1]) ls_images.append(image_p) ls_masks.append(mask_p) # reverse the lists because we want the lowest resolution image as index 0 return ls_images[::-1], ls_masks[::-1] def refine_predict(batch: dict, inpainter: nn.Module, gpu_ids: str, modulo: int, n_iters: int, lr: float, min_side: int, max_scales: int, px_budget: int): """Refines the inpainting of the network Parameters ---------- batch : dict image-mask batch, currently we assume the batchsize to be 1 inpainter : nn.Module the inpainting neural network gpu_ids : str the GPU ids of the machine to use. If only single GPU, use: "0," modulo : int pad the image to ensure dimension % modulo == 0 n_iters : int number of iterations of refinement for each scale lr : float learning rate min_side : int all sides of image on all scales should be >= min_side / sqrt(2) max_scales : int max number of downscaling scales for the image-mask pyramid px_budget : int pixels budget. Any image will be resized to satisfy height*width <= px_budget Returns ------- torch.Tensor inpainted image of size (1,3,H,W) """ inpainter = inpainter.model assert not inpainter.training assert not inpainter.add_noise_kwargs assert inpainter.concat_mask gpu_ids = [ f'cuda:{gpuid}' for gpuid in gpu_ids.replace(' ', '').split(',') if gpuid.isdigit() ] n_resnet_blocks = 0 first_resblock_ind = 0 found_first_resblock = False for idl in range(len(inpainter.generator.model)): if isinstance(inpainter.generator.model[idl], FFCResnetBlock): n_resnet_blocks += 1 found_first_resblock = True elif not found_first_resblock: first_resblock_ind += 1 resblocks_per_gpu = n_resnet_blocks // len(gpu_ids) devices = [torch.device(gpu_id) for gpu_id in gpu_ids] # split the model into front, and rear parts forward_front = inpainter.generator.model[0:first_resblock_ind] forward_front.to(devices[0]) forward_rears = [] for idd in range(len(gpu_ids)): if idd < len(gpu_ids) - 1: forward_rears.append( inpainter.generator.model[first_resblock_ind + resblocks_per_gpu * (idd):first_resblock_ind + resblocks_per_gpu * (idd + 1)]) else: forward_rears.append( inpainter.generator.model[first_resblock_ind + resblocks_per_gpu * (idd):]) forward_rears[idd].to(devices[idd]) ls_images, ls_masks = _get_image_mask_pyramid(batch, min_side, max_scales, px_budget) image_inpainted = None for ids, (image, mask) in enumerate(zip(ls_images, ls_masks)): orig_shape = image.shape[2:] image = pad_tensor_to_modulo(image, modulo) mask = pad_tensor_to_modulo(mask, modulo) mask[mask >= 1e-8] = 1.0 mask[mask < 1e-8] = 0.0 image, mask = move_to_device(image, devices[0]), move_to_device( mask, devices[0]) if image_inpainted is not None: image_inpainted = move_to_device(image_inpainted, devices[-1]) image_inpainted = _infer(image, mask, forward_front, forward_rears, image_inpainted, orig_shape, devices, ids, n_iters, lr) image_inpainted = image_inpainted[:, :, :orig_shape[0], :orig_shape[1]] # detach everything to save resources image = image.detach().cpu() mask = mask.detach().cpu() return image_inpainted