| 12345678910111213141516171819202122232425262728293031323334353637383940414243444546474849505152535455565758596061626364 |
- # Copyright (c) Alibaba, Inc. and its affiliates.
- from typing import Dict
- import tensorflow as tf
- from modelscope.models.base import Model, Tensor
- from .loss import content_loss, guided_filter, style_loss, total_variation_loss
- from .network import unet_generator
- class CartoonModel(Model):
- def __init__(self, model_dir, *args, **kwargs):
- super().__init__(model_dir, *args, **kwargs)
- self.model_dir = model_dir
- def __call__(
- self,
- input_photo: Dict[str, Tensor],
- input_cartoon: Dict[str, Tensor] = None,
- input_superpixel: Dict[str, Tensor] = None) -> Dict[str, Tensor]:
- """return the result by the model
- Args:
- input_photo: the preprocessed input photo image
- input_cartoon: the preprocessed input cartoon image
- input_superpixel: the computed input superpixel image
- Returns:
- output_dict: output dict of target ids
- """
- if input_cartoon is None:
- output = unet_generator(input_photo)
- output_cartoon = guided_filter(input_photo, output, r=1)
- return {'output_cartoon': output_cartoon}
- else:
- output = unet_generator(input_photo)
- output_cartoon = guided_filter(input_photo, output, r=1)
- con_loss = content_loss(self.model_dir, input_photo,
- output_cartoon, input_superpixel)
- sty_g_loss, sty_d_loss = style_loss(input_cartoon, output_cartoon)
- tv_loss = total_variation_loss(output_cartoon)
- g_loss = 1e-1 * sty_g_loss + 2e2 * con_loss + 1e4 * tv_loss
- d_loss = sty_d_loss
- return {
- 'output_cartoon': output_cartoon,
- 'g_loss': g_loss,
- 'd_loss': d_loss,
- }
- def forward(self, input: Dict[str, Tensor]) -> Dict[str, Tensor]:
- """
- Run the forward pass for a model.
- Args:
- input (Dict[str, Tensor]): the dict of the model inputs for the forward method
- Returns:
- Dict[str, Tensor]: output from the model forward pass
- """
|