_video_opt.py 20 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521
  1. import math
  2. import warnings
  3. from fractions import Fraction
  4. from typing import Optional, Union
  5. import torch
  6. from ..extension import _load_library
  7. from ._video_deprecation_warning import _raise_video_deprecation_warning
  8. try:
  9. _load_library("video_reader")
  10. _HAS_CPU_VIDEO_DECODER = True
  11. except (ImportError, OSError):
  12. _HAS_CPU_VIDEO_DECODER = False
  13. _HAS_VIDEO_OPT = _HAS_CPU_VIDEO_DECODER # For BC
  14. default_timebase = Fraction(0, 1)
  15. # simple class for torch scripting
  16. # the complex Fraction class from fractions module is not scriptable
  17. class Timebase:
  18. __annotations__ = {"numerator": int, "denominator": int}
  19. __slots__ = ["numerator", "denominator"]
  20. def __init__(
  21. self,
  22. numerator: int,
  23. denominator: int,
  24. ) -> None:
  25. self.numerator = numerator
  26. self.denominator = denominator
  27. class VideoMetaData:
  28. __annotations__ = {
  29. "has_video": bool,
  30. "video_timebase": Timebase,
  31. "video_duration": float,
  32. "video_fps": float,
  33. "has_audio": bool,
  34. "audio_timebase": Timebase,
  35. "audio_duration": float,
  36. "audio_sample_rate": float,
  37. }
  38. __slots__ = [
  39. "has_video",
  40. "video_timebase",
  41. "video_duration",
  42. "video_fps",
  43. "has_audio",
  44. "audio_timebase",
  45. "audio_duration",
  46. "audio_sample_rate",
  47. ]
  48. def __init__(self) -> None:
  49. self.has_video = False
  50. self.video_timebase = Timebase(0, 1)
  51. self.video_duration = 0.0
  52. self.video_fps = 0.0
  53. self.has_audio = False
  54. self.audio_timebase = Timebase(0, 1)
  55. self.audio_duration = 0.0
  56. self.audio_sample_rate = 0.0
  57. def _validate_pts(pts_range: tuple[int, int]) -> None:
  58. if pts_range[0] > pts_range[1] > 0:
  59. raise ValueError(
  60. f"Start pts should not be smaller than end pts, got start pts: {pts_range[0]} and end pts: {pts_range[1]}"
  61. )
  62. def _fill_info(
  63. vtimebase: torch.Tensor,
  64. vfps: torch.Tensor,
  65. vduration: torch.Tensor,
  66. atimebase: torch.Tensor,
  67. asample_rate: torch.Tensor,
  68. aduration: torch.Tensor,
  69. ) -> VideoMetaData:
  70. """
  71. Build update VideoMetaData struct with info about the video
  72. """
  73. meta = VideoMetaData()
  74. if vtimebase.numel() > 0:
  75. meta.video_timebase = Timebase(int(vtimebase[0].item()), int(vtimebase[1].item()))
  76. timebase = vtimebase[0].item() / float(vtimebase[1].item())
  77. if vduration.numel() > 0:
  78. meta.has_video = True
  79. meta.video_duration = float(vduration.item()) * timebase
  80. if vfps.numel() > 0:
  81. meta.video_fps = float(vfps.item())
  82. if atimebase.numel() > 0:
  83. meta.audio_timebase = Timebase(int(atimebase[0].item()), int(atimebase[1].item()))
  84. timebase = atimebase[0].item() / float(atimebase[1].item())
  85. if aduration.numel() > 0:
  86. meta.has_audio = True
  87. meta.audio_duration = float(aduration.item()) * timebase
  88. if asample_rate.numel() > 0:
  89. meta.audio_sample_rate = float(asample_rate.item())
  90. return meta
  91. def _align_audio_frames(
  92. aframes: torch.Tensor, aframe_pts: torch.Tensor, audio_pts_range: tuple[int, int]
  93. ) -> torch.Tensor:
  94. start, end = aframe_pts[0], aframe_pts[-1]
  95. num_samples = aframes.size(0)
  96. step_per_aframe = float(end - start + 1) / float(num_samples)
  97. s_idx = 0
  98. e_idx = num_samples
  99. if start < audio_pts_range[0]:
  100. s_idx = int((audio_pts_range[0] - start) / step_per_aframe)
  101. if audio_pts_range[1] != -1 and end > audio_pts_range[1]:
  102. e_idx = int((audio_pts_range[1] - end) / step_per_aframe)
  103. return aframes[s_idx:e_idx, :]
  104. def _read_video_from_file(
  105. filename: str,
  106. seek_frame_margin: float = 0.25,
  107. read_video_stream: bool = True,
  108. video_width: int = 0,
  109. video_height: int = 0,
  110. video_min_dimension: int = 0,
  111. video_max_dimension: int = 0,
  112. video_pts_range: tuple[int, int] = (0, -1),
  113. video_timebase: Fraction = default_timebase,
  114. read_audio_stream: bool = True,
  115. audio_samples: int = 0,
  116. audio_channels: int = 0,
  117. audio_pts_range: tuple[int, int] = (0, -1),
  118. audio_timebase: Fraction = default_timebase,
  119. ) -> tuple[torch.Tensor, torch.Tensor, VideoMetaData]:
  120. """
  121. Reads a video from a file, returning both the video frames and the audio frames
  122. Args:
  123. filename (str): path to the video file
  124. seek_frame_margin (double, optional): seeking frame in the stream is imprecise. Thus,
  125. when video_start_pts is specified, we seek the pts earlier by seek_frame_margin seconds
  126. read_video_stream (int, optional): whether read video stream. If yes, set to 1. Otherwise, 0
  127. video_width/video_height/video_min_dimension/video_max_dimension (int): together decide
  128. the size of decoded frames:
  129. - When video_width = 0, video_height = 0, video_min_dimension = 0,
  130. and video_max_dimension = 0, keep the original frame resolution
  131. - When video_width = 0, video_height = 0, video_min_dimension != 0,
  132. and video_max_dimension = 0, keep the aspect ratio and resize the
  133. frame so that shorter edge size is video_min_dimension
  134. - When video_width = 0, video_height = 0, video_min_dimension = 0,
  135. and video_max_dimension != 0, keep the aspect ratio and resize
  136. the frame so that longer edge size is video_max_dimension
  137. - When video_width = 0, video_height = 0, video_min_dimension != 0,
  138. and video_max_dimension != 0, resize the frame so that shorter
  139. edge size is video_min_dimension, and longer edge size is
  140. video_max_dimension. The aspect ratio may not be preserved
  141. - When video_width = 0, video_height != 0, video_min_dimension = 0,
  142. and video_max_dimension = 0, keep the aspect ratio and resize
  143. the frame so that frame video_height is $video_height
  144. - When video_width != 0, video_height == 0, video_min_dimension = 0,
  145. and video_max_dimension = 0, keep the aspect ratio and resize
  146. the frame so that frame video_width is $video_width
  147. - When video_width != 0, video_height != 0, video_min_dimension = 0,
  148. and video_max_dimension = 0, resize the frame so that frame
  149. video_width and video_height are set to $video_width and
  150. $video_height, respectively
  151. video_pts_range (list(int), optional): the start and end presentation timestamp of video stream
  152. video_timebase (Fraction, optional): a Fraction rational number which denotes timebase in video stream
  153. read_audio_stream (int, optional): whether read audio stream. If yes, set to 1. Otherwise, 0
  154. audio_samples (int, optional): audio sampling rate
  155. audio_channels (int optional): audio channels
  156. audio_pts_range (list(int), optional): the start and end presentation timestamp of audio stream
  157. audio_timebase (Fraction, optional): a Fraction rational number which denotes time base in audio stream
  158. Returns
  159. vframes (Tensor[T, H, W, C]): the `T` video frames
  160. aframes (Tensor[L, K]): the audio frames, where `L` is the number of points and
  161. `K` is the number of audio_channels
  162. info (Dict): metadata for the video and audio. Can contain the fields video_fps (float)
  163. and audio_fps (int)
  164. """
  165. _raise_video_deprecation_warning()
  166. _validate_pts(video_pts_range)
  167. _validate_pts(audio_pts_range)
  168. result = torch.ops.video_reader.read_video_from_file(
  169. filename,
  170. seek_frame_margin,
  171. 0, # getPtsOnly
  172. read_video_stream,
  173. video_width,
  174. video_height,
  175. video_min_dimension,
  176. video_max_dimension,
  177. video_pts_range[0],
  178. video_pts_range[1],
  179. video_timebase.numerator,
  180. video_timebase.denominator,
  181. read_audio_stream,
  182. audio_samples,
  183. audio_channels,
  184. audio_pts_range[0],
  185. audio_pts_range[1],
  186. audio_timebase.numerator,
  187. audio_timebase.denominator,
  188. )
  189. vframes, _vframe_pts, vtimebase, vfps, vduration, aframes, aframe_pts, atimebase, asample_rate, aduration = result
  190. info = _fill_info(vtimebase, vfps, vduration, atimebase, asample_rate, aduration)
  191. if aframes.numel() > 0:
  192. # when audio stream is found
  193. aframes = _align_audio_frames(aframes, aframe_pts, audio_pts_range)
  194. return vframes, aframes, info
  195. def _read_video_timestamps_from_file(filename: str) -> tuple[list[int], list[int], VideoMetaData]:
  196. """
  197. Decode all video- and audio frames in the video. Only pts
  198. (presentation timestamp) is returned. The actual frame pixel data is not
  199. copied. Thus, it is much faster than read_video(...)
  200. """
  201. result = torch.ops.video_reader.read_video_from_file(
  202. filename,
  203. 0, # seek_frame_margin
  204. 1, # getPtsOnly
  205. 1, # read_video_stream
  206. 0, # video_width
  207. 0, # video_height
  208. 0, # video_min_dimension
  209. 0, # video_max_dimension
  210. 0, # video_start_pts
  211. -1, # video_end_pts
  212. 0, # video_timebase_num
  213. 1, # video_timebase_den
  214. 1, # read_audio_stream
  215. 0, # audio_samples
  216. 0, # audio_channels
  217. 0, # audio_start_pts
  218. -1, # audio_end_pts
  219. 0, # audio_timebase_num
  220. 1, # audio_timebase_den
  221. )
  222. _vframes, vframe_pts, vtimebase, vfps, vduration, _aframes, aframe_pts, atimebase, asample_rate, aduration = result
  223. info = _fill_info(vtimebase, vfps, vduration, atimebase, asample_rate, aduration)
  224. vframe_pts = vframe_pts.numpy().tolist()
  225. aframe_pts = aframe_pts.numpy().tolist()
  226. return vframe_pts, aframe_pts, info
  227. def _probe_video_from_file(filename: str) -> VideoMetaData:
  228. """
  229. Probe a video file and return VideoMetaData with info about the video
  230. """
  231. _raise_video_deprecation_warning()
  232. result = torch.ops.video_reader.probe_video_from_file(filename)
  233. vtimebase, vfps, vduration, atimebase, asample_rate, aduration = result
  234. info = _fill_info(vtimebase, vfps, vduration, atimebase, asample_rate, aduration)
  235. return info
  236. def _read_video_from_memory(
  237. video_data: torch.Tensor,
  238. seek_frame_margin: float = 0.25,
  239. read_video_stream: int = 1,
  240. video_width: int = 0,
  241. video_height: int = 0,
  242. video_min_dimension: int = 0,
  243. video_max_dimension: int = 0,
  244. video_pts_range: tuple[int, int] = (0, -1),
  245. video_timebase_numerator: int = 0,
  246. video_timebase_denominator: int = 1,
  247. read_audio_stream: int = 1,
  248. audio_samples: int = 0,
  249. audio_channels: int = 0,
  250. audio_pts_range: tuple[int, int] = (0, -1),
  251. audio_timebase_numerator: int = 0,
  252. audio_timebase_denominator: int = 1,
  253. ) -> tuple[torch.Tensor, torch.Tensor]:
  254. """
  255. Reads a video from memory, returning both the video frames as the audio frames
  256. This function is torchscriptable.
  257. Args:
  258. video_data (data type could be 1) torch.Tensor, dtype=torch.int8 or 2) python bytes):
  259. compressed video content stored in either 1) torch.Tensor 2) python bytes
  260. seek_frame_margin (double, optional): seeking frame in the stream is imprecise.
  261. Thus, when video_start_pts is specified, we seek the pts earlier by seek_frame_margin seconds
  262. read_video_stream (int, optional): whether read video stream. If yes, set to 1. Otherwise, 0
  263. video_width/video_height/video_min_dimension/video_max_dimension (int): together decide
  264. the size of decoded frames:
  265. - When video_width = 0, video_height = 0, video_min_dimension = 0,
  266. and video_max_dimension = 0, keep the original frame resolution
  267. - When video_width = 0, video_height = 0, video_min_dimension != 0,
  268. and video_max_dimension = 0, keep the aspect ratio and resize the
  269. frame so that shorter edge size is video_min_dimension
  270. - When video_width = 0, video_height = 0, video_min_dimension = 0,
  271. and video_max_dimension != 0, keep the aspect ratio and resize
  272. the frame so that longer edge size is video_max_dimension
  273. - When video_width = 0, video_height = 0, video_min_dimension != 0,
  274. and video_max_dimension != 0, resize the frame so that shorter
  275. edge size is video_min_dimension, and longer edge size is
  276. video_max_dimension. The aspect ratio may not be preserved
  277. - When video_width = 0, video_height != 0, video_min_dimension = 0,
  278. and video_max_dimension = 0, keep the aspect ratio and resize
  279. the frame so that frame video_height is $video_height
  280. - When video_width != 0, video_height == 0, video_min_dimension = 0,
  281. and video_max_dimension = 0, keep the aspect ratio and resize
  282. the frame so that frame video_width is $video_width
  283. - When video_width != 0, video_height != 0, video_min_dimension = 0,
  284. and video_max_dimension = 0, resize the frame so that frame
  285. video_width and video_height are set to $video_width and
  286. $video_height, respectively
  287. video_pts_range (list(int), optional): the start and end presentation timestamp of video stream
  288. video_timebase_numerator / video_timebase_denominator (float, optional): a rational
  289. number which denotes timebase in video stream
  290. read_audio_stream (int, optional): whether read audio stream. If yes, set to 1. Otherwise, 0
  291. audio_samples (int, optional): audio sampling rate
  292. audio_channels (int optional): audio audio_channels
  293. audio_pts_range (list(int), optional): the start and end presentation timestamp of audio stream
  294. audio_timebase_numerator / audio_timebase_denominator (float, optional):
  295. a rational number which denotes time base in audio stream
  296. Returns:
  297. vframes (Tensor[T, H, W, C]): the `T` video frames
  298. aframes (Tensor[L, K]): the audio frames, where `L` is the number of points and
  299. `K` is the number of channels
  300. """
  301. _raise_video_deprecation_warning()
  302. _validate_pts(video_pts_range)
  303. _validate_pts(audio_pts_range)
  304. if not isinstance(video_data, torch.Tensor):
  305. with warnings.catch_warnings():
  306. # Ignore the warning because we actually don't modify the buffer in this function
  307. warnings.filterwarnings("ignore", message="The given buffer is not writable")
  308. video_data = torch.frombuffer(video_data, dtype=torch.uint8)
  309. result = torch.ops.video_reader.read_video_from_memory(
  310. video_data,
  311. seek_frame_margin,
  312. 0, # getPtsOnly
  313. read_video_stream,
  314. video_width,
  315. video_height,
  316. video_min_dimension,
  317. video_max_dimension,
  318. video_pts_range[0],
  319. video_pts_range[1],
  320. video_timebase_numerator,
  321. video_timebase_denominator,
  322. read_audio_stream,
  323. audio_samples,
  324. audio_channels,
  325. audio_pts_range[0],
  326. audio_pts_range[1],
  327. audio_timebase_numerator,
  328. audio_timebase_denominator,
  329. )
  330. vframes, _vframe_pts, vtimebase, vfps, vduration, aframes, aframe_pts, atimebase, asample_rate, aduration = result
  331. if aframes.numel() > 0:
  332. # when audio stream is found
  333. aframes = _align_audio_frames(aframes, aframe_pts, audio_pts_range)
  334. return vframes, aframes
  335. def _read_video_timestamps_from_memory(
  336. video_data: torch.Tensor,
  337. ) -> tuple[list[int], list[int], VideoMetaData]:
  338. """
  339. Decode all frames in the video. Only pts (presentation timestamp) is returned.
  340. The actual frame pixel data is not copied. Thus, read_video_timestamps(...)
  341. is much faster than read_video(...)
  342. """
  343. if not isinstance(video_data, torch.Tensor):
  344. with warnings.catch_warnings():
  345. # Ignore the warning because we actually don't modify the buffer in this function
  346. warnings.filterwarnings("ignore", message="The given buffer is not writable")
  347. video_data = torch.frombuffer(video_data, dtype=torch.uint8)
  348. result = torch.ops.video_reader.read_video_from_memory(
  349. video_data,
  350. 0, # seek_frame_margin
  351. 1, # getPtsOnly
  352. 1, # read_video_stream
  353. 0, # video_width
  354. 0, # video_height
  355. 0, # video_min_dimension
  356. 0, # video_max_dimension
  357. 0, # video_start_pts
  358. -1, # video_end_pts
  359. 0, # video_timebase_num
  360. 1, # video_timebase_den
  361. 1, # read_audio_stream
  362. 0, # audio_samples
  363. 0, # audio_channels
  364. 0, # audio_start_pts
  365. -1, # audio_end_pts
  366. 0, # audio_timebase_num
  367. 1, # audio_timebase_den
  368. )
  369. _raise_video_deprecation_warning()
  370. _vframes, vframe_pts, vtimebase, vfps, vduration, _aframes, aframe_pts, atimebase, asample_rate, aduration = result
  371. info = _fill_info(vtimebase, vfps, vduration, atimebase, asample_rate, aduration)
  372. vframe_pts = vframe_pts.numpy().tolist()
  373. aframe_pts = aframe_pts.numpy().tolist()
  374. return vframe_pts, aframe_pts, info
  375. def _probe_video_from_memory(
  376. video_data: torch.Tensor,
  377. ) -> VideoMetaData:
  378. """
  379. Probe a video in memory and return VideoMetaData with info about the video
  380. This function is torchscriptable
  381. """
  382. _raise_video_deprecation_warning()
  383. if not isinstance(video_data, torch.Tensor):
  384. with warnings.catch_warnings():
  385. # Ignore the warning because we actually don't modify the buffer in this function
  386. warnings.filterwarnings("ignore", message="The given buffer is not writable")
  387. video_data = torch.frombuffer(video_data, dtype=torch.uint8)
  388. result = torch.ops.video_reader.probe_video_from_memory(video_data)
  389. vtimebase, vfps, vduration, atimebase, asample_rate, aduration = result
  390. info = _fill_info(vtimebase, vfps, vduration, atimebase, asample_rate, aduration)
  391. return info
  392. def _read_video(
  393. filename: str,
  394. start_pts: Union[float, Fraction] = 0,
  395. end_pts: Optional[Union[float, Fraction]] = None,
  396. pts_unit: str = "pts",
  397. ) -> tuple[torch.Tensor, torch.Tensor, dict[str, float]]:
  398. _raise_video_deprecation_warning()
  399. if end_pts is None:
  400. end_pts = float("inf")
  401. if pts_unit == "pts":
  402. warnings.warn(
  403. "The pts_unit 'pts' gives wrong results and will be removed in a "
  404. + "follow-up version. Please use pts_unit 'sec'."
  405. )
  406. info = _probe_video_from_file(filename)
  407. has_video = info.has_video
  408. has_audio = info.has_audio
  409. def get_pts(time_base):
  410. start_offset = start_pts
  411. end_offset = end_pts
  412. if pts_unit == "sec":
  413. start_offset = int(math.floor(start_pts * (1 / time_base)))
  414. if end_offset != float("inf"):
  415. end_offset = int(math.ceil(end_pts * (1 / time_base)))
  416. if end_offset == float("inf"):
  417. end_offset = -1
  418. return start_offset, end_offset
  419. video_pts_range = (0, -1)
  420. video_timebase = default_timebase
  421. if has_video:
  422. video_timebase = Fraction(info.video_timebase.numerator, info.video_timebase.denominator)
  423. video_pts_range = get_pts(video_timebase)
  424. audio_pts_range = (0, -1)
  425. audio_timebase = default_timebase
  426. if has_audio:
  427. audio_timebase = Fraction(info.audio_timebase.numerator, info.audio_timebase.denominator)
  428. audio_pts_range = get_pts(audio_timebase)
  429. vframes, aframes, info = _read_video_from_file(
  430. filename,
  431. read_video_stream=True,
  432. video_pts_range=video_pts_range,
  433. video_timebase=video_timebase,
  434. read_audio_stream=True,
  435. audio_pts_range=audio_pts_range,
  436. audio_timebase=audio_timebase,
  437. )
  438. _info = {}
  439. if has_video:
  440. _info["video_fps"] = info.video_fps
  441. if has_audio:
  442. _info["audio_fps"] = info.audio_sample_rate
  443. return vframes, aframes, _info
  444. def _read_video_timestamps(
  445. filename: str, pts_unit: str = "pts"
  446. ) -> tuple[Union[list[int], list[Fraction]], Optional[float]]:
  447. _raise_video_deprecation_warning()
  448. if pts_unit == "pts":
  449. warnings.warn(
  450. "The pts_unit 'pts' gives wrong results and will be removed in a "
  451. + "follow-up version. Please use pts_unit 'sec'."
  452. )
  453. pts: Union[list[int], list[Fraction]]
  454. pts, _, info = _read_video_timestamps_from_file(filename)
  455. if pts_unit == "sec":
  456. video_time_base = Fraction(info.video_timebase.numerator, info.video_timebase.denominator)
  457. pts = [x * video_time_base for x in pts]
  458. video_fps = info.video_fps if info.has_video else None
  459. return pts, video_fps