image_skychange_pipeline.py 2.2 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960
  1. # Copyright (c) Alibaba, Inc. and its affiliates.
  2. import pdb
  3. import time
  4. from typing import Any, Dict, Union
  5. import cv2
  6. import numpy as np
  7. import PIL
  8. from modelscope.metainfo import Pipelines
  9. from modelscope.models.cv.image_skychange import ImageSkyChangePreprocessor
  10. from modelscope.outputs import OutputKeys
  11. from modelscope.pipelines.base import Input, Model, Pipeline
  12. from modelscope.pipelines.builder import PIPELINES
  13. from modelscope.preprocessors import LoadImage
  14. from modelscope.utils.constant import Tasks
  15. from modelscope.utils.logger import get_logger
  16. logger = get_logger()
  17. @PIPELINES.register_module(
  18. Tasks.image_skychange, module_name=Pipelines.image_skychange)
  19. class ImageSkychangePipeline(Pipeline):
  20. """
  21. Image Sky Change Pipeline. Given two images(sky_image and scene_image), pipeline will replace the sky style
  22. of sky_image with the sky style of scene_image.
  23. Examples:
  24. >>> from modelscope.pipelines import pipeline
  25. >>> detector = pipeline('image-skychange', 'damo/cv_hrnetocr_skychange')
  26. >>> detector({
  27. 'sky_image': 'sky_image.jpg', # sky_image path (str)
  28. 'scene_image': 'scene_image.jpg', # scene_image path (str)
  29. })
  30. >>> {"output_img": [H * W * 3] 0~255, we can use cv2.imwrite to save output_img as an image.}
  31. """
  32. def __init__(self, model: str, **kwargs):
  33. """
  34. use `model` to create a image sky change pipeline for image editing
  35. Args:
  36. model (`str` or `Model`): model_id on modelscope hub
  37. preprocessor(`Preprocessor`, *optional*, defaults to None): `ImageSkyChangePreprocessor`.
  38. """
  39. super().__init__(model=model, **kwargs)
  40. if not isinstance(self.model, Model):
  41. logger.error('model object is not initialized.')
  42. raise Exception('model object is not initialized.')
  43. if self.preprocessor is None:
  44. self.preprocessor = ImageSkyChangePreprocessor()
  45. logger.info('load model done')
  46. def forward(self, input: Dict[str, Any]) -> Dict[str, Any]:
  47. res = self.model.forward(**input)
  48. return {OutputKeys.OUTPUT_IMG: res}
  49. def postprocess(self, inputs: Dict[str, Any]) -> Dict[str, Any]:
  50. return inputs