video.py 13 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357
  1. import math
  2. import os
  3. import random
  4. import uuid
  5. from os.path import exists
  6. from tempfile import TemporaryDirectory
  7. from urllib.parse import urlparse
  8. import numpy as np
  9. import torch
  10. import torch.utils.data
  11. import torch.utils.dlpack as dlpack
  12. import torchvision.transforms._transforms_video as transforms
  13. from decord import VideoReader
  14. from torchvision.transforms import Compose
  15. from modelscope.hub.file_download import http_get_file
  16. from modelscope.metainfo import Preprocessors
  17. from modelscope.utils.constant import Fields, ModeKeys
  18. from modelscope.utils.type_assert import type_assert
  19. from .base import Preprocessor
  20. from .builder import PREPROCESSORS
  21. def ReadVideoData(cfg,
  22. video_path,
  23. num_spatial_crops_override=None,
  24. num_temporal_views_override=None):
  25. """ simple interface to load video frames from file
  26. Args:
  27. cfg (Config): The global config object.
  28. video_path (str): video file path
  29. num_spatial_crops_override (int): the spatial crops per clip
  30. num_temporal_views_override (int): the temporal clips per video
  31. Returns:
  32. data (Tensor): the normalized video clips for model inputs
  33. """
  34. url_parsed = urlparse(video_path)
  35. if url_parsed.scheme in ('file', '') and exists(
  36. url_parsed.path): # Possibly a local file
  37. data = _decode_video(cfg, video_path, num_temporal_views_override)
  38. else:
  39. with TemporaryDirectory() as temporary_cache_dir:
  40. random_str = uuid.uuid4().hex
  41. http_get_file(
  42. url=video_path,
  43. local_dir=temporary_cache_dir,
  44. file_name=random_str,
  45. cookies=None)
  46. temp_file_path = os.path.join(temporary_cache_dir, random_str)
  47. data = _decode_video(cfg, temp_file_path,
  48. num_temporal_views_override)
  49. if num_spatial_crops_override is not None:
  50. num_spatial_crops = num_spatial_crops_override
  51. transform = kinetics400_tranform(cfg, num_spatial_crops_override)
  52. else:
  53. num_spatial_crops = cfg.TEST.NUM_SPATIAL_CROPS
  54. transform = kinetics400_tranform(cfg, cfg.TEST.NUM_SPATIAL_CROPS)
  55. data_list = []
  56. for i in range(data.size(0)):
  57. for j in range(num_spatial_crops):
  58. transform.transforms[1].set_spatial_index(j)
  59. data_list.append(transform(data[i]))
  60. return torch.stack(data_list, dim=0)
  61. def kinetics400_tranform(cfg, num_spatial_crops):
  62. """
  63. Configs the transform for the kinetics-400 dataset.
  64. We apply controlled spatial cropping and normalization.
  65. Args:
  66. cfg (Config): The global config object.
  67. num_spatial_crops (int): the spatial crops per clip
  68. Returns:
  69. transform_function (Compose): the transform function for input clips
  70. """
  71. resize_video = KineticsResizedCrop(
  72. short_side_range=[cfg.DATA.TEST_SCALE, cfg.DATA.TEST_SCALE],
  73. crop_size=cfg.DATA.TEST_CROP_SIZE,
  74. num_spatial_crops=num_spatial_crops)
  75. std_transform_list = [
  76. transforms.ToTensorVideo(), resize_video,
  77. transforms.NormalizeVideo(
  78. mean=cfg.DATA.MEAN, std=cfg.DATA.STD, inplace=True)
  79. ]
  80. return Compose(std_transform_list)
  81. def _interval_based_sampling(vid_length, vid_fps, target_fps, clip_idx,
  82. num_clips, num_frames, interval, minus_interval):
  83. """
  84. Generates the frame index list using interval based sampling.
  85. Args:
  86. vid_length (int): the length of the whole video (valid selection range).
  87. vid_fps (int): the original video fps
  88. target_fps (int): the normalized video fps
  89. clip_idx (int):
  90. -1 for random temporal sampling, and positive values for sampling specific
  91. clip from the video
  92. num_clips (int):
  93. the total clips to be sampled from each video. combined with clip_idx,
  94. the sampled video is the "clip_idx-th" video from "num_clips" videos.
  95. num_frames (int): number of frames in each sampled clips.
  96. interval (int): the interval to sample each frame.
  97. minus_interval (bool): control the end index
  98. Returns:
  99. index (tensor): the sampled frame indexes
  100. """
  101. if num_frames == 1:
  102. index = [random.randint(0, vid_length - 1)]
  103. else:
  104. # transform FPS
  105. clip_length = num_frames * interval * vid_fps / target_fps
  106. max_idx = max(vid_length - clip_length, 0)
  107. if num_clips == 1:
  108. start_idx = max_idx / 2
  109. else:
  110. start_idx = clip_idx * math.floor(max_idx / (num_clips - 1))
  111. if minus_interval:
  112. end_idx = start_idx + clip_length - interval
  113. else:
  114. end_idx = start_idx + clip_length - 1
  115. index = torch.linspace(start_idx, end_idx, num_frames)
  116. index = torch.clamp(index, 0, vid_length - 1).long()
  117. return index
  118. def _decode_video_frames_list(cfg,
  119. frames_list,
  120. vid_fps,
  121. num_temporal_views_override=None):
  122. """
  123. Decodes the video given the numpy frames.
  124. Args:
  125. cfg (Config): The global config object.
  126. frames_list (list): all frames for a video, the frames should be numpy array.
  127. vid_fps (int): the fps of this video.
  128. num_temporal_views_override (int): the temporal clips per video
  129. Returns:
  130. frames (Tensor): video tensor data
  131. """
  132. assert isinstance(frames_list, list)
  133. if num_temporal_views_override is not None:
  134. num_clips_per_video = num_temporal_views_override
  135. else:
  136. num_clips_per_video = cfg.TEST.NUM_ENSEMBLE_VIEWS
  137. frame_list = []
  138. for clip_idx in range(num_clips_per_video):
  139. # for each clip in the video,
  140. # a list is generated before decoding the specified frames from the video
  141. list_ = _interval_based_sampling(
  142. len(frames_list),
  143. vid_fps,
  144. cfg.DATA.TARGET_FPS,
  145. clip_idx,
  146. num_clips_per_video,
  147. cfg.DATA.NUM_INPUT_FRAMES,
  148. cfg.DATA.SAMPLING_RATE,
  149. cfg.DATA.MINUS_INTERVAL,
  150. )
  151. frames = None
  152. frames = torch.from_numpy(
  153. np.stack([frames_list[index] for index in list_.tolist()], axis=0))
  154. frame_list.append(frames)
  155. frames = torch.stack(frame_list)
  156. del vr
  157. return frames
  158. def _decode_video(cfg, path, num_temporal_views_override=None):
  159. """
  160. Decodes the video given the numpy frames.
  161. Args:
  162. cfg (Config): The global config object.
  163. path (str): video file path.
  164. num_temporal_views_override (int): the temporal clips per video
  165. Returns:
  166. frames (Tensor): video tensor data
  167. """
  168. vr = VideoReader(path)
  169. if num_temporal_views_override is not None:
  170. num_clips_per_video = num_temporal_views_override
  171. else:
  172. num_clips_per_video = cfg.TEST.NUM_ENSEMBLE_VIEWS
  173. frame_list = []
  174. for clip_idx in range(num_clips_per_video):
  175. # for each clip in the video,
  176. # a list is generated before decoding the specified frames from the video
  177. list_ = _interval_based_sampling(
  178. len(vr),
  179. vr.get_avg_fps(),
  180. cfg.DATA.TARGET_FPS,
  181. clip_idx,
  182. num_clips_per_video,
  183. cfg.DATA.NUM_INPUT_FRAMES,
  184. cfg.DATA.SAMPLING_RATE,
  185. cfg.DATA.MINUS_INTERVAL,
  186. )
  187. frames = None
  188. if path.endswith('.avi'):
  189. append_list = torch.arange(0, list_[0], 4)
  190. frames = dlpack.from_dlpack(
  191. vr.get_batch(torch.cat([append_list,
  192. list_])).to_dlpack()).clone()
  193. frames = frames[append_list.shape[0]:]
  194. else:
  195. frames = dlpack.from_dlpack(
  196. vr.get_batch(list_).to_dlpack()).clone()
  197. frame_list.append(frames)
  198. frames = torch.stack(frame_list)
  199. del vr
  200. return frames
  201. class KineticsResizedCrop(object):
  202. """Perform resize and crop for kinetics-400 dataset
  203. Args:
  204. short_side_range (list): The length of short side range. In inference, this should be [256, 256]
  205. crop_size (int): The cropped size for frames.
  206. num_spatial_crops (int): The number of the cropped spatial regions in each video.
  207. """
  208. def __init__(
  209. self,
  210. short_side_range,
  211. crop_size,
  212. num_spatial_crops=1,
  213. ):
  214. self.idx = -1
  215. self.short_side_range = short_side_range
  216. self.crop_size = int(crop_size)
  217. self.num_spatial_crops = num_spatial_crops
  218. def _get_controlled_crop(self, clip):
  219. """Perform controlled crop for video tensor.
  220. Args:
  221. clip (Tensor): the video data, the shape is [T, C, H, W]
  222. """
  223. _, _, clip_height, clip_width = clip.shape
  224. length = self.short_side_range[0]
  225. if clip_height < clip_width:
  226. new_clip_height = int(length)
  227. new_clip_width = int(clip_width / clip_height * new_clip_height)
  228. new_clip = torch.nn.functional.interpolate(
  229. clip, size=(new_clip_height, new_clip_width), mode='bilinear')
  230. else:
  231. new_clip_width = int(length)
  232. new_clip_height = int(clip_height / clip_width * new_clip_width)
  233. new_clip = torch.nn.functional.interpolate(
  234. clip, size=(new_clip_height, new_clip_width), mode='bilinear')
  235. x_max = int(new_clip_width - self.crop_size)
  236. y_max = int(new_clip_height - self.crop_size)
  237. if self.num_spatial_crops == 1:
  238. x = x_max // 2
  239. y = y_max // 2
  240. elif self.num_spatial_crops == 3:
  241. if self.idx == 0:
  242. if new_clip_width == length:
  243. x = x_max // 2
  244. y = 0
  245. elif new_clip_height == length:
  246. x = 0
  247. y = y_max // 2
  248. elif self.idx == 1:
  249. x = x_max // 2
  250. y = y_max // 2
  251. elif self.idx == 2:
  252. if new_clip_width == length:
  253. x = x_max // 2
  254. y = y_max
  255. elif new_clip_height == length:
  256. x = x_max
  257. y = y_max // 2
  258. return new_clip[:, :, y:y + self.crop_size, x:x + self.crop_size]
  259. def _get_random_crop(self, clip):
  260. _, _, clip_height, clip_width = clip.shape
  261. short_side = min(clip_height, clip_width)
  262. long_side = max(clip_height, clip_width)
  263. new_short_side = int(random.uniform(*self.short_side_range))
  264. new_long_side = int(long_side / short_side * new_short_side)
  265. if clip_height < clip_width:
  266. new_clip_height = new_short_side
  267. new_clip_width = new_long_side
  268. else:
  269. new_clip_height = new_long_side
  270. new_clip_width = new_short_side
  271. new_clip = torch.nn.functional.interpolate(
  272. clip, size=(new_clip_height, new_clip_width), mode='bilinear')
  273. x_max = int(new_clip_width - self.crop_size)
  274. y_max = int(new_clip_height - self.crop_size)
  275. x = int(random.uniform(0, x_max))
  276. y = int(random.uniform(0, y_max))
  277. return new_clip[:, :, y:y + self.crop_size, x:x + self.crop_size]
  278. def set_spatial_index(self, idx):
  279. """Set the spatial cropping index for controlled cropping..
  280. Args:
  281. idx (int): the spatial index. The value should be in [0, 1, 2], means [left, center, right], respectively.
  282. """
  283. self.idx = idx
  284. def __call__(self, clip):
  285. return self._get_controlled_crop(clip)
  286. @PREPROCESSORS.register_module(
  287. Fields.cv, module_name=Preprocessors.movie_scene_segmentation_preprocessor)
  288. class MovieSceneSegmentationPreprocessor(Preprocessor):
  289. def __init__(self, *args, **kwargs):
  290. """
  291. movie scene segmentation preprocessor
  292. """
  293. super().__init__(*args, **kwargs)
  294. self.is_train = kwargs.pop('is_train', True)
  295. self.preprocessor_train_cfg = kwargs.pop(ModeKeys.TRAIN, None)
  296. self.preprocessor_test_cfg = kwargs.pop(ModeKeys.EVAL, None)
  297. self.num_keyframe = kwargs.pop('num_keyframe', 3)
  298. from .movie_scene_segmentation import get_transform
  299. self.train_transform = get_transform(self.preprocessor_train_cfg)
  300. self.test_transform = get_transform(self.preprocessor_test_cfg)
  301. def train(self):
  302. self.is_train = True
  303. return
  304. def eval(self):
  305. self.is_train = False
  306. return
  307. @type_assert(object, object)
  308. def __call__(self, results):
  309. if self.is_train:
  310. transforms = self.train_transform
  311. else:
  312. transforms = self.test_transform
  313. results = torch.stack(transforms(results), dim=0)
  314. results = results.view(-1, self.num_keyframe, 3, 224, 224)
  315. return results