image_denoise_pipeline.py 4.1 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105
  1. # Copyright (c) Alibaba, Inc. and its affiliates.
  2. from typing import Any, Dict, Optional, Union
  3. import torch
  4. from torchvision import transforms
  5. from modelscope.metainfo import Pipelines
  6. from modelscope.models import Model
  7. from modelscope.models.cv.image_denoise import NAFNetForImageDenoise
  8. from modelscope.outputs import OutputKeys
  9. from modelscope.pipelines.base import Input, Pipeline
  10. from modelscope.pipelines.builder import PIPELINES
  11. from modelscope.preprocessors import ImageDenoisePreprocessor, LoadImage
  12. from modelscope.utils.constant import Tasks
  13. from modelscope.utils.logger import get_logger
  14. logger = get_logger()
  15. __all__ = ['ImageDenoisePipeline']
  16. @PIPELINES.register_module(
  17. Tasks.image_denoising, module_name=Pipelines.image_denoise)
  18. class ImageDenoisePipeline(Pipeline):
  19. def __init__(self,
  20. model: Union[NAFNetForImageDenoise, str],
  21. preprocessor: Optional[ImageDenoisePreprocessor] = None,
  22. **kwargs):
  23. """
  24. use `model` and `preprocessor` to create a cv image denoise pipeline for prediction
  25. Args:
  26. model: model id on modelscope hub.
  27. """
  28. super().__init__(model=model, preprocessor=preprocessor, **kwargs)
  29. self.model.eval()
  30. self.config = self.model.config
  31. if torch.cuda.is_available():
  32. self._device = torch.device('cuda')
  33. else:
  34. self._device = torch.device('cpu')
  35. logger.info('load image denoise model done')
  36. def preprocess(self, input: Input) -> Dict[str, Any]:
  37. img = LoadImage.convert_to_img(input)
  38. test_transforms = transforms.Compose([transforms.ToTensor()])
  39. img = test_transforms(img)
  40. result = {'img': img.unsqueeze(0).to(self._device)}
  41. return result
  42. def crop_process(self, input):
  43. output = torch.zeros_like(input) # [1, C, H, W]
  44. # determine crop_h and crop_w
  45. ih, iw = input.shape[-2:]
  46. crop_rows, crop_cols = max(ih // 512, 1), max(iw // 512, 1)
  47. overlap = 16
  48. step_h, step_w = ih // crop_rows, iw // crop_cols
  49. for y in range(crop_rows):
  50. for x in range(crop_cols):
  51. crop_y = step_h * y
  52. crop_x = step_w * x
  53. crop_h = step_h if y < crop_rows - 1 else ih - crop_y
  54. crop_w = step_w if x < crop_cols - 1 else iw - crop_x
  55. crop_frames = input[:, :,
  56. max(0, crop_y - overlap
  57. ):min(crop_y + crop_h + overlap, ih),
  58. max(0, crop_x - overlap
  59. ):min(crop_x + crop_w
  60. + overlap, iw)].contiguous()
  61. h_start = overlap if max(0, crop_y - overlap) > 0 else 0
  62. w_start = overlap if max(0, crop_x - overlap) > 0 else 0
  63. h_end = h_start + crop_h if min(crop_y + crop_h
  64. + overlap, ih) < ih else ih
  65. w_end = w_start + crop_w if min(crop_x + crop_w
  66. + overlap, iw) < iw else iw
  67. output[:, :, crop_y:crop_y + crop_h,
  68. crop_x:crop_x + crop_w] = self.model._inference_forward(
  69. crop_frames)['outputs'][:, :, h_start:h_end,
  70. w_start:w_end]
  71. return output
  72. def forward(self, input: Dict[str, Any]) -> Dict[str, Any]:
  73. def set_phase(model, is_train):
  74. if is_train:
  75. model.train()
  76. else:
  77. model.eval()
  78. is_train = False
  79. set_phase(self.model, is_train)
  80. with torch.no_grad():
  81. output = self.crop_process(input['img']) # output Tensor
  82. return {'output_tensor': output}
  83. def postprocess(self, input: Dict[str, Any]) -> Dict[str, Any]:
  84. output_img = (input['output_tensor'].squeeze(0) * 255).cpu().permute(
  85. 1, 2, 0).numpy().astype('uint8')
  86. return {OutputKeys.OUTPUT_IMG: output_img[:, :, ::-1]}