| 123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210 |
- """
- Part of the implementation is borrowed and modified from LaMa, publicly available at
- https://github.com/saic-mdal/lama
- """
- import bisect
- import torch
- import torch.nn.functional as F
- from modelscope.utils.logger import get_logger
- from .base import BaseInpaintingTrainingModule
- from .modules.feature_matching import feature_matching_loss, masked_l1_loss
- LOGGER = get_logger()
- def set_requires_grad(module, value):
- for param in module.parameters():
- param.requires_grad = value
- def add_prefix_to_keys(dct, prefix):
- return {prefix + k: v for k, v in dct.items()}
- class LinearRamp:
- def __init__(self, start_value=0, end_value=1, start_iter=-1, end_iter=0):
- self.start_value = start_value
- self.end_value = end_value
- self.start_iter = start_iter
- self.end_iter = end_iter
- def __call__(self, i):
- if i < self.start_iter:
- return self.start_value
- if i >= self.end_iter:
- return self.end_value
- part = (i - self.start_iter) / (self.end_iter - self.start_iter)
- return self.start_value * (1 - part) + self.end_value * part
- class LadderRamp:
- def __init__(self, start_iters, values):
- self.start_iters = start_iters
- self.values = values
- assert len(values) == len(start_iters) + 1, (len(values),
- len(start_iters))
- def __call__(self, i):
- segment_i = bisect.bisect_right(self.start_iters, i)
- return self.values[segment_i]
- def get_ramp(kind='ladder', **kwargs):
- if kind == 'linear':
- return LinearRamp(**kwargs)
- if kind == 'ladder':
- return LadderRamp(**kwargs)
- raise ValueError(f'Unexpected ramp kind: {kind}')
- class DefaultInpaintingTrainingModule(BaseInpaintingTrainingModule):
- def __init__(self,
- model_dir='',
- predict_only=False,
- concat_mask=True,
- rescale_scheduler_kwargs=None,
- image_to_discriminator='predicted_image',
- add_noise_kwargs=None,
- noise_fill_hole=False,
- const_area_crop_kwargs=None,
- distance_weighter_kwargs=None,
- distance_weighted_mask_for_discr=False,
- fake_fakes_proba=0,
- fake_fakes_generator_kwargs=None,
- **kwargs):
- super().__init__(model_dir=model_dir, predict_only=predict_only)
- self.concat_mask = concat_mask
- self.rescale_size_getter = get_ramp(
- **rescale_scheduler_kwargs
- ) if rescale_scheduler_kwargs is not None else None
- self.image_to_discriminator = image_to_discriminator
- self.add_noise_kwargs = add_noise_kwargs
- self.noise_fill_hole = noise_fill_hole
- self.const_area_crop_kwargs = const_area_crop_kwargs
- self.refine_mask_for_losses = None
- self.distance_weighted_mask_for_discr = distance_weighted_mask_for_discr
- self.feature_matching_weight = 100
- self.losses_l1_weight_known = 10
- self.losses_l1_weight_missing = 0
- self.fake_fakes_proba = fake_fakes_proba
- def forward(self, batch):
- img = batch['image']
- mask = batch['mask']
- masked_img = img * (1 - mask)
- if self.concat_mask:
- masked_img = torch.cat([masked_img, mask], dim=1)
- batch['predicted_image'] = self.generator(masked_img)
- batch['inpainted'] = mask * batch['predicted_image'] + (
- 1 - mask) * batch['image']
- batch['mask_for_losses'] = mask
- return batch
- def generator_loss(self, batch):
- img = batch['image']
- predicted_img = batch[self.image_to_discriminator]
- original_mask = batch['mask']
- supervised_mask = batch['mask_for_losses']
- # L1
- l1_value = masked_l1_loss(predicted_img, img, supervised_mask,
- self.losses_l1_weight_known,
- self.losses_l1_weight_missing)
- total_loss = l1_value
- metrics = dict(gen_l1=l1_value)
- # discriminator
- # adversarial_loss calls backward by itself
- mask_for_discr = supervised_mask if self.distance_weighted_mask_for_discr else original_mask
- self.adversarial_loss.pre_generator_step(
- real_batch=img,
- fake_batch=predicted_img,
- generator=self.generator,
- discriminator=self.discriminator)
- discr_real_pred, discr_real_features = self.discriminator(img)
- discr_fake_pred, discr_fake_features = self.discriminator(
- predicted_img)
- adv_gen_loss, adv_metrics = self.adversarial_loss.generator_loss(
- real_batch=img,
- fake_batch=predicted_img,
- discr_real_pred=discr_real_pred,
- discr_fake_pred=discr_fake_pred,
- mask=mask_for_discr)
- total_loss = total_loss + adv_gen_loss
- metrics['gen_adv'] = adv_gen_loss
- metrics.update(add_prefix_to_keys(adv_metrics, 'adv_'))
- # feature matching
- if self.feature_matching_weight > 0:
- need_mask_in_fm = False
- mask_for_fm = supervised_mask if need_mask_in_fm else None
- fm_value = feature_matching_loss(
- discr_fake_features, discr_real_features,
- mask=mask_for_fm) * self.feature_matching_weight
- total_loss = total_loss + fm_value
- metrics['gen_fm'] = fm_value
- if self.loss_resnet_pl is not None:
- resnet_pl_value = self.loss_resnet_pl(predicted_img, img)
- total_loss = total_loss + resnet_pl_value
- metrics['gen_resnet_pl'] = resnet_pl_value
- return total_loss, metrics
- def discriminator_loss(self, batch):
- total_loss = 0
- metrics = {}
- predicted_img = batch[self.image_to_discriminator].detach()
- self.adversarial_loss.pre_discriminator_step(
- real_batch=batch['image'],
- fake_batch=predicted_img,
- generator=self.generator,
- discriminator=self.discriminator)
- discr_real_pred, discr_real_features = self.discriminator(
- batch['image'])
- discr_fake_pred, discr_fake_features = self.discriminator(
- predicted_img)
- adv_discr_loss, adv_metrics = self.adversarial_loss.discriminator_loss(
- real_batch=batch['image'],
- fake_batch=predicted_img,
- discr_real_pred=discr_real_pred,
- discr_fake_pred=discr_fake_pred,
- mask=batch['mask'])
- total_loss = (total_loss + adv_discr_loss) * 0.1
- metrics['discr_adv'] = adv_discr_loss
- metrics.update(add_prefix_to_keys(adv_metrics, 'adv_'))
- return total_loss, metrics
- def _do_step(self, batch, optimizer_idx=None):
- if optimizer_idx == 0: # step for generator
- set_requires_grad(self.generator, True)
- set_requires_grad(self.discriminator, False)
- elif optimizer_idx == 1: # step for discriminator
- set_requires_grad(self.generator, False)
- set_requires_grad(self.discriminator, True)
- batch = self(batch)
- total_loss = 0
- if optimizer_idx is None or optimizer_idx == 0: # step for generator
- total_loss, metrics = self.generator_loss(batch)
- elif optimizer_idx is None or optimizer_idx == 1: # step for discriminator
- total_loss, metrics = self.discriminator_loss(batch)
- result = dict(loss=total_loss)
- return result
|