video_utils.py 33 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525526527528529530531532533534535536537538539540541542543544545546547548549550551552553554555556557558559560561562563564565566567568569570571572573574575576577578579580581582583584585586587588589590591592593594595596597598599600601602603604605606607608609610611612613614615616617618619620621622623624625626627628629630631632633634635636637638639640641642643644645646647648649650651652653654655656657658659660661662663664665666667668669670671672673674675676677678679680681682683684685686687688689690691692693694695696697698699700701702703704705706707708709710711712713714715716717718719720721722723724725726727728729730731732733734735736737738739740741742743744745746747748749750751752753754755756757758759760761762763764765766767768769770771772773774775776777778779780781782783784785786787788789790791792793794795796797798799800801802803804805806807808809810811812813814815816817818819820821822823824825826827828829830831832833834835836837838839840841842843844845846847848849850851852853854855856857858859860861862863864865866867868869870871872873874875876877878
  1. # coding=utf-8
  2. # Copyright 2025 The HuggingFace Inc. team.
  3. #
  4. # Licensed under the Apache License, Version 2.0 (the "License");
  5. # you may not use this file except in compliance with the License.
  6. # You may obtain a copy of the License at
  7. #
  8. # http://www.apache.org/licenses/LICENSE-2.0
  9. #
  10. # Unless required by applicable law or agreed to in writing, software
  11. # distributed under the License is distributed on an "AS IS" BASIS,
  12. # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  13. # See the License for the specific language governing permissions and
  14. # limitations under the License.
  15. import os
  16. import warnings
  17. from collections.abc import Iterable, Mapping
  18. from contextlib import redirect_stdout
  19. from dataclasses import dataclass, fields
  20. from io import BytesIO
  21. from typing import Callable, NewType, Optional, Union
  22. from urllib.parse import urlparse
  23. import numpy as np
  24. import requests
  25. from .image_transforms import PaddingMode, to_channel_dimension_format
  26. from .image_utils import ChannelDimension, infer_channel_dimension_format, is_valid_image
  27. from .utils import (
  28. is_av_available,
  29. is_cv2_available,
  30. is_decord_available,
  31. is_numpy_array,
  32. is_torch_available,
  33. is_torch_tensor,
  34. is_torchcodec_available,
  35. is_torchvision_available,
  36. is_vision_available,
  37. is_yt_dlp_available,
  38. logging,
  39. requires_backends,
  40. )
  41. if is_vision_available():
  42. import PIL.Image
  43. import PIL.ImageOps
  44. if is_torchvision_available():
  45. from torchvision import io as torchvision_io
  46. if is_torch_available():
  47. import torch
  48. logger = logging.get_logger(__name__)
  49. URL = NewType("URL", str)
  50. Path = NewType("Path", str)
  51. VideoInput = Union[
  52. list["PIL.Image.Image"],
  53. np.ndarray,
  54. "torch.Tensor",
  55. list[np.ndarray],
  56. list["torch.Tensor"],
  57. list[list["PIL.Image.Image"]],
  58. list[list[np.ndarray]],
  59. list[list["torch.Tensor"]],
  60. URL,
  61. list[URL],
  62. list[list[URL]],
  63. Path,
  64. list[Path],
  65. list[list[Path]],
  66. ]
  67. @dataclass
  68. class VideoMetadata(Mapping):
  69. total_num_frames: int
  70. fps: Optional[float] = None
  71. width: Optional[int] = None
  72. height: Optional[int] = None
  73. duration: Optional[float] = None
  74. video_backend: Optional[str] = None
  75. frames_indices: Optional[list[int]] = None
  76. def __iter__(self):
  77. return (f.name for f in fields(self))
  78. def __len__(self):
  79. return len(fields(self))
  80. def __getitem__(self, item):
  81. return getattr(self, item)
  82. def __setitem__(self, key, value):
  83. return setattr(self, key, value)
  84. @property
  85. def timestamps(self) -> list[float]:
  86. "Timestamps of the sampled frames in seconds."
  87. if self.fps is None or self.frames_indices is None:
  88. raise ValueError("Cannot infer video `timestamps` when `fps` or `frames_indices` is None.")
  89. return [frame_idx / self.fps for frame_idx in self.frames_indices]
  90. def update(self, dictionary):
  91. for key, value in dictionary.items():
  92. if hasattr(self, key):
  93. setattr(self, key, value)
  94. def is_valid_video_frame(frame):
  95. return isinstance(frame, PIL.Image.Image) or (
  96. (is_numpy_array(frame) or is_torch_tensor(frame)) and frame.ndim == 3
  97. )
  98. def is_valid_video(video):
  99. if not isinstance(video, (list, tuple)):
  100. return (is_numpy_array(video) or is_torch_tensor(video)) and video.ndim == 4
  101. return video and all(is_valid_video_frame(frame) for frame in video)
  102. def valid_videos(videos):
  103. # If we have a list of videos, it could be either one video as list of frames or a batch
  104. if isinstance(videos, (list, tuple)):
  105. for video_or_frame in videos:
  106. if not (is_valid_video(video_or_frame) or is_valid_video_frame(video_or_frame)):
  107. return False
  108. # If not a list, then we have a single 4D video or 5D batched tensor
  109. elif not is_valid_video(videos) or videos.ndim == 5:
  110. return False
  111. return True
  112. def is_batched_video(videos):
  113. if isinstance(videos, (list, tuple)):
  114. return is_valid_video(videos[0])
  115. elif (is_numpy_array(videos) or is_torch_tensor(videos)) and videos.ndim == 5:
  116. return True
  117. return False
  118. def is_scaled_video(video: np.ndarray) -> bool:
  119. """
  120. Checks to see whether the pixel values have already been rescaled to [0, 1].
  121. """
  122. # It's possible the video has pixel values in [0, 255] but is of floating type
  123. return np.min(video) >= 0 and np.max(video) <= 1
  124. def convert_pil_frames_to_video(videos: list[VideoInput]) -> list[Union[np.ndarray, "torch.Tensor"]]:
  125. """
  126. Given a batch of videos, converts each video to a 4D array. If video is already in array type,
  127. it is simply returned. We assume that all inputs in the list are in the same format, based on the type of the first element.
  128. Args:
  129. videos (`VideoInput`):
  130. Video inputs to turn into a list of videos.
  131. """
  132. if not (isinstance(videos[0], (list, tuple)) and is_valid_image(videos[0][0])):
  133. return videos
  134. video_converted = []
  135. for video in videos:
  136. video = [np.array(frame) for frame in video]
  137. video = np.stack(video)
  138. video_converted.append(video)
  139. return video_converted
  140. def make_batched_videos(videos) -> list[Union[np.ndarray, "torch.Tensor", "URL", "Path"]]:
  141. """
  142. Ensure that the input is a list of videos. If the input is a single video, it is converted to a list of length 1.
  143. If the input is a batch of videos, it is converted to a list of 4D video arrays. Videos passed as list `PIL.Image`
  144. frames are converted to 4D arrays.
  145. We assume that all inputs in the list are in the same format, based on the type of the first element.
  146. Args:
  147. videos (`VideoInput`):
  148. Video inputs to turn into a list of videos.
  149. """
  150. # Early exit for deeply nested list of image frame paths. We shouldn't flatten them
  151. try:
  152. if isinstance(videos[0][0], list) and isinstance(videos[0][0][0], str):
  153. return [image_paths for sublist in videos for image_paths in sublist]
  154. except (IndexError, TypeError):
  155. pass
  156. if isinstance(videos, str) or is_valid_video(videos):
  157. return convert_pil_frames_to_video([videos])
  158. # only one frame passed, thus we unsqueeze time dim
  159. elif is_valid_image(videos):
  160. if isinstance(videos, PIL.Image.Image):
  161. videos = np.array(videos)
  162. return [videos[None, ...]]
  163. elif not isinstance(videos, list):
  164. raise ValueError(
  165. f"Invalid video input. Expected either a list of video frames or an input of 4 or 5 dimensions, but got"
  166. f" type {type(videos)}."
  167. )
  168. # Recursively flatten any nested structure
  169. flat_videos_list = []
  170. for item in videos:
  171. if isinstance(item, str) or is_valid_video(item):
  172. flat_videos_list.append(item)
  173. elif isinstance(item, list) and item:
  174. flat_videos_list.extend(make_batched_videos(item))
  175. flat_videos_list = convert_pil_frames_to_video(flat_videos_list)
  176. return flat_videos_list
  177. def make_batched_metadata(videos: VideoInput, video_metadata: Union[VideoMetadata, dict]):
  178. if video_metadata is None:
  179. # Create default metadata and fill attributes we can infer from given video
  180. video_metadata = [
  181. {
  182. "total_num_frames": len(video),
  183. "fps": None,
  184. "duration": None,
  185. "frames_indices": list(range(len(video))),
  186. "height": get_video_size(video)[0] if is_valid_video(video) else None,
  187. "width": get_video_size(video)[1] if is_valid_video(video) else None,
  188. }
  189. for video in videos
  190. ]
  191. if isinstance(video_metadata, list):
  192. # Flatten if nested list
  193. if isinstance(video_metadata[0], list):
  194. video_metadata = [
  195. VideoMetadata(**metadata) for metadata_list in video_metadata for metadata in metadata_list
  196. ]
  197. # Simply wrap in VideoMetadata if simple dict
  198. elif isinstance(video_metadata[0], dict):
  199. video_metadata = [VideoMetadata(**metadata) for metadata in video_metadata]
  200. else:
  201. # Create a batched list from single object
  202. video_metadata = [VideoMetadata(**video_metadata)]
  203. return video_metadata
  204. def get_video_size(video: np.ndarray, channel_dim: Optional[ChannelDimension] = None) -> tuple[int, int]:
  205. """
  206. Returns the (height, width) dimensions of the video.
  207. Args:
  208. video (`np.ndarray`):
  209. The video to get the dimensions of.
  210. channel_dim (`ChannelDimension`, *optional*):
  211. Which dimension the channel dimension is in. If `None`, will infer the channel dimension from the video.
  212. Returns:
  213. A tuple of the video's height and width.
  214. """
  215. if channel_dim is None:
  216. channel_dim = infer_channel_dimension_format(video, num_channels=(1, 3, 4))
  217. if channel_dim == ChannelDimension.FIRST:
  218. return video.shape[-2], video.shape[-1]
  219. elif channel_dim == ChannelDimension.LAST:
  220. return video.shape[-3], video.shape[-2]
  221. else:
  222. raise ValueError(f"Unsupported data format: {channel_dim}")
  223. def get_uniform_frame_indices(total_num_frames: int, num_frames: Optional[int] = None):
  224. """
  225. Creates a numpy array for uniform sampling of `num_frame` frames from `total_num_frames`
  226. when loading a video.
  227. Args:
  228. total_num_frames (`int`):
  229. Total number of frames that a video has.
  230. num_frames (`int`, *optional*):
  231. Number of frames to sample uniformly. If not specified, all frames are sampled.
  232. Returns:
  233. np.ndarray: np array of frame indices that will be sampled.
  234. """
  235. if num_frames is not None:
  236. indices = np.arange(0, total_num_frames, total_num_frames / num_frames).astype(int)
  237. else:
  238. indices = np.arange(0, total_num_frames).astype(int)
  239. return indices
  240. def default_sample_indices_fn(metadata: VideoMetadata, num_frames=None, fps=None, **kwargs):
  241. """
  242. A default sampling function that replicates the logic used in get_uniform_frame_indices,
  243. while optionally handling `fps` if `num_frames` is not provided.
  244. Args:
  245. metadata (`VideoMetadata`):
  246. `VideoMetadata` object containing metadata about the video, such as "total_num_frames" or "fps".
  247. num_frames (`int`, *optional*):
  248. Number of frames to sample uniformly.
  249. fps (`int` or `float`, *optional*):
  250. Desired frames per second. Takes priority over num_frames if both are provided.
  251. Returns:
  252. `np.ndarray`: Array of frame indices to sample.
  253. """
  254. total_num_frames = metadata.total_num_frames
  255. video_fps = metadata.fps
  256. # If num_frames is not given but fps is, calculate num_frames from fps
  257. if num_frames is None and fps is not None:
  258. num_frames = int(total_num_frames / video_fps * fps)
  259. if num_frames > total_num_frames:
  260. raise ValueError(
  261. f"When loading the video with fps={fps}, we computed num_frames={num_frames} "
  262. f"which exceeds total_num_frames={total_num_frames}. Check fps or video metadata."
  263. )
  264. if num_frames is not None:
  265. indices = np.arange(0, total_num_frames, total_num_frames / num_frames, dtype=int)
  266. else:
  267. indices = np.arange(0, total_num_frames, dtype=int)
  268. return indices
  269. def read_video_opencv(
  270. video_path: Union["URL", "Path"],
  271. sample_indices_fn: Callable,
  272. **kwargs,
  273. ) -> tuple[np.ndarray, VideoMetadata]:
  274. """
  275. Decode a video using the OpenCV backend.
  276. Args:
  277. video_path (`str`):
  278. Path to the video file.
  279. sample_indices_fn (`Callable`):
  280. A callable function that will return indices at which the video should be sampled. If the video has to be loaded using
  281. by a different sampling technique than provided by `num_frames` or `fps` arguments, one should provide their own `sample_indices_fn`.
  282. If not provided, simple uniform sampling with fps is performed.
  283. Example:
  284. def sample_indices_fn(metadata, **kwargs):
  285. return np.linspace(0, metadata.total_num_frames - 1, num_frames, dtype=int)
  286. Returns:
  287. tuple[`np.ndarray`, `VideoMetadata`]: A tuple containing:
  288. - Numpy array of frames in RGB (shape: [num_frames, height, width, 3]).
  289. - `VideoMetadata` object.
  290. """
  291. # Lazy import cv2
  292. requires_backends(read_video_opencv, ["cv2"])
  293. import cv2
  294. video = cv2.VideoCapture(video_path)
  295. total_num_frames = int(video.get(cv2.CAP_PROP_FRAME_COUNT))
  296. video_fps = video.get(cv2.CAP_PROP_FPS)
  297. duration = total_num_frames / video_fps if video_fps else 0
  298. metadata = VideoMetadata(
  299. total_num_frames=int(total_num_frames),
  300. fps=float(video_fps),
  301. duration=float(duration),
  302. video_backend="opencv",
  303. height=int(video.get(cv2.CAP_PROP_FRAME_HEIGHT)),
  304. width=int(video.get(cv2.CAP_PROP_FRAME_WIDTH)),
  305. )
  306. indices = sample_indices_fn(metadata=metadata, **kwargs)
  307. index = 0
  308. frames = []
  309. while video.isOpened():
  310. success, frame = video.read()
  311. if not success:
  312. break
  313. if index in indices:
  314. height, width, channel = frame.shape
  315. frame = cv2.cvtColor(frame, cv2.COLOR_BGR2RGB)
  316. frames.append(frame[0:height, 0:width, 0:channel])
  317. if success:
  318. index += 1
  319. if index >= total_num_frames:
  320. break
  321. video.release()
  322. metadata.frames_indices = indices
  323. return np.stack(frames), metadata
  324. def read_video_decord(
  325. video_path: Union["URL", "Path"],
  326. sample_indices_fn: Callable,
  327. **kwargs,
  328. ):
  329. """
  330. Decode a video using the Decord backend.
  331. Args:
  332. video_path (`str`):
  333. Path to the video file.
  334. sample_indices_fn (`Callable`):
  335. A callable function that will return indices at which the video should be sampled. If the video has to be loaded using
  336. by a different sampling technique than provided by `num_frames` or `fps` arguments, one should provide their own `sample_indices_fn`.
  337. If not provided, simple uniform sampling with fps is performed.
  338. Example:
  339. def sample_indices_fn(metadata, **kwargs):
  340. return np.linspace(0, metadata.total_num_frames - 1, num_frames, dtype=int)
  341. Returns:
  342. tuple[`np.array`, `VideoMetadata`]: A tuple containing:
  343. - Numpy array of frames in RGB (shape: [num_frames, height, width, 3]).
  344. - `VideoMetadata` object.
  345. """
  346. # Lazy import from decord
  347. requires_backends(read_video_decord, ["decord"])
  348. from decord import VideoReader, cpu
  349. vr = VideoReader(uri=video_path, ctx=cpu(0)) # decord has problems with gpu
  350. video_fps = vr.get_avg_fps()
  351. total_num_frames = len(vr)
  352. duration = total_num_frames / video_fps if video_fps else 0
  353. metadata = VideoMetadata(
  354. total_num_frames=int(total_num_frames),
  355. fps=float(video_fps),
  356. duration=float(duration),
  357. video_backend="decord",
  358. )
  359. indices = sample_indices_fn(metadata=metadata, **kwargs)
  360. video = vr.get_batch(indices).asnumpy()
  361. metadata.update(
  362. {
  363. "frames_indices": indices,
  364. "height": video.shape[1],
  365. "width": video.shape[2],
  366. }
  367. )
  368. return video, metadata
  369. def read_video_pyav(
  370. video_path: Union["URL", "Path"],
  371. sample_indices_fn: Callable,
  372. **kwargs,
  373. ):
  374. """
  375. Decode the video with PyAV decoder.
  376. Args:
  377. video_path (`str`):
  378. Path to the video file.
  379. sample_indices_fn (`Callable`, *optional*):
  380. A callable function that will return indices at which the video should be sampled. If the video has to be loaded using
  381. by a different sampling technique than provided by `num_frames` or `fps` arguments, one should provide their own `sample_indices_fn`.
  382. If not provided, simple uniform sampling with fps is performed.
  383. Example:
  384. def sample_indices_fn(metadata, **kwargs):
  385. return np.linspace(0, metadata.total_num_frames - 1, num_frames, dtype=int)
  386. Returns:
  387. tuple[`np.array`, `VideoMetadata`]: A tuple containing:
  388. - Numpy array of frames in RGB (shape: [num_frames, height, width, 3]).
  389. - `VideoMetadata` object.
  390. """
  391. # Lazy import av
  392. requires_backends(read_video_pyav, ["av"])
  393. import av
  394. container = av.open(video_path)
  395. total_num_frames = container.streams.video[0].frames
  396. video_fps = container.streams.video[0].average_rate # should we better use `av_guess_frame_rate`?
  397. duration = total_num_frames / video_fps if video_fps else 0
  398. metadata = VideoMetadata(
  399. total_num_frames=int(total_num_frames),
  400. fps=float(video_fps),
  401. duration=float(duration),
  402. video_backend="pyav",
  403. height=container.streams.video[0].height,
  404. width=container.streams.video[0].width,
  405. )
  406. indices = sample_indices_fn(metadata=metadata, **kwargs)
  407. frames = []
  408. container.seek(0)
  409. end_index = indices[-1]
  410. for i, frame in enumerate(container.decode(video=0)):
  411. if i > end_index:
  412. break
  413. if i >= 0 and i in indices:
  414. frames.append(frame)
  415. video = np.stack([x.to_ndarray(format="rgb24") for x in frames])
  416. metadata.frames_indices = indices
  417. return video, metadata
  418. def read_video_torchvision(
  419. video_path: Union["URL", "Path"],
  420. sample_indices_fn: Callable,
  421. **kwargs,
  422. ):
  423. """
  424. Decode the video with torchvision decoder.
  425. Args:
  426. video_path (`str`):
  427. Path to the video file.
  428. sample_indices_fn (`Callable`, *optional*):
  429. A callable function that will return indices at which the video should be sampled. If the video has to be loaded using
  430. by a different sampling technique than provided by `num_frames` or `fps` arguments, one should provide their own `sample_indices_fn`.
  431. If not provided, simple uniform sampling with fps is performed.
  432. Example:
  433. def sample_indices_fn(metadata, **kwargs):
  434. return np.linspace(0, metadata.total_num_frames - 1, num_frames, dtype=int)
  435. Returns:
  436. tuple[`torch.Tensor`, `VideoMetadata`]: A tuple containing:
  437. - Torch tensor of frames in RGB (shape: [num_frames, height, width, 3]).
  438. - `VideoMetadata` object.
  439. """
  440. warnings.warn(
  441. "Using `torchvision` for video decoding is deprecated and will be removed in future versions. "
  442. "Please use `torchcodec` instead."
  443. )
  444. video, _, info = torchvision_io.read_video(
  445. video_path,
  446. start_pts=0.0,
  447. end_pts=None,
  448. pts_unit="sec",
  449. output_format="TCHW",
  450. )
  451. video_fps = info["video_fps"]
  452. total_num_frames = video.size(0)
  453. duration = total_num_frames / video_fps if video_fps else 0
  454. metadata = VideoMetadata(
  455. total_num_frames=int(total_num_frames),
  456. fps=float(video_fps),
  457. duration=float(duration),
  458. video_backend="torchvision",
  459. )
  460. indices = sample_indices_fn(metadata=metadata, **kwargs)
  461. video = video[indices].contiguous()
  462. metadata.update(
  463. {
  464. "frames_indices": indices,
  465. "height": video.shape[2],
  466. "width": video.shape[3],
  467. }
  468. )
  469. return video, metadata
  470. def read_video_torchcodec(
  471. video_path: Union["URL", "Path"],
  472. sample_indices_fn: Callable,
  473. **kwargs,
  474. ):
  475. """
  476. Decode the video with torchcodec decoder.
  477. Args:
  478. video_path (`str`):
  479. Path to the video file.
  480. sample_indices_fn (`Callable`):
  481. A callable function that will return indices at which the video should be sampled. If the video has to be loaded using
  482. by a different sampling technique than provided by `num_frames` or `fps` arguments, one should provide their own `sample_indices_fn`.
  483. If not provided, simple uniform sampling with fps is performed.
  484. Example:
  485. def sample_indices_fn(metadata, **kwargs):
  486. return np.linspace(0, metadata.total_num_frames - 1, num_frames, dtype=int)
  487. Returns:
  488. Tuple[`torch.Tensor`, `VideoMetadata`]: A tuple containing:
  489. - Torch tensor of frames in RGB (shape: [num_frames, height, width, 3]).
  490. - `VideoMetadata` object.
  491. """
  492. # Lazy import torchcodec
  493. requires_backends(read_video_torchcodec, ["torchcodec"])
  494. from torchcodec.decoders import VideoDecoder
  495. decoder = VideoDecoder(
  496. video_path,
  497. # Interestingly `exact` mode takes less than approximate when we load the whole video
  498. seek_mode="exact",
  499. # Allow FFmpeg decide on the number of threads for efficiency
  500. num_ffmpeg_threads=0,
  501. device=kwargs.get("device"),
  502. )
  503. metadata = VideoMetadata(
  504. total_num_frames=decoder.metadata.num_frames,
  505. fps=decoder.metadata.average_fps,
  506. duration=decoder.metadata.duration_seconds,
  507. video_backend="torchcodec",
  508. height=decoder.metadata.height,
  509. width=decoder.metadata.width,
  510. )
  511. indices = sample_indices_fn(metadata=metadata, **kwargs)
  512. video = decoder.get_frames_at(indices=indices).data.contiguous()
  513. metadata.frames_indices = indices
  514. return video, metadata
  515. VIDEO_DECODERS = {
  516. "decord": read_video_decord,
  517. "opencv": read_video_opencv,
  518. "pyav": read_video_pyav,
  519. "torchvision": read_video_torchvision,
  520. "torchcodec": read_video_torchcodec,
  521. }
  522. def load_video(
  523. video: VideoInput,
  524. num_frames: Optional[int] = None,
  525. fps: Optional[Union[int, float]] = None,
  526. backend: str = "pyav",
  527. sample_indices_fn: Optional[Callable] = None,
  528. **kwargs,
  529. ) -> np.ndarray:
  530. """
  531. Loads `video` to a numpy array.
  532. Args:
  533. video (`VideoInput`):
  534. The video to convert to the numpy array format. Can be a link to video or local path.
  535. num_frames (`int`, *optional*):
  536. Number of frames to sample uniformly. If not passed, the whole video is loaded.
  537. fps (`int` or `float`, *optional*):
  538. Number of frames to sample per second. Should be passed only when `num_frames=None`.
  539. If not specified and `num_frames==None`, all frames are sampled.
  540. backend (`str`, *optional*, defaults to `"pyav"`):
  541. The backend to use when loading the video. Can be any of ["decord", "pyav", "opencv", "torchvision", "torchcodec"]. Defaults to "pyav".
  542. sample_indices_fn (`Callable`, *optional*):
  543. A callable function that will return indices at which the video should be sampled. If the video has to be loaded using
  544. by a different sampling technique than provided by `num_frames` or `fps` arguments, one should provide their own `sample_indices_fn`.
  545. If not provided, simple uniformt sampling with fps is performed, otherwise `sample_indices_fn` has priority over other args.
  546. The function expects at input the all args along with all kwargs passed to `load_video` and should output valid
  547. indices at which the video should be sampled. For example:
  548. Example:
  549. def sample_indices_fn(metadata, **kwargs):
  550. return np.linspace(0, metadata.total_num_frames - 1, num_frames, dtype=int)
  551. Returns:
  552. tuple[`np.ndarray`, Dict]: A tuple containing:
  553. - Numpy array of frames in RGB (shape: [num_frames, height, width, 3]).
  554. - Metadata dictionary.
  555. """
  556. # If `sample_indices_fn` is given, we can accept any args as those might be needed by custom `sample_indices_fn`
  557. if fps is not None and num_frames is not None and sample_indices_fn is None:
  558. raise ValueError(
  559. "`num_frames`, `fps`, and `sample_indices_fn` are mutually exclusive arguments, please use only one!"
  560. )
  561. # If user didn't pass a sampling function, create one on the fly with default logic
  562. if sample_indices_fn is None:
  563. def sample_indices_fn_func(metadata, **fn_kwargs):
  564. return default_sample_indices_fn(metadata, num_frames=num_frames, fps=fps, **fn_kwargs)
  565. sample_indices_fn = sample_indices_fn_func
  566. # Early exit if provided an array or `PIL` frames
  567. if not isinstance(video, str):
  568. metadata = [None] * len(video)
  569. return video, metadata
  570. if urlparse(video).netloc in ["www.youtube.com", "youtube.com"]:
  571. if not is_yt_dlp_available():
  572. raise ImportError("To load a video from YouTube url you have to install `yt_dlp` first.")
  573. # Lazy import from yt_dlp
  574. requires_backends(load_video, ["yt_dlp"])
  575. from yt_dlp import YoutubeDL
  576. buffer = BytesIO()
  577. with redirect_stdout(buffer), YoutubeDL() as f:
  578. f.download([video])
  579. bytes_obj = buffer.getvalue()
  580. file_obj = BytesIO(bytes_obj)
  581. elif video.startswith("http://") or video.startswith("https://"):
  582. file_obj = BytesIO(requests.get(video).content)
  583. elif os.path.isfile(video):
  584. file_obj = video
  585. else:
  586. raise TypeError("Incorrect format used for video. Should be an url linking to an video or a local path.")
  587. # can also load with decord, but not cv2/torchvision
  588. # both will fail in case of url links
  589. video_is_url = video.startswith("http://") or video.startswith("https://")
  590. if video_is_url and backend == "opencv":
  591. raise ValueError("If you are trying to load a video from URL, you cannot use 'opencv' as backend")
  592. if (
  593. (not is_decord_available() and backend == "decord")
  594. or (not is_av_available() and backend == "pyav")
  595. or (not is_cv2_available() and backend == "opencv")
  596. or (not is_torchvision_available() and backend == "torchvision")
  597. or (not is_torchcodec_available() and backend == "torchcodec")
  598. ):
  599. raise ImportError(
  600. f"You chose backend={backend} for loading the video but the required library is not found in your environment "
  601. f"Make sure to install {backend} before loading the video."
  602. )
  603. video_decoder = VIDEO_DECODERS[backend]
  604. video, metadata = video_decoder(file_obj, sample_indices_fn, **kwargs)
  605. return video, metadata
  606. def convert_to_rgb(
  607. video: np.ndarray,
  608. input_data_format: Optional[Union[str, ChannelDimension]] = None,
  609. ) -> np.ndarray:
  610. """
  611. Convert video to RGB by blending the transparency layer if it's in RGBA format, otherwise simply returns it.
  612. Args:
  613. video (`np.ndarray`):
  614. The video to convert.
  615. input_data_format (`ChannelDimension`, *optional*):
  616. The channel dimension format of the input video. If unset, will use the inferred format from the input.
  617. """
  618. if not isinstance(video, np.ndarray):
  619. raise TypeError(f"Video has to be a numpy array to convert to RGB format, but found {type(video)}")
  620. # np.array usually comes with ChannelDimension.LAST so let's convert it
  621. if input_data_format is None:
  622. input_data_format = infer_channel_dimension_format(video)
  623. video = to_channel_dimension_format(video, ChannelDimension.FIRST, input_channel_dim=input_data_format)
  624. # 3 channels for RGB already
  625. if video.shape[-3] == 3:
  626. return video
  627. # Grayscale video so we repeat it 3 times for each channel
  628. if video.shape[-3] == 1:
  629. return video.repeat(3, -3)
  630. if not (video[..., 3, :, :] < 255).any():
  631. return video
  632. # There is a transparency layer, blend it with a white background.
  633. # Calculate the alpha proportion for blending.
  634. alpha = video[..., 3, :, :] / 255.0
  635. video = (1 - alpha[..., None, :, :]) * 255 + alpha[..., None, :, :] * video[..., 3, :, :]
  636. return video
  637. def pad(
  638. video: np.ndarray,
  639. padding: Union[int, tuple[int, int], Iterable[tuple[int, int]]],
  640. mode: PaddingMode = PaddingMode.CONSTANT,
  641. constant_values: Union[float, Iterable[float]] = 0.0,
  642. data_format: Optional[Union[str, ChannelDimension]] = None,
  643. input_data_format: Optional[Union[str, ChannelDimension]] = None,
  644. ) -> np.ndarray:
  645. """
  646. Pads the `video` with the specified (height, width) `padding` and `mode`.
  647. Args:
  648. video (`np.ndarray`):
  649. The video to pad.
  650. padding (`int` or `tuple[int, int]` or `Iterable[tuple[int, int]]`):
  651. Padding to apply to the edges of the height, width axes. Can be one of three formats:
  652. - `((before_height, after_height), (before_width, after_width))` unique pad widths for each axis.
  653. - `((before, after),)` yields same before and after pad for height and width.
  654. - `(pad,)` or int is a shortcut for before = after = pad width for all axes.
  655. mode (`PaddingMode`):
  656. The padding mode to use. Can be one of:
  657. - `"constant"`: pads with a constant value.
  658. - `"reflect"`: pads with the reflection of the vector mirrored on the first and last values of the
  659. vector along each axis.
  660. - `"replicate"`: pads with the replication of the last value on the edge of the array along each axis.
  661. - `"symmetric"`: pads with the reflection of the vector mirrored along the edge of the array.
  662. constant_values (`float` or `Iterable[float]`, *optional*):
  663. The value to use for the padding if `mode` is `"constant"`.
  664. data_format (`str` or `ChannelDimension`, *optional*):
  665. The channel dimension format for the output video. Can be one of:
  666. - `"channels_first"` or `ChannelDimension.FIRST`: video in (num_frames, num_channels, height, width) format.
  667. - `"channels_last"` or `ChannelDimension.LAST`: video in (num_frames, height, width, num_channels) format.
  668. If unset, will use same as the input video.
  669. input_data_format (`str` or `ChannelDimension`, *optional*):
  670. The channel dimension format for the input video. Can be one of:
  671. - `"channels_first"` or `ChannelDimension.FIRST`: video in (num_frames, num_channels, height, width) format.
  672. - `"channels_last"` or `ChannelDimension.LAST`: video in (num_frames, height, width, num_channels) format.
  673. If unset, will use the inferred format of the input video.
  674. Returns:
  675. `np.ndarray`: The padded video.
  676. """
  677. if input_data_format is None:
  678. input_data_format = infer_channel_dimension_format(video)
  679. def _expand_for_data_format(values):
  680. """
  681. Convert values to be in the format expected by np.pad based on the data format.
  682. """
  683. if isinstance(values, (int, float)):
  684. values = ((values, values), (values, values))
  685. elif isinstance(values, tuple) and len(values) == 1:
  686. values = ((values[0], values[0]), (values[0], values[0]))
  687. elif isinstance(values, tuple) and len(values) == 2 and isinstance(values[0], int):
  688. values = (values, values)
  689. elif isinstance(values, tuple) and len(values) == 2 and isinstance(values[0], tuple):
  690. pass
  691. else:
  692. raise ValueError(f"Unsupported format: {values}")
  693. # add 0 for channel dimension
  694. values = (
  695. ((0, 0), (0, 0), *values) if input_data_format == ChannelDimension.FIRST else ((0, 0), *values, (0, 0))
  696. )
  697. # Add additional padding if there's a batch dimension
  698. values = (0, *values) if video.ndim == 5 else values
  699. return values
  700. padding_map = {
  701. PaddingMode.CONSTANT: "constant",
  702. PaddingMode.REFLECT: "reflect",
  703. PaddingMode.REPLICATE: "replicate",
  704. PaddingMode.SYMMETRIC: "symmetric",
  705. }
  706. padding = _expand_for_data_format(padding)
  707. pad_kwargs = {}
  708. if mode not in padding_map:
  709. raise ValueError(f"Invalid padding mode: {mode}")
  710. elif mode == PaddingMode.CONSTANT:
  711. pad_kwargs["constant_values"] = _expand_for_data_format(constant_values)
  712. video = np.pad(video, padding, mode=padding_map[mode], **pad_kwargs)
  713. video = to_channel_dimension_format(video, data_format, input_data_format) if data_format is not None else video
  714. return video
  715. def group_videos_by_shape(
  716. videos: list["torch.Tensor"],
  717. ) -> tuple[dict[tuple[int, int], "torch.Tensor"], dict[int, tuple[tuple[int, int], int]]]:
  718. """
  719. Groups videos by shape.
  720. Returns a dictionary with the shape as key and a list of videos with that shape as value,
  721. and a dictionary with the index of the video in the original list as key and the shape and index in the grouped list as value.
  722. """
  723. grouped_videos = {}
  724. grouped_videos_index = {}
  725. for i, video in enumerate(videos):
  726. shape = video.shape[-2::]
  727. num_frames = video.shape[-4] # video format BTCHW
  728. shape = (num_frames, *shape)
  729. if shape not in grouped_videos:
  730. grouped_videos[shape] = []
  731. grouped_videos[shape].append(video)
  732. grouped_videos_index[i] = (shape, len(grouped_videos[shape]) - 1)
  733. # stack videos with the same size and number of frames
  734. grouped_videos = {shape: torch.stack(videos, dim=0) for shape, videos in grouped_videos.items()}
  735. return grouped_videos, grouped_videos_index
  736. def reorder_videos(
  737. processed_videos: dict[tuple[int, int], "torch.Tensor"],
  738. grouped_videos_index: dict[int, tuple[tuple[int, int], int]],
  739. ) -> list["torch.Tensor"]:
  740. """
  741. Reconstructs a list of videos in the original order.
  742. """
  743. return [
  744. processed_videos[grouped_videos_index[i][0]][grouped_videos_index[i][1]]
  745. for i in range(len(grouped_videos_index))
  746. ]