image_debanding_pipeline.py 2.6 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566
  1. # Copyright (c) Alibaba, Inc. and its affiliates.
  2. from typing import Any, Dict, Optional, Union
  3. import torch
  4. from torchvision import transforms
  5. from modelscope.metainfo import Pipelines
  6. from modelscope.models.base import Model
  7. from modelscope.models.cv.image_debanding import RRDBImageDebanding
  8. from modelscope.outputs import OutputKeys
  9. from modelscope.pipelines.base import Input, Pipeline
  10. from modelscope.pipelines.builder import PIPELINES
  11. from modelscope.preprocessors import LoadImage
  12. from modelscope.utils.constant import Tasks
  13. from modelscope.utils.logger import get_logger
  14. logger = get_logger()
  15. @PIPELINES.register_module(
  16. Tasks.image_debanding, module_name=Pipelines.image_debanding)
  17. class ImageDebandingPipeline(Pipeline):
  18. def __init__(self, model: Union[RRDBImageDebanding, str], **kwargs):
  19. """The inference pipeline for image debanding.
  20. Args:
  21. model (`str` or `Model` or module instance): A model instance or a model local dir
  22. or a model id in the model hub.
  23. preprocessor (`Preprocessor`, `optional`): A Preprocessor instance.
  24. kwargs (dict, `optional`):
  25. Extra kwargs passed into the preprocessor's constructor.
  26. Example:
  27. >>> import cv2
  28. >>> from modelscope.outputs import OutputKeys
  29. >>> from modelscope.pipelines import pipeline
  30. >>> from modelscope.utils.constant import Tasks
  31. >>> debanding = pipeline(Tasks.image_debanding, model='damo/cv_rrdb_image-debanding')
  32. result = debanding(
  33. 'https://modelscope.oss-cn-beijing.aliyuncs.com/test/images/debanding.png')
  34. >>> cv2.imwrite('result.png', result[OutputKeys.OUTPUT_IMG])
  35. """
  36. super().__init__(model=model, **kwargs)
  37. self.model.eval()
  38. if torch.cuda.is_available():
  39. self._device = torch.device('cuda')
  40. else:
  41. self._device = torch.device('cpu')
  42. def preprocess(self, input: Input) -> Dict[str, Any]:
  43. img = LoadImage.convert_to_img(input)
  44. test_transforms = transforms.Compose([transforms.ToTensor()])
  45. img = test_transforms(img)
  46. result = {'src': img.unsqueeze(0).to(self._device)}
  47. return result
  48. @torch.no_grad()
  49. def forward(self, input: Dict[str, Any]) -> Dict[str, Any]:
  50. return super().forward(input)
  51. def postprocess(self, inputs: Dict[str, Any]) -> Dict[str, Any]:
  52. output_img = (inputs['outputs'].squeeze(0) * 255.).type(
  53. torch.uint8).cpu().permute(1, 2, 0).numpy()[:, :, ::-1]
  54. return {OutputKeys.OUTPUT_IMG: output_img}