video_deinterlace_pipeline.py 7.6 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186
  1. # The implementation here is modified based on RealBasicVSR,
  2. # originally Apache 2.0 License and publicly available at
  3. # https://github.com/ckkelvinchan/RealBasicVSR/blob/master/inference_realbasicvsr.py
  4. import math
  5. import os
  6. import subprocess
  7. import tempfile
  8. from typing import Any, Dict, Optional, Union
  9. import cv2
  10. import numpy as np
  11. import torch
  12. from torchvision.utils import make_grid
  13. from modelscope.metainfo import Pipelines
  14. from modelscope.models.cv.video_deinterlace.UNet_for_video_deinterlace import \
  15. UNetForVideoDeinterlace
  16. from modelscope.outputs import OutputKeys
  17. from modelscope.pipelines.base import Input, Pipeline
  18. from modelscope.pipelines.builder import PIPELINES
  19. from modelscope.preprocessors.cv import VideoReader
  20. from modelscope.utils.constant import Tasks
  21. from modelscope.utils.logger import get_logger
  22. VIDEO_EXTENSIONS = ('.mp4', '.mov')
  23. logger = get_logger()
  24. def tensor2img(tensor, out_type=np.uint8, min_max=(0, 1)):
  25. """Convert torch Tensors into image numpy arrays.
  26. After clamping to (min, max), image values will be normalized to [0, 1].
  27. For different tensor shapes, this function will have different behaviors:
  28. 1. 4D mini-batch Tensor of shape (N x 3/1 x H x W):
  29. Use `make_grid` to stitch images in the batch dimension, and then
  30. convert it to numpy array.
  31. 2. 3D Tensor of shape (3/1 x H x W) and 2D Tensor of shape (H x W):
  32. Directly change to numpy array.
  33. Note that the image channel in input tensors should be RGB order. This
  34. function will convert it to cv2 convention, i.e., (H x W x C) with BGR
  35. order.
  36. Args:
  37. tensor (Tensor | list[Tensor]): Input tensors.
  38. out_type (numpy type): Output types. If ``np.uint8``, transform outputs
  39. to uint8 type with range [0, 255]; otherwise, float type with
  40. range [0, 1]. Default: ``np.uint8``.
  41. min_max (tuple): min and max values for clamp.
  42. Returns:
  43. (Tensor | list[Tensor]): 3D ndarray of shape (H x W x C) or 2D ndarray
  44. of shape (H x W).
  45. """
  46. condition = torch.is_tensor(tensor) or (isinstance(tensor, list) and all(
  47. torch.is_tensor(t) for t in tensor))
  48. if not condition:
  49. raise TypeError(
  50. f'tensor or list of tensors expected, got {type(tensor)}')
  51. if torch.is_tensor(tensor):
  52. tensor = [tensor]
  53. result = []
  54. for _tensor in tensor:
  55. # Squeeze two times so that:
  56. # 1. (1, 1, h, w) -> (h, w) or
  57. # 3. (1, 3, h, w) -> (3, h, w) or
  58. # 2. (n>1, 3/1, h, w) -> (n>1, 3/1, h, w)
  59. _tensor = _tensor.squeeze(0).squeeze(0)
  60. _tensor = _tensor.float().detach().cpu().clamp_(*min_max)
  61. _tensor = (_tensor - min_max[0]) / (min_max[1] - min_max[0])
  62. n_dim = _tensor.dim()
  63. if n_dim == 4:
  64. img_np = make_grid(
  65. _tensor, nrow=int(math.sqrt(_tensor.size(0))),
  66. normalize=False).numpy()
  67. img_np = np.transpose(img_np[[2, 1, 0], :, :], (1, 2, 0))
  68. elif n_dim == 3:
  69. img_np = _tensor.numpy()
  70. img_np = np.transpose(img_np[[2, 1, 0], :, :], (1, 2, 0))
  71. elif n_dim == 2:
  72. img_np = _tensor.numpy()
  73. else:
  74. raise ValueError('Only support 4D, 3D or 2D tensor. '
  75. f'But received with dimension: {n_dim}')
  76. if out_type == np.uint8:
  77. # Unlike MATLAB, numpy.unit8() WILL NOT round by default.
  78. img_np = (img_np * 255.0).round()
  79. img_np = img_np.astype(out_type)
  80. result.append(img_np)
  81. result = result[0] if len(result) == 1 else result
  82. return result
  83. @PIPELINES.register_module(
  84. Tasks.video_deinterlace, module_name=Pipelines.video_deinterlace)
  85. class VideoDeinterlacePipeline(Pipeline):
  86. def __init__(self,
  87. model: Union[UNetForVideoDeinterlace, str],
  88. preprocessor=None,
  89. **kwargs):
  90. """The inference pipeline for all the video deinterlace sub-tasks.
  91. Args:
  92. model (`str` or `Model` or module instance): A model instance or a model local dir
  93. or a model id in the model hub.
  94. preprocessor (`Preprocessor`, `optional`): A Preprocessor instance.
  95. kwargs (dict, `optional`):
  96. Extra kwargs passed into the preprocessor's constructor.
  97. Example:
  98. >>> from modelscope.pipelines import pipeline
  99. >>> pipeline_ins = pipeline('video-deinterlace',
  100. model='damo/cv_unet_video-deinterlace')
  101. >>> input = 'https://modelscope.oss-cn-beijing.aliyuncs.com/test/videos/video_deinterlace_test.mp4'
  102. >>> print(pipeline_ins(input)[OutputKeys.OUTPUT_VIDEO])
  103. """
  104. super().__init__(model=model, preprocessor=preprocessor, **kwargs)
  105. if torch.cuda.is_available():
  106. self._device = torch.device('cuda')
  107. else:
  108. self._device = torch.device('cpu')
  109. self.net = self.model.model
  110. self.net.to(self._device)
  111. self.net.eval()
  112. logger.info('load video deinterlace model done')
  113. def preprocess(self, input: Input) -> Dict[str, Any]:
  114. # input is a video file
  115. video_reader = VideoReader(input)
  116. inputs = []
  117. for frame in video_reader:
  118. inputs.append(np.flip(frame, axis=2))
  119. fps = video_reader.fps
  120. for i, img in enumerate(inputs):
  121. img = torch.from_numpy(img / 255.).permute(2, 0, 1).float()
  122. inputs[i] = img.unsqueeze(0)
  123. inputs = torch.stack(inputs, dim=1)
  124. return {'video': inputs, 'fps': fps}
  125. def forward(self, input: Dict[str, Any]) -> Dict[str, Any]:
  126. inputs = input['video'][0]
  127. frenet = self.net.frenet
  128. enhnet = self.net.enhnet
  129. with torch.no_grad():
  130. outputs = []
  131. frames = []
  132. for i in range(0, inputs.size(0)):
  133. frames.append(frenet(inputs[i:i + 1, ...].to(self._device)))
  134. if i == 0:
  135. frames = [frames[-1]] * 2
  136. continue
  137. outputs.append(enhnet(frames).cpu().unsqueeze(1))
  138. frames = frames[1:]
  139. frames.append(frames[-1])
  140. outputs.append(enhnet(frames).cpu().unsqueeze(1))
  141. outputs = torch.cat(outputs, dim=1)
  142. return {'output': outputs, 'fps': input['fps']}
  143. def postprocess(self, inputs: Dict[str, Any], **kwargs) -> Dict[str, Any]:
  144. output_video_path = kwargs.get('output_video', None)
  145. demo_service = kwargs.get('demo_service', False)
  146. if output_video_path is None:
  147. output_video_path = tempfile.NamedTemporaryFile(suffix='.mp4').name
  148. h, w = inputs['output'].shape[-2:]
  149. fourcc = cv2.VideoWriter_fourcc(*'mp4v')
  150. video_writer = cv2.VideoWriter(output_video_path, fourcc,
  151. inputs['fps'], (w, h))
  152. for i in range(0, inputs['output'].size(1)):
  153. img = tensor2img(inputs['output'][:, i, :, :, :])
  154. video_writer.write(img.astype(np.uint8))
  155. video_writer.release()
  156. if demo_service:
  157. assert os.system(
  158. 'ffmpeg -version'
  159. ) == 0, 'ffmpeg is not installed correctly, please refer to https://trac.ffmpeg.org/wiki/CompilationGuide.'
  160. output_video_path_for_web = output_video_path[:-4] + '_web.mp4'
  161. convert_cmd = f'ffmpeg -i {output_video_path} -vcodec h264 -crf 5 {output_video_path_for_web}'
  162. subprocess.call(convert_cmd, shell=True)
  163. return {OutputKeys.OUTPUT_VIDEO: output_video_path_for_web}
  164. else:
  165. return {OutputKeys.OUTPUT_VIDEO: output_video_path}