base.py 2.6 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475
  1. """
  2. Part of the implementation is borrowed and modified from LaMa, publicly available at
  3. https://github.com/saic-mdal/lama
  4. """
  5. from typing import Dict, Tuple
  6. import torch
  7. import torch.nn as nn
  8. import torch.nn.functional as F
  9. from modelscope.utils.logger import get_logger
  10. from .modules.adversarial import NonSaturatingWithR1
  11. from .modules.ffc import FFCResNetGenerator
  12. from .modules.perceptual import ResNetPL
  13. from .modules.pix2pixhd import NLayerDiscriminator
  14. LOGGER = get_logger()
  15. class BaseInpaintingTrainingModule(nn.Module):
  16. def __init__(self,
  17. model_dir='',
  18. use_ddp=True,
  19. predict_only=False,
  20. visualize_each_iters=100,
  21. average_generator=False,
  22. generator_avg_beta=0.999,
  23. average_generator_start_step=30000,
  24. average_generator_period=10,
  25. store_discr_outputs_for_vis=False,
  26. **kwargs):
  27. super().__init__()
  28. LOGGER.info(
  29. f'BaseInpaintingTrainingModule init called, predict_only is {predict_only}'
  30. )
  31. self.generator = FFCResNetGenerator()
  32. self.use_ddp = use_ddp
  33. if not predict_only:
  34. self.discriminator = NLayerDiscriminator()
  35. self.adversarial_loss = NonSaturatingWithR1(
  36. weight=10,
  37. gp_coef=0.001,
  38. mask_as_fake_target=True,
  39. allow_scale_mask=True)
  40. self.average_generator = average_generator
  41. self.generator_avg_beta = generator_avg_beta
  42. self.average_generator_start_step = average_generator_start_step
  43. self.average_generator_period = average_generator_period
  44. self.generator_average = None
  45. self.last_generator_averaging_step = -1
  46. self.store_discr_outputs_for_vis = store_discr_outputs_for_vis
  47. self.loss_l1 = nn.L1Loss(reduction='none')
  48. self.loss_resnet_pl = ResNetPL(weight=30, weights_path=model_dir)
  49. self.visualize_each_iters = visualize_each_iters
  50. LOGGER.info('BaseInpaintingTrainingModule init done')
  51. def forward(self, batch: Dict[str,
  52. torch.Tensor]) -> Dict[str, torch.Tensor]:
  53. """Pass data through generator and obtain at leas 'predicted_image' and 'inpainted' keys"""
  54. raise NotImplementedError()
  55. def generator_loss(self,
  56. batch) -> Tuple[torch.Tensor, Dict[str, torch.Tensor]]:
  57. raise NotImplementedError()
  58. def discriminator_loss(
  59. self, batch) -> Tuple[torch.Tensor, Dict[str, torch.Tensor]]:
  60. raise NotImplementedError()