image_matting_pipeline.py 2.5 KB

1234567891011121314151617181920212223242526272829303132333435363738394041424344454647484950515253545556575859606162636465666768
  1. # Copyright (c) Alibaba, Inc. and its affiliates.
  2. import os.path as osp
  3. from typing import Any, Dict
  4. import cv2
  5. import numpy as np
  6. import tensorflow as tf
  7. from modelscope.metainfo import Pipelines
  8. from modelscope.outputs import OutputKeys
  9. from modelscope.pipelines.base import Input, Pipeline
  10. from modelscope.pipelines.builder import PIPELINES
  11. from modelscope.preprocessors import LoadImage
  12. from modelscope.utils.constant import ModelFile, Tasks
  13. from modelscope.utils.device import device_placement
  14. from modelscope.utils.logger import get_logger
  15. if tf.__version__ >= '2.0':
  16. tf = tf.compat.v1
  17. logger = get_logger()
  18. @PIPELINES.register_module(
  19. Tasks.portrait_matting, module_name=Pipelines.portrait_matting)
  20. @PIPELINES.register_module(
  21. Tasks.universal_matting, module_name=Pipelines.universal_matting)
  22. class ImageMattingPipeline(Pipeline):
  23. def __init__(self, model: str, **kwargs):
  24. """
  25. use `model` to create a image matting pipeline for prediction
  26. Args:
  27. model: model id on modelscope hub.
  28. """
  29. super().__init__(model=model, **kwargs)
  30. model_path = osp.join(self.model, ModelFile.TF_GRAPH_FILE)
  31. with device_placement(self.framework, self.device_name):
  32. config = tf.ConfigProto(allow_soft_placement=True)
  33. config.gpu_options.allow_growth = True
  34. self._session = tf.Session(config=config)
  35. with self._session.as_default():
  36. logger.info(f'loading model from {model_path}')
  37. with tf.gfile.FastGFile(model_path, 'rb') as f:
  38. graph_def = tf.GraphDef()
  39. graph_def.ParseFromString(f.read())
  40. tf.import_graph_def(graph_def, name='')
  41. self.output = self._session.graph.get_tensor_by_name(
  42. 'output_png:0')
  43. self.input_name = 'input_image:0'
  44. logger.info('load model done')
  45. def preprocess(self, input: Input) -> Dict[str, Any]:
  46. img = LoadImage.convert_to_ndarray(input)
  47. img = img.astype(float)
  48. result = {'img': img}
  49. return result
  50. def forward(self, input: Dict[str, Any]) -> Dict[str, Any]:
  51. with self._session.as_default():
  52. feed_dict = {self.input_name: input['img']}
  53. output_img = self._session.run(self.output, feed_dict=feed_dict)
  54. output_img = cv2.cvtColor(output_img, cv2.COLOR_RGBA2BGRA)
  55. return {OutputKeys.OUTPUT_IMG: output_img}
  56. def postprocess(self, inputs: Dict[str, Any]) -> Dict[str, Any]:
  57. return inputs