video_summarization_pipeline.py 4.2 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119
  1. # Part of the implementation is borrowed and modified from PGL-SUM,
  2. # publicly available at https://github.com/e-apostolidis/PGL-SUM
  3. import os.path as osp
  4. from typing import Any, Dict
  5. import cv2
  6. import numpy as np
  7. import torch
  8. from tqdm import tqdm
  9. from modelscope.metainfo import Pipelines
  10. from modelscope.models.cv.video_summarization import (PGLVideoSummarization,
  11. summary_format)
  12. from modelscope.models.cv.video_summarization.base_model import bvlc_googlenet
  13. from modelscope.models.cv.video_summarization.summarizer import (
  14. generate_summary, get_change_points)
  15. from modelscope.outputs import OutputKeys
  16. from modelscope.pipelines.base import Input, Pipeline
  17. from modelscope.pipelines.builder import PIPELINES
  18. from modelscope.utils.config import Config
  19. from modelscope.utils.constant import ModelFile, Tasks
  20. from modelscope.utils.logger import get_logger
  21. logger = get_logger()
  22. @PIPELINES.register_module(
  23. Tasks.video_summarization, module_name=Pipelines.video_summarization)
  24. class VideoSummarizationPipeline(Pipeline):
  25. def __init__(self, model: str, **kwargs):
  26. """
  27. use `model` to create a video summarization pipeline for prediction
  28. Args:
  29. model: model id on modelscope hub.
  30. """
  31. super().__init__(model=model, auto_collate=False, **kwargs)
  32. logger.info(f'loading model from {model}')
  33. googlenet_model_path = osp.join(model, 'bvlc_googlenet.pt')
  34. config_path = osp.join(model, ModelFile.CONFIGURATION)
  35. logger.info(f'loading config from {config_path}')
  36. self.cfg = Config.from_file(config_path)
  37. self.googlenet_model = bvlc_googlenet()
  38. self.googlenet_model.model.load_state_dict(
  39. torch.load(
  40. googlenet_model_path,
  41. map_location=torch.device(self.device),
  42. weights_only=True))
  43. self.googlenet_model = self.googlenet_model.to(self.device).eval()
  44. self.pgl_model = PGLVideoSummarization(model)
  45. self.pgl_model = self.pgl_model.to(self.device).eval()
  46. logger.info('load model done')
  47. def preprocess(self, input: Input) -> Dict[str, Any]:
  48. if not isinstance(input, str):
  49. raise TypeError(f'input should be a str,'
  50. f' but got {type(input)}')
  51. frames = []
  52. picks = []
  53. cap = cv2.VideoCapture(input)
  54. self.fps = cap.get(cv2.CAP_PROP_FPS)
  55. self.frame_count = cap.get(cv2.CAP_PROP_FRAME_COUNT)
  56. frame_idx = 0
  57. while (cap.isOpened()):
  58. ret, frame = cap.read()
  59. if not ret:
  60. break
  61. if frame_idx % 15 == 0:
  62. frames.append(frame)
  63. picks.append(frame_idx)
  64. frame_idx += 1
  65. n_frame = frame_idx
  66. result = {
  67. 'video_name': input,
  68. 'video_frames': np.array(frames),
  69. 'n_frame': n_frame,
  70. 'picks': np.array(picks)
  71. }
  72. return result
  73. def forward(self, input: Dict[str, Any]) -> Dict[str, Any]:
  74. frame_features = []
  75. for frame in tqdm(input['video_frames']):
  76. feat = self.googlenet_model(frame)
  77. frame_features.append(feat)
  78. change_points, n_frame_per_seg = get_change_points(
  79. frame_features, input['n_frame'])
  80. summary = self.inference(frame_features, input['n_frame'],
  81. input['picks'], change_points)
  82. output = summary_format(summary, self.fps)
  83. return {OutputKeys.OUTPUT: output}
  84. def postprocess(self, inputs: Dict[str, Any]) -> Dict[str, Any]:
  85. return inputs
  86. def inference(self, frame_features, n_frames, picks, change_points):
  87. frame_features = torch.from_numpy(np.array(frame_features, np.float32))
  88. picks = np.array(picks, np.int32)
  89. with torch.no_grad():
  90. results = self.pgl_model(dict(frame_features=frame_features))
  91. scores = results['scores']
  92. if not scores.device.type == 'cpu':
  93. scores = scores.cpu()
  94. scores = scores.squeeze(0).numpy().tolist()
  95. summary = generate_summary([change_points], [scores], [n_frames],
  96. [picks])[0]
  97. return summary.tolist()