default.py 7.4 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210
  1. """
  2. Part of the implementation is borrowed and modified from LaMa, publicly available at
  3. https://github.com/saic-mdal/lama
  4. """
  5. import bisect
  6. import torch
  7. import torch.nn.functional as F
  8. from modelscope.utils.logger import get_logger
  9. from .base import BaseInpaintingTrainingModule
  10. from .modules.feature_matching import feature_matching_loss, masked_l1_loss
  11. LOGGER = get_logger()
  12. def set_requires_grad(module, value):
  13. for param in module.parameters():
  14. param.requires_grad = value
  15. def add_prefix_to_keys(dct, prefix):
  16. return {prefix + k: v for k, v in dct.items()}
  17. class LinearRamp:
  18. def __init__(self, start_value=0, end_value=1, start_iter=-1, end_iter=0):
  19. self.start_value = start_value
  20. self.end_value = end_value
  21. self.start_iter = start_iter
  22. self.end_iter = end_iter
  23. def __call__(self, i):
  24. if i < self.start_iter:
  25. return self.start_value
  26. if i >= self.end_iter:
  27. return self.end_value
  28. part = (i - self.start_iter) / (self.end_iter - self.start_iter)
  29. return self.start_value * (1 - part) + self.end_value * part
  30. class LadderRamp:
  31. def __init__(self, start_iters, values):
  32. self.start_iters = start_iters
  33. self.values = values
  34. assert len(values) == len(start_iters) + 1, (len(values),
  35. len(start_iters))
  36. def __call__(self, i):
  37. segment_i = bisect.bisect_right(self.start_iters, i)
  38. return self.values[segment_i]
  39. def get_ramp(kind='ladder', **kwargs):
  40. if kind == 'linear':
  41. return LinearRamp(**kwargs)
  42. if kind == 'ladder':
  43. return LadderRamp(**kwargs)
  44. raise ValueError(f'Unexpected ramp kind: {kind}')
  45. class DefaultInpaintingTrainingModule(BaseInpaintingTrainingModule):
  46. def __init__(self,
  47. model_dir='',
  48. predict_only=False,
  49. concat_mask=True,
  50. rescale_scheduler_kwargs=None,
  51. image_to_discriminator='predicted_image',
  52. add_noise_kwargs=None,
  53. noise_fill_hole=False,
  54. const_area_crop_kwargs=None,
  55. distance_weighter_kwargs=None,
  56. distance_weighted_mask_for_discr=False,
  57. fake_fakes_proba=0,
  58. fake_fakes_generator_kwargs=None,
  59. **kwargs):
  60. super().__init__(model_dir=model_dir, predict_only=predict_only)
  61. self.concat_mask = concat_mask
  62. self.rescale_size_getter = get_ramp(
  63. **rescale_scheduler_kwargs
  64. ) if rescale_scheduler_kwargs is not None else None
  65. self.image_to_discriminator = image_to_discriminator
  66. self.add_noise_kwargs = add_noise_kwargs
  67. self.noise_fill_hole = noise_fill_hole
  68. self.const_area_crop_kwargs = const_area_crop_kwargs
  69. self.refine_mask_for_losses = None
  70. self.distance_weighted_mask_for_discr = distance_weighted_mask_for_discr
  71. self.feature_matching_weight = 100
  72. self.losses_l1_weight_known = 10
  73. self.losses_l1_weight_missing = 0
  74. self.fake_fakes_proba = fake_fakes_proba
  75. def forward(self, batch):
  76. img = batch['image']
  77. mask = batch['mask']
  78. masked_img = img * (1 - mask)
  79. if self.concat_mask:
  80. masked_img = torch.cat([masked_img, mask], dim=1)
  81. batch['predicted_image'] = self.generator(masked_img)
  82. batch['inpainted'] = mask * batch['predicted_image'] + (
  83. 1 - mask) * batch['image']
  84. batch['mask_for_losses'] = mask
  85. return batch
  86. def generator_loss(self, batch):
  87. img = batch['image']
  88. predicted_img = batch[self.image_to_discriminator]
  89. original_mask = batch['mask']
  90. supervised_mask = batch['mask_for_losses']
  91. # L1
  92. l1_value = masked_l1_loss(predicted_img, img, supervised_mask,
  93. self.losses_l1_weight_known,
  94. self.losses_l1_weight_missing)
  95. total_loss = l1_value
  96. metrics = dict(gen_l1=l1_value)
  97. # discriminator
  98. # adversarial_loss calls backward by itself
  99. mask_for_discr = supervised_mask if self.distance_weighted_mask_for_discr else original_mask
  100. self.adversarial_loss.pre_generator_step(
  101. real_batch=img,
  102. fake_batch=predicted_img,
  103. generator=self.generator,
  104. discriminator=self.discriminator)
  105. discr_real_pred, discr_real_features = self.discriminator(img)
  106. discr_fake_pred, discr_fake_features = self.discriminator(
  107. predicted_img)
  108. adv_gen_loss, adv_metrics = self.adversarial_loss.generator_loss(
  109. real_batch=img,
  110. fake_batch=predicted_img,
  111. discr_real_pred=discr_real_pred,
  112. discr_fake_pred=discr_fake_pred,
  113. mask=mask_for_discr)
  114. total_loss = total_loss + adv_gen_loss
  115. metrics['gen_adv'] = adv_gen_loss
  116. metrics.update(add_prefix_to_keys(adv_metrics, 'adv_'))
  117. # feature matching
  118. if self.feature_matching_weight > 0:
  119. need_mask_in_fm = False
  120. mask_for_fm = supervised_mask if need_mask_in_fm else None
  121. fm_value = feature_matching_loss(
  122. discr_fake_features, discr_real_features,
  123. mask=mask_for_fm) * self.feature_matching_weight
  124. total_loss = total_loss + fm_value
  125. metrics['gen_fm'] = fm_value
  126. if self.loss_resnet_pl is not None:
  127. resnet_pl_value = self.loss_resnet_pl(predicted_img, img)
  128. total_loss = total_loss + resnet_pl_value
  129. metrics['gen_resnet_pl'] = resnet_pl_value
  130. return total_loss, metrics
  131. def discriminator_loss(self, batch):
  132. total_loss = 0
  133. metrics = {}
  134. predicted_img = batch[self.image_to_discriminator].detach()
  135. self.adversarial_loss.pre_discriminator_step(
  136. real_batch=batch['image'],
  137. fake_batch=predicted_img,
  138. generator=self.generator,
  139. discriminator=self.discriminator)
  140. discr_real_pred, discr_real_features = self.discriminator(
  141. batch['image'])
  142. discr_fake_pred, discr_fake_features = self.discriminator(
  143. predicted_img)
  144. adv_discr_loss, adv_metrics = self.adversarial_loss.discriminator_loss(
  145. real_batch=batch['image'],
  146. fake_batch=predicted_img,
  147. discr_real_pred=discr_real_pred,
  148. discr_fake_pred=discr_fake_pred,
  149. mask=batch['mask'])
  150. total_loss = (total_loss + adv_discr_loss) * 0.1
  151. metrics['discr_adv'] = adv_discr_loss
  152. metrics.update(add_prefix_to_keys(adv_metrics, 'adv_'))
  153. return total_loss, metrics
  154. def _do_step(self, batch, optimizer_idx=None):
  155. if optimizer_idx == 0: # step for generator
  156. set_requires_grad(self.generator, True)
  157. set_requires_grad(self.discriminator, False)
  158. elif optimizer_idx == 1: # step for discriminator
  159. set_requires_grad(self.generator, False)
  160. set_requires_grad(self.discriminator, True)
  161. batch = self(batch)
  162. total_loss = 0
  163. if optimizer_idx is None or optimizer_idx == 0: # step for generator
  164. total_loss, metrics = self.generator_loss(batch)
  165. elif optimizer_idx is None or optimizer_idx == 1: # step for discriminator
  166. total_loss, metrics = self.discriminator_loss(batch)
  167. result = dict(loss=total_loss)
  168. return result