video_stabilization.py 1.5 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445
  1. # Part of the implementation is borrowed and modified from DUTCode,
  2. # publicly available at https://github.com/Annbless/DUTCode
  3. import cv2
  4. import numpy as np
  5. import torch
  6. import torch.nn as nn
  7. from modelscope.preprocessors.cv import VideoReader
  8. def stabilization_preprocessor(input, cfg):
  9. video_reader = VideoReader(input)
  10. inputs = []
  11. for frame in video_reader:
  12. inputs.append(np.flip(frame, axis=2))
  13. fps = video_reader.fps
  14. w = video_reader.width
  15. h = video_reader.height
  16. rgb_images = []
  17. images = []
  18. ori_images = []
  19. for i, frame in enumerate(inputs):
  20. frame = cv2.cvtColor(frame, cv2.COLOR_RGB2BGR)
  21. image = cv2.cvtColor(frame, cv2.COLOR_BGR2GRAY)
  22. image = image * (1. / 255.)
  23. image = cv2.resize(image, (cfg.MODEL.WIDTH, cfg.MODEL.HEIGHT))
  24. images.append(image.reshape(1, 1, cfg.MODEL.HEIGHT, cfg.MODEL.WIDTH))
  25. rgb_image = cv2.resize(frame, (cfg.MODEL.WIDTH, cfg.MODEL.HEIGHT))
  26. rgb_images.append(
  27. np.expand_dims(np.transpose(rgb_image, (2, 0, 1)), 0))
  28. ori_images.append(np.expand_dims(np.transpose(frame, (2, 0, 1)), 0))
  29. x = np.concatenate(images, 1).astype(np.float32)
  30. x = torch.from_numpy(x).unsqueeze(0)
  31. x_rgb = np.concatenate(rgb_images, 0).astype(np.float32)
  32. x_rgb = torch.from_numpy(x_rgb).unsqueeze(0)
  33. return {
  34. 'ori_images': ori_images,
  35. 'x': x,
  36. 'x_rgb': x_rgb,
  37. 'fps': fps,
  38. 'width': w,
  39. 'height': h
  40. }