video_stabilization_pipeline.py 4.5 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122
  1. # Modified from https://github.com/Annbless/DUTCode
  2. # Copyright (c) Alibaba, Inc. and its affiliates.
  3. import glob
  4. import math
  5. import os
  6. import subprocess
  7. import tempfile
  8. from typing import Any, Dict, Optional, Union
  9. import cv2
  10. import numpy as np
  11. import torch
  12. from modelscope.metainfo import Pipelines
  13. from modelscope.metrics.video_stabilization_metric import warpprocess
  14. from modelscope.models.cv.video_stabilization.DUTRAFTStabilizer import \
  15. DUTRAFTStabilizer
  16. from modelscope.outputs import OutputKeys
  17. from modelscope.pipelines.base import Input, Pipeline
  18. from modelscope.pipelines.builder import PIPELINES
  19. from modelscope.preprocessors import LoadImage
  20. from modelscope.preprocessors.cv import VideoReader
  21. from modelscope.utils.constant import ModelFile, Tasks
  22. from modelscope.utils.logger import get_logger
  23. logger = get_logger()
  24. def check_file_exist(filename, msg_tmpl='file "{}" does not exist'):
  25. if not osp.isfile(filename):
  26. raise FileNotFoundError(msg_tmpl.format(filename))
  27. __all__ = ['VideoStabilizationPipeline']
  28. @PIPELINES.register_module(
  29. Tasks.video_stabilization, module_name=Pipelines.video_stabilization)
  30. class VideoStabilizationPipeline(Pipeline):
  31. """ Video Stabilization Pipeline.
  32. Examples:
  33. >>> import cv2
  34. >>> from modelscope.outputs import OutputKeys
  35. >>> from modelscope.pipelines import pipeline
  36. >>> from modelscope.utils.constant import Tasks
  37. >>> test_video = 'https://modelscope.oss-cn-beijing.aliyuncs.com/test/videos/video_stabilization_test_video.avi'
  38. >>> video_stabilization = pipeline(Tasks.video_stabilization, model='damo/cv_dut-raft_video-stabilization_base')
  39. >>> out_video_path = video_stabilization(test_video)[OutputKeys.OUTPUT_VIDEO]
  40. >>> print('Pipeline: the output video path is {}'.format(out_video_path))
  41. """
  42. def __init__(self,
  43. model: Union[DUTRAFTStabilizer, str],
  44. preprocessor=None,
  45. **kwargs):
  46. super().__init__(model=model, preprocessor=preprocessor, **kwargs)
  47. logger.info('load video stabilization model done')
  48. def preprocess(self, input: Input) -> Dict[str, Any]:
  49. # read video
  50. video_reader = VideoReader(input)
  51. fps = video_reader.fps
  52. width = video_reader.width
  53. height = video_reader.height
  54. return {
  55. 'vid_path': input,
  56. 'fps': fps,
  57. 'width': width,
  58. 'height': height
  59. }
  60. def forward(self, input: Dict[str, Any]) -> Dict[str, Any]:
  61. results = self.model._inference_forward(input['vid_path'])
  62. results = warpprocess(results)
  63. out_images = results['output']
  64. out_images = out_images.numpy().astype(np.uint8)
  65. out_images = [
  66. np.transpose(out_images[idx], (1, 2, 0))
  67. for idx in range(out_images.shape[0])
  68. ]
  69. base_crop_width = results['base_crop_width']
  70. return {
  71. 'output': out_images,
  72. 'fps': input['fps'],
  73. 'base_crop_width': base_crop_width
  74. }
  75. def postprocess(self, inputs: Dict[str, Any], **kwargs) -> Dict[str, Any]:
  76. output_video_path = kwargs.get('output_video', None)
  77. is_cvt_h264 = kwargs.get('is_cvt_h264', False)
  78. if output_video_path is None:
  79. output_video_path = tempfile.NamedTemporaryFile(suffix='.mp4').name
  80. h, w = inputs['output'][0].shape[-3:-1]
  81. fourcc = cv2.VideoWriter_fourcc(*'mp4v')
  82. video_writer = cv2.VideoWriter(output_video_path, fourcc,
  83. inputs['fps'], (w, h))
  84. for idx, frame in enumerate(inputs['output']):
  85. horizontal_border = int(inputs['base_crop_width'] * w / 1280)
  86. vertical_border = int(horizontal_border * h / w)
  87. new_frame = frame[vertical_border:-vertical_border,
  88. horizontal_border:-horizontal_border]
  89. new_frame = cv2.resize(new_frame, (w, h))
  90. video_writer.write(new_frame)
  91. video_writer.release()
  92. if is_cvt_h264:
  93. assert os.system(
  94. 'ffmpeg -version'
  95. ) == 0, 'ffmpeg is not installed correctly, please refer to https://trac.ffmpeg.org/wiki/CompilationGuide.'
  96. output_video_path_for_web = output_video_path[:-4] + '_web.mp4'
  97. convert_cmd = f'ffmpeg -i {output_video_path} -vcodec h264 -crf 5 {output_video_path_for_web}'
  98. subprocess.call(convert_cmd, shell=True)
  99. return {OutputKeys.OUTPUT_VIDEO: output_video_path_for_web}
  100. else:
  101. return {OutputKeys.OUTPUT_VIDEO: output_video_path}