model_tf.py 2.1 KB

12345678910111213141516171819202122232425262728293031323334353637383940414243444546474849505152535455565758596061626364
  1. # Copyright (c) Alibaba, Inc. and its affiliates.
  2. from typing import Dict
  3. import tensorflow as tf
  4. from modelscope.models.base import Model, Tensor
  5. from .loss import content_loss, guided_filter, style_loss, total_variation_loss
  6. from .network import unet_generator
  7. class CartoonModel(Model):
  8. def __init__(self, model_dir, *args, **kwargs):
  9. super().__init__(model_dir, *args, **kwargs)
  10. self.model_dir = model_dir
  11. def __call__(
  12. self,
  13. input_photo: Dict[str, Tensor],
  14. input_cartoon: Dict[str, Tensor] = None,
  15. input_superpixel: Dict[str, Tensor] = None) -> Dict[str, Tensor]:
  16. """return the result by the model
  17. Args:
  18. input_photo: the preprocessed input photo image
  19. input_cartoon: the preprocessed input cartoon image
  20. input_superpixel: the computed input superpixel image
  21. Returns:
  22. output_dict: output dict of target ids
  23. """
  24. if input_cartoon is None:
  25. output = unet_generator(input_photo)
  26. output_cartoon = guided_filter(input_photo, output, r=1)
  27. return {'output_cartoon': output_cartoon}
  28. else:
  29. output = unet_generator(input_photo)
  30. output_cartoon = guided_filter(input_photo, output, r=1)
  31. con_loss = content_loss(self.model_dir, input_photo,
  32. output_cartoon, input_superpixel)
  33. sty_g_loss, sty_d_loss = style_loss(input_cartoon, output_cartoon)
  34. tv_loss = total_variation_loss(output_cartoon)
  35. g_loss = 1e-1 * sty_g_loss + 2e2 * con_loss + 1e4 * tv_loss
  36. d_loss = sty_d_loss
  37. return {
  38. 'output_cartoon': output_cartoon,
  39. 'g_loss': g_loss,
  40. 'd_loss': d_loss,
  41. }
  42. def forward(self, input: Dict[str, Tensor]) -> Dict[str, Tensor]:
  43. """
  44. Run the forward pass for a model.
  45. Args:
  46. input (Dict[str, Tensor]): the dict of the model inputs for the forward method
  47. Returns:
  48. Dict[str, Tensor]: output from the model forward pass
  49. """