image_cartoon_pipeline.py 5.1 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148
  1. # Copyright (c) Alibaba, Inc. and its affiliates.
  2. import os
  3. from typing import Any, Dict
  4. import cv2
  5. import numpy as np
  6. import tensorflow as tf
  7. from modelscope.metainfo import Pipelines
  8. from modelscope.models.cv.cartoon import (FaceAna, get_f5p,
  9. get_reference_facial_points,
  10. padTo16x, resize_size,
  11. warp_and_crop_face)
  12. from modelscope.outputs import OutputKeys
  13. from modelscope.pipelines.base import Input, Pipeline
  14. from modelscope.pipelines.builder import PIPELINES
  15. from modelscope.preprocessors import LoadImage
  16. from modelscope.utils.constant import Tasks
  17. from modelscope.utils.logger import get_logger
  18. from ...utils.device import device_placement
  19. if tf.__version__ >= '2.0':
  20. tf = tf.compat.v1
  21. tf.disable_eager_execution()
  22. logger = get_logger()
  23. @PIPELINES.register_module(
  24. Tasks.image_portrait_stylization,
  25. module_name=Pipelines.person_image_cartoon)
  26. class ImageCartoonPipeline(Pipeline):
  27. def __init__(self, model: str, **kwargs):
  28. """
  29. use `model` to create a image cartoon pipeline for prediction
  30. Args:
  31. model: model id on modelscope hub.
  32. """
  33. super().__init__(model=model, **kwargs)
  34. self.facer = FaceAna(self.model)
  35. with tf.Graph().as_default():
  36. self.sess_anime_head = self.load_sess(
  37. os.path.join(self.model, 'cartoon_h.pb'), 'model_anime_head')
  38. self.sess_anime_bg = self.load_sess(
  39. os.path.join(self.model, 'cartoon_bg.pb'), 'model_anime_bg')
  40. self.box_width = 288
  41. global_mask = cv2.imread(os.path.join(self.model, 'alpha.jpg'))
  42. global_mask = cv2.resize(
  43. global_mask, (self.box_width, self.box_width),
  44. interpolation=cv2.INTER_AREA)
  45. self.global_mask = cv2.cvtColor(
  46. global_mask, cv2.COLOR_BGR2GRAY).astype(np.float32) / 255.0
  47. def load_sess(self, model_path, name):
  48. config = tf.ConfigProto(allow_soft_placement=True)
  49. config.gpu_options.allow_growth = True
  50. sess = tf.Session(config=config)
  51. logger.info(f'loading model from {model_path}')
  52. with tf.gfile.FastGFile(model_path, 'rb') as f:
  53. graph_def = tf.GraphDef()
  54. graph_def.ParseFromString(f.read())
  55. sess.graph.as_default()
  56. tf.import_graph_def(graph_def, name=name)
  57. sess.run(tf.global_variables_initializer())
  58. logger.info(f'load model {model_path} done.')
  59. return sess
  60. def preprocess(self, input: Input) -> Dict[str, Any]:
  61. img = LoadImage.convert_to_ndarray(input)
  62. img = img.astype(float)
  63. result = {'img': img}
  64. return result
  65. def detect_face(self, img):
  66. src_h, src_w, _ = img.shape
  67. boxes, landmarks, _ = self.facer.run(img)
  68. if boxes.shape[0] == 0:
  69. return None
  70. else:
  71. return landmarks
  72. def forward(self, input: Dict[str, Any]) -> Dict[str, Any]:
  73. img = input['img'].astype(np.uint8)
  74. ori_h, ori_w, _ = img.shape
  75. img = resize_size(img, size=720)
  76. img_brg = img[:, :, ::-1]
  77. # background process
  78. pad_bg, pad_h, pad_w = padTo16x(img_brg)
  79. bg_res = self.sess_anime_bg.run(
  80. self.sess_anime_bg.graph.get_tensor_by_name(
  81. 'model_anime_bg/output_image:0'),
  82. feed_dict={'model_anime_bg/input_image:0': pad_bg})
  83. res = bg_res[:pad_h, :pad_w, :]
  84. landmarks = self.detect_face(img)
  85. if landmarks is None:
  86. print('No face detected!')
  87. return {OutputKeys.OUTPUT_IMG: res}
  88. for landmark in landmarks:
  89. # get facial 5 points
  90. f5p = get_f5p(landmark, img_brg)
  91. # face alignment
  92. head_img, trans_inv = warp_and_crop_face(
  93. img,
  94. f5p,
  95. ratio=0.75,
  96. reference_pts=get_reference_facial_points(default_square=True),
  97. crop_size=(self.box_width, self.box_width),
  98. return_trans_inv=True)
  99. # head process
  100. head_res = self.sess_anime_head.run(
  101. self.sess_anime_head.graph.get_tensor_by_name(
  102. 'model_anime_head/output_image:0'),
  103. feed_dict={
  104. 'model_anime_head/input_image:0': head_img[:, :, ::-1]
  105. })
  106. # merge head and background
  107. head_trans_inv = cv2.warpAffine(
  108. head_res,
  109. trans_inv, (np.size(img, 1), np.size(img, 0)),
  110. borderValue=(0, 0, 0))
  111. mask = self.global_mask
  112. mask_trans_inv = cv2.warpAffine(
  113. mask,
  114. trans_inv, (np.size(img, 1), np.size(img, 0)),
  115. borderValue=(0, 0, 0))
  116. mask_trans_inv = np.expand_dims(mask_trans_inv, 2)
  117. res = mask_trans_inv * head_trans_inv + (1 - mask_trans_inv) * res
  118. res = cv2.resize(res, (ori_w, ori_h), interpolation=cv2.INTER_AREA)
  119. return {OutputKeys.OUTPUT_IMG: res}
  120. def postprocess(self, inputs: Dict[str, Any]) -> Dict[str, Any]:
  121. return inputs