| 123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393 |
- '''
- Part of the implementation is borrowed and modified from LaMa, publicly available at
- https://github.com/saic-mdal/lama
- '''
- import cv2
- import numpy as np
- import torch
- import torch.nn as nn
- from kornia.filters import gaussian_blur2d
- from kornia.geometry.transform import resize
- from kornia.morphology import erosion
- from torch.nn import functional as F
- from torch.optim import SGD, Adam
- from tqdm import tqdm
- from .modules.ffc import FFCResnetBlock
- def move_to_device(obj, device):
- if isinstance(obj, nn.Module):
- return obj.to(device)
- if torch.is_tensor(obj):
- return obj.to(device)
- if isinstance(obj, (tuple, list)):
- return [move_to_device(el, device) for el in obj]
- if isinstance(obj, dict):
- return {name: move_to_device(val, device) for name, val in obj.items()}
- raise ValueError(f'Unexpected type {type(obj)}')
- def ceil_modulo(x, mod):
- if x % mod == 0:
- return x
- return (x // mod + 1) * mod
- def pad_tensor_to_modulo(img, mod):
- batch_size, channels, height, width = img.shape
- out_height = ceil_modulo(height, mod)
- out_width = ceil_modulo(width, mod)
- return F.pad(
- img,
- pad=(0, out_width - width, 0, out_height - height),
- mode='reflect')
- def _pyrdown(im: torch.Tensor, downsize: tuple = None):
- """downscale the image"""
- if downsize is None:
- downsize = (im.shape[2] // 2, im.shape[3] // 2)
- assert im.shape[
- 1] == 3, 'Expected shape for the input to be (n,3,height,width)'
- im = gaussian_blur2d(im, kernel_size=(5, 5), sigma=(1.0, 1.0))
- im = F.interpolate(im, size=downsize, mode='bilinear', align_corners=False)
- return im
- def _pyrdown_mask(mask: torch.Tensor,
- downsize: tuple = None,
- eps: float = 1e-8,
- blur_mask: bool = True,
- round_up: bool = True):
- """downscale the mask tensor
- Parameters
- ----------
- mask : torch.Tensor
- mask of size (B, 1, H, W)
- downsize : tuple, optional
- size to downscale to. If None, image is downscaled to half, by default None
- eps : float, optional
- threshold value for binarizing the mask, by default 1e-8
- blur_mask : bool, optional
- if True, apply gaussian filter before downscaling, by default True
- round_up : bool, optional
- if True, values above eps are marked 1, else, values below 1-eps are marked 0, by default True
- Returns
- -------
- torch.Tensor
- downscaled mask
- """
- if downsize is None:
- downsize = (mask.shape[2] // 2, mask.shape[3] // 2)
- assert mask.shape[
- 1] == 1, 'Expected shape for the input to be (n,1,height,width)'
- if blur_mask is True:
- mask = gaussian_blur2d(mask, kernel_size=(5, 5), sigma=(1.0, 1.0))
- mask = F.interpolate(
- mask, size=downsize, mode='bilinear', align_corners=False)
- else:
- mask = F.interpolate(
- mask, size=downsize, mode='bilinear', align_corners=False)
- if round_up:
- mask[mask >= eps] = 1
- mask[mask < eps] = 0
- else:
- mask[mask >= 1.0 - eps] = 1
- mask[mask < 1.0 - eps] = 0
- return mask
- def _erode_mask(mask: torch.Tensor,
- ekernel: torch.Tensor = None,
- eps: float = 1e-8):
- """erode the mask, and set gray pixels to 0"""
- if ekernel is not None:
- mask = erosion(mask, ekernel)
- mask[mask >= 1.0 - eps] = 1
- mask[mask < 1.0 - eps] = 0
- return mask
- def _l1_loss(pred: torch.Tensor,
- pred_downscaled: torch.Tensor,
- ref: torch.Tensor,
- mask: torch.Tensor,
- mask_downscaled: torch.Tensor,
- image: torch.Tensor,
- on_pred: bool = True):
- """l1 loss on src pixels, and downscaled predictions if on_pred=True"""
- loss = torch.mean(torch.abs(pred[mask < 1e-8] - image[mask < 1e-8]))
- if on_pred:
- loss += torch.mean(
- torch.abs(pred_downscaled[mask_downscaled >= 1e-8]
- - ref[mask_downscaled >= 1e-8]))
- return loss
- def _infer(image: torch.Tensor,
- mask: torch.Tensor,
- forward_front: nn.Module,
- forward_rears: nn.Module,
- ref_lower_res: torch.Tensor,
- orig_shape: tuple,
- devices: list,
- scale_ind: int,
- n_iters: int = 15,
- lr: float = 0.002):
- """Performs inference with refinement at a given scale.
- Parameters
- ----------
- image : torch.Tensor
- input image to be inpainted, of size (1,3,H,W)
- mask : torch.Tensor
- input inpainting mask, of size (1,1,H,W)
- forward_front : nn.Module
- the front part of the inpainting network
- forward_rears : nn.Module
- the rear part of the inpainting network
- ref_lower_res : torch.Tensor
- the inpainting at previous scale, used as reference image
- orig_shape : tuple
- shape of the original input image before padding
- devices : list
- list of available devices
- scale_ind : int
- the scale index
- n_iters : int, optional
- number of iterations of refinement, by default 15
- lr : float, optional
- learning rate, by default 0.002
- Returns
- -------
- torch.Tensor
- inpainted image
- """
- masked_image = image * (1 - mask)
- masked_image = torch.cat([masked_image, mask], dim=1)
- mask = mask.repeat(1, 3, 1, 1)
- if ref_lower_res is not None:
- ref_lower_res = ref_lower_res.detach()
- with torch.no_grad():
- z1, z2 = forward_front(masked_image)
- # Inference
- mask = mask.to(devices[-1])
- ekernel = torch.from_numpy(
- cv2.getStructuringElement(cv2.MORPH_ELLIPSE,
- (15, 15)).astype(bool)).float()
- ekernel = ekernel.to(devices[-1])
- image = image.to(devices[-1])
- z1, z2 = z1.detach().to(devices[0]), z2.detach().to(devices[0])
- z1.requires_grad, z2.requires_grad = True, True
- optimizer = Adam([z1, z2], lr=lr)
- pbar = tqdm(range(n_iters), leave=False)
- for idi in pbar:
- optimizer.zero_grad()
- input_feat = (z1, z2)
- for idd, forward_rear in enumerate(forward_rears):
- output_feat = forward_rear(input_feat)
- if idd < len(devices) - 1:
- midz1, midz2 = output_feat
- midz1, midz2 = midz1.to(devices[idd + 1]), midz2.to(
- devices[idd + 1])
- input_feat = (midz1, midz2)
- else:
- pred = output_feat
- if ref_lower_res is None:
- break
- losses = {}
- # scaled loss with downsampler
- pred_downscaled = _pyrdown(pred[:, :, :orig_shape[0], :orig_shape[1]])
- mask_downscaled = _pyrdown_mask(
- mask[:, :1, :orig_shape[0], :orig_shape[1]],
- blur_mask=False,
- round_up=False)
- mask_downscaled = _erode_mask(mask_downscaled, ekernel=ekernel)
- mask_downscaled = mask_downscaled.repeat(1, 3, 1, 1)
- losses['ms_l1'] = _l1_loss(
- pred,
- pred_downscaled,
- ref_lower_res,
- mask,
- mask_downscaled,
- image,
- on_pred=True)
- loss = sum(losses.values())
- pbar.set_description(
- 'Refining scale {} using scale {} ...current loss: {:.4f}'.format(
- scale_ind + 1, scale_ind, loss.item()))
- if idi < n_iters - 1:
- loss.backward()
- optimizer.step()
- del pred_downscaled
- del loss
- del pred
- # "pred" is the prediction after Plug-n-Play module
- inpainted = mask * pred + (1 - mask) * image
- inpainted = inpainted.detach().cpu()
- return inpainted
- def _get_image_mask_pyramid(batch: dict, min_side: int, max_scales: int,
- px_budget: int):
- """Build the image mask pyramid
- Parameters
- ----------
- batch : dict
- batch containing image, mask, etc
- min_side : int
- minimum side length to limit the number of scales of the pyramid
- max_scales : int
- maximum number of scales allowed
- px_budget : int
- the product H*W cannot exceed this budget, because of resource constraints
- Returns
- -------
- tuple
- image-mask pyramid in the form of list of images and list of masks
- """
- assert batch['image'].shape[
- 0] == 1, 'refiner works on only batches of size 1!'
- h, w = batch['unpad_to_size']
- h, w = h[0].item(), w[0].item()
- image = batch['image'][..., :h, :w]
- mask = batch['mask'][..., :h, :w]
- if h * w > px_budget:
- # resize
- ratio = np.sqrt(px_budget / float(h * w))
- h_orig, w_orig = h, w
- h, w = int(h * ratio), int(w * ratio)
- print(
- f'Original image too large for refinement! Resizing {(h_orig,w_orig)} to {(h,w)}...'
- )
- image = resize(
- image, (h, w), interpolation='bilinear', align_corners=False)
- mask = resize(
- mask, (h, w), interpolation='bilinear', align_corners=False)
- mask[mask > 1e-8] = 1
- breadth = min(h, w)
- n_scales = min(1 + int(round(max(0, np.log2(breadth / min_side)))),
- max_scales)
- ls_images = []
- ls_masks = []
- ls_images.append(image)
- ls_masks.append(mask)
- for _ in range(n_scales - 1):
- image_p = _pyrdown(ls_images[-1])
- mask_p = _pyrdown_mask(ls_masks[-1])
- ls_images.append(image_p)
- ls_masks.append(mask_p)
- # reverse the lists because we want the lowest resolution image as index 0
- return ls_images[::-1], ls_masks[::-1]
- def refine_predict(batch: dict, inpainter: nn.Module, gpu_ids: str,
- modulo: int, n_iters: int, lr: float, min_side: int,
- max_scales: int, px_budget: int):
- """Refines the inpainting of the network
- Parameters
- ----------
- batch : dict
- image-mask batch, currently we assume the batchsize to be 1
- inpainter : nn.Module
- the inpainting neural network
- gpu_ids : str
- the GPU ids of the machine to use. If only single GPU, use: "0,"
- modulo : int
- pad the image to ensure dimension % modulo == 0
- n_iters : int
- number of iterations of refinement for each scale
- lr : float
- learning rate
- min_side : int
- all sides of image on all scales should be >= min_side / sqrt(2)
- max_scales : int
- max number of downscaling scales for the image-mask pyramid
- px_budget : int
- pixels budget. Any image will be resized to satisfy height*width <= px_budget
- Returns
- -------
- torch.Tensor
- inpainted image of size (1,3,H,W)
- """
- inpainter = inpainter.model
- assert not inpainter.training
- assert not inpainter.add_noise_kwargs
- assert inpainter.concat_mask
- gpu_ids = [
- f'cuda:{gpuid}' for gpuid in gpu_ids.replace(' ', '').split(',')
- if gpuid.isdigit()
- ]
- n_resnet_blocks = 0
- first_resblock_ind = 0
- found_first_resblock = False
- for idl in range(len(inpainter.generator.model)):
- if isinstance(inpainter.generator.model[idl], FFCResnetBlock):
- n_resnet_blocks += 1
- found_first_resblock = True
- elif not found_first_resblock:
- first_resblock_ind += 1
- resblocks_per_gpu = n_resnet_blocks // len(gpu_ids)
- devices = [torch.device(gpu_id) for gpu_id in gpu_ids]
- # split the model into front, and rear parts
- forward_front = inpainter.generator.model[0:first_resblock_ind]
- forward_front.to(devices[0])
- forward_rears = []
- for idd in range(len(gpu_ids)):
- if idd < len(gpu_ids) - 1:
- forward_rears.append(
- inpainter.generator.model[first_resblock_ind
- + resblocks_per_gpu
- * (idd):first_resblock_ind
- + resblocks_per_gpu * (idd + 1)])
- else:
- forward_rears.append(
- inpainter.generator.model[first_resblock_ind
- + resblocks_per_gpu * (idd):])
- forward_rears[idd].to(devices[idd])
- ls_images, ls_masks = _get_image_mask_pyramid(batch, min_side, max_scales,
- px_budget)
- image_inpainted = None
- for ids, (image, mask) in enumerate(zip(ls_images, ls_masks)):
- orig_shape = image.shape[2:]
- image = pad_tensor_to_modulo(image, modulo)
- mask = pad_tensor_to_modulo(mask, modulo)
- mask[mask >= 1e-8] = 1.0
- mask[mask < 1e-8] = 0.0
- image, mask = move_to_device(image, devices[0]), move_to_device(
- mask, devices[0])
- if image_inpainted is not None:
- image_inpainted = move_to_device(image_inpainted, devices[-1])
- image_inpainted = _infer(image, mask, forward_front, forward_rears,
- image_inpainted, orig_shape, devices, ids,
- n_iters, lr)
- image_inpainted = image_inpainted[:, :, :orig_shape[0], :orig_shape[1]]
- # detach everything to save resources
- image = image.detach().cpu()
- mask = mask.detach().cpu()
- return image_inpainted
|