DUTRAFTStabilizer.py 3.4 KB

12345678910111213141516171819202122232425262728293031323334353637383940414243444546474849505152535455565758596061626364656667686970717273747576777879808182838485868788899091929394
  1. # Part of the implementation is borrowed and modified from DUTCode,
  2. # publicly available at https://github.com/Annbless/DUTCode
  3. import math
  4. import os
  5. import sys
  6. import tempfile
  7. from typing import Any, Dict, Optional, Union
  8. import cv2
  9. import numpy as np
  10. import torch
  11. import torch.nn as nn
  12. from modelscope.metainfo import Models
  13. from modelscope.models.base import Tensor
  14. from modelscope.models.base.base_torch_model import TorchModel
  15. from modelscope.models.builder import MODELS
  16. from modelscope.models.cv.video_stabilization.DUT.config import cfg
  17. from modelscope.models.cv.video_stabilization.DUT.DUT_raft import DUT
  18. from modelscope.preprocessors.cv import VideoReader, stabilization_preprocessor
  19. from modelscope.utils.config import Config
  20. from modelscope.utils.constant import ModelFile, Tasks
  21. from modelscope.utils.logger import get_logger
  22. __all__ = ['DUTRAFTStabilizer']
  23. @MODELS.register_module(
  24. Tasks.video_stabilization, module_name=Models.video_stabilization)
  25. class DUTRAFTStabilizer(TorchModel):
  26. def __init__(self, model_dir: str, *args, **kwargs):
  27. """initialize the video stabilization model from the `model_dir` path.
  28. Args:
  29. model_dir (str): the model path.
  30. """
  31. super().__init__(model_dir, *args, **kwargs)
  32. self.model_dir = model_dir
  33. self.config = Config.from_file(
  34. os.path.join(self.model_dir, ModelFile.CONFIGURATION))
  35. SmootherPath = os.path.join(self.model_dir,
  36. self.config.modelsetting.SmootherPath)
  37. RFDetPath = os.path.join(self.model_dir,
  38. self.config.modelsetting.RFDetPath)
  39. RAFTPath = os.path.join(self.model_dir,
  40. self.config.modelsetting.RAFTPath)
  41. MotionProPath = os.path.join(self.model_dir,
  42. self.config.modelsetting.MotionProPath)
  43. homo = self.config.modelsetting.homo
  44. args = self.config.modelsetting.args
  45. self.base_crop_width = self.config.modelsetting.base_crop_width
  46. self.net = DUT(
  47. SmootherPath=SmootherPath,
  48. RFDetPath=RFDetPath,
  49. RAFTPath=RAFTPath,
  50. MotionProPath=MotionProPath,
  51. homo=homo,
  52. args=args)
  53. self.net.cuda()
  54. self.net.eval()
  55. def _inference_forward(self, input: str) -> Dict[str, Any]:
  56. data = stabilization_preprocessor(input, cfg)
  57. with torch.no_grad():
  58. origin_motion, smooth_path = self.net.inference(
  59. data['x'].cuda(), data['x_rgb'].cuda(), repeat=50)
  60. origin_motion = origin_motion.cpu().numpy()
  61. smooth_path = smooth_path.cpu().numpy()
  62. origin_motion = np.transpose(origin_motion[0], (2, 3, 1, 0))
  63. smooth_path = np.transpose(smooth_path[0], (2, 3, 1, 0))
  64. return {
  65. 'origin_motion': origin_motion,
  66. 'smooth_path': smooth_path,
  67. 'ori_images': data['ori_images'],
  68. 'fps': data['fps'],
  69. 'width': data['width'],
  70. 'height': data['height'],
  71. 'base_crop_width': self.base_crop_width
  72. }
  73. def forward(self, inputs: Dict[str, str]) -> Dict[str, Any]:
  74. """return the result by the model
  75. Args:
  76. inputs (str): the input video path
  77. Returns:
  78. Dict[str, str]: results
  79. """
  80. return self._inference_forward(inputs['input'][0])