image_deblur_pipeline.py 4.5 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118
  1. # Copyright (c) Alibaba, Inc. and its affiliates.
  2. from typing import Any, Dict, Optional, Union
  3. import torch
  4. from torchvision import transforms
  5. from modelscope.metainfo import Pipelines
  6. from modelscope.models.cv.image_deblur import NAFNetForImageDeblur
  7. from modelscope.outputs import OutputKeys
  8. from modelscope.pipelines.base import Input, Pipeline
  9. from modelscope.pipelines.builder import PIPELINES
  10. from modelscope.preprocessors import ImageDeblurPreprocessor, LoadImage
  11. from modelscope.utils.constant import Tasks
  12. from modelscope.utils.logger import get_logger
  13. logger = get_logger()
  14. __all__ = ['ImageDeblurPipeline']
  15. @PIPELINES.register_module(
  16. Tasks.image_deblurring, module_name=Pipelines.image_deblur)
  17. class ImageDeblurPipeline(Pipeline):
  18. """
  19. Examples:
  20. >>> from modelscope.pipelines import pipeline
  21. >>> from modelscope.utils.constant import Tasks
  22. >>> from modelscope.outputs import OutputKeys
  23. >>> import cv2
  24. >>>
  25. >>> img = 'https://modelscope.oss-cn-beijing.aliyuncs.com/test/images/blurry.jpg'
  26. >>> image_deblur_pipeline = pipeline(Tasks.image_deblurring, 'damo/cv_nafnet_image-deblur_gopro')
  27. >>> result = image_deblur_pipeline(img)[OutputKeys.OUTPUT_IMG]
  28. >>> cv2.imwrite('result.png', result)
  29. """
  30. def __init__(self,
  31. model: Union[NAFNetForImageDeblur, str],
  32. preprocessor: Optional[ImageDeblurPreprocessor] = None,
  33. **kwargs):
  34. """
  35. use `model` and `preprocessor` to create a cv image deblur pipeline for prediction
  36. Args:
  37. model: model id on modelscope hub.
  38. """
  39. super().__init__(model=model, preprocessor=preprocessor, **kwargs)
  40. self.model.eval()
  41. self.config = self.model.config
  42. if torch.cuda.is_available():
  43. self._device = torch.device('cuda')
  44. else:
  45. self._device = torch.device('cpu')
  46. logger.info('load image denoise model done')
  47. def preprocess(self, input: Input) -> Dict[str, Any]:
  48. img = LoadImage.convert_to_img(input)
  49. test_transforms = transforms.Compose([transforms.ToTensor()])
  50. img = test_transforms(img)
  51. result = {'img': img.unsqueeze(0).to(self._device)}
  52. return result
  53. def crop_process(self, input):
  54. output = torch.zeros_like(input) # [1, C, H, W]
  55. # determine crop_h and crop_w
  56. ih, iw = input.shape[-2:]
  57. crop_rows, crop_cols = max(ih // 512, 1), max(iw // 512, 1)
  58. overlap = 16
  59. step_h, step_w = ih // crop_rows, iw // crop_cols
  60. for y in range(crop_rows):
  61. for x in range(crop_cols):
  62. crop_y = step_h * y
  63. crop_x = step_w * x
  64. crop_h = step_h if y < crop_rows - 1 else ih - crop_y
  65. crop_w = step_w if x < crop_cols - 1 else iw - crop_x
  66. crop_frames = input[:, :,
  67. max(0, crop_y - overlap
  68. ):min(crop_y + crop_h + overlap, ih),
  69. max(0, crop_x - overlap
  70. ):min(crop_x + crop_w
  71. + overlap, iw)].contiguous()
  72. h_start = overlap if max(0, crop_y - overlap) > 0 else 0
  73. w_start = overlap if max(0, crop_x - overlap) > 0 else 0
  74. h_end = h_start + crop_h if min(crop_y + crop_h
  75. + overlap, ih) < ih else ih
  76. w_end = w_start + crop_w if min(crop_x + crop_w
  77. + overlap, iw) < iw else iw
  78. output[:, :, crop_y:crop_y + crop_h,
  79. crop_x:crop_x + crop_w] = self.model._inference_forward(
  80. crop_frames)['outputs'][:, :, h_start:h_end,
  81. w_start:w_end]
  82. return output
  83. def forward(self, input: Dict[str, Any]) -> Dict[str, Any]:
  84. def set_phase(model, is_train):
  85. if is_train:
  86. model.train()
  87. else:
  88. model.eval()
  89. is_train = False
  90. set_phase(self.model, is_train)
  91. with torch.no_grad():
  92. output = self.crop_process(input['img']) # output Tensor
  93. return {'output_tensor': output}
  94. def postprocess(self, input: Dict[str, Any]) -> Dict[str, Any]:
  95. output_img = (input['output_tensor'].squeeze(0) * 255).cpu().permute(
  96. 1, 2, 0).numpy().astype('uint8')
  97. return {OutputKeys.OUTPUT_IMG: output_img[:, :, ::-1]}