videocomposer_pipeline.py 14 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382
  1. # Copyright (c) Alibaba, Inc. and its affiliates.
  2. import os
  3. import random
  4. import subprocess
  5. import tempfile
  6. import time
  7. from functools import partial
  8. from typing import Any, Dict
  9. import cv2
  10. import imageio
  11. import numpy as np
  12. import torch
  13. import torchvision.transforms as T
  14. from mvextractor.videocap import VideoCap
  15. from PIL import Image
  16. import modelscope.models.multi_modal.videocomposer.data as data
  17. from modelscope.metainfo import Pipelines
  18. from modelscope.models.multi_modal.videocomposer.data.transforms import (
  19. CenterCropV3, random_resize)
  20. from modelscope.models.multi_modal.videocomposer.ops.random_mask import (
  21. make_irregular_mask, make_rectangle_mask, make_uncrop)
  22. from modelscope.models.multi_modal.videocomposer.utils.utils import rand_name
  23. from modelscope.outputs import OutputKeys
  24. from modelscope.pipelines.base import Input, Pipeline
  25. from modelscope.pipelines.builder import PIPELINES
  26. from modelscope.utils.constant import Tasks
  27. from modelscope.utils.device import device_placement
  28. from modelscope.utils.logger import get_logger
  29. logger = get_logger()
  30. @PIPELINES.register_module(
  31. Tasks.text_to_video_synthesis, module_name=Pipelines.videocomposer)
  32. class VideoComposerPipeline(Pipeline):
  33. r""" Video Composer Pipeline.
  34. Examples:
  35. >>> from modelscope.pipelines import pipeline
  36. >>> from modelscope.utils.constant import Tasks
  37. >>> pipe = pipeline(
  38. task=Tasks.text_to_video_synthesis,
  39. model='buptwq/videocomposer',
  40. model_revision='v1.0.1')
  41. >>> inputs = {'Video:FILE': 'path/input_video.mp4',
  42. 'Image:FILE': 'path/input_image.png',
  43. 'text': 'the text description'}
  44. >>> output = pipe(inputs)
  45. """
  46. def __init__(self, model: str, **kwargs):
  47. """
  48. use `model` to create a videocomposer pipeline for prediction
  49. Args:
  50. model: model id on modelscope hub.
  51. """
  52. super().__init__(model=model)
  53. self.log_dir = kwargs.pop('log_dir', './video_outputs')
  54. if not os.path.exists(self.log_dir):
  55. os.makedirs(self.log_dir)
  56. self.feature_framerate = kwargs.pop('feature_framerate', 4)
  57. self.frame_lens = kwargs.pop('frame_lens', [
  58. 16,
  59. 16,
  60. 16,
  61. 16,
  62. ])
  63. self.feature_framerates = kwargs.pop('feature_framerates', [
  64. 4,
  65. ])
  66. self.batch_sizes = kwargs.pop('batch_sizes', {
  67. '1': 1,
  68. '4': 1,
  69. '8': 1,
  70. '16': 1,
  71. })
  72. l1 = len(self.frame_lens)
  73. l2 = len(self.feature_framerates)
  74. self.max_frames = self.frame_lens[0 % (l1 * l2) // l2]
  75. self.batch_size = self.batch_sizes[str(self.max_frames)]
  76. self.resolution = kwargs.pop('resolution', 256)
  77. self.image_resolution = kwargs.pop('image_resolution', 256)
  78. self.mean = kwargs.pop('mean', [0.5, 0.5, 0.5])
  79. self.std = kwargs.pop('std', [0.5, 0.5, 0.5])
  80. self.vit_image_size = kwargs.pop('vit_image_size', 224)
  81. self.vit_mean = kwargs.pop('vit_mean',
  82. [0.48145466, 0.4578275, 0.40821073])
  83. self.vit_std = kwargs.pop('vit_std',
  84. [0.26862954, 0.26130258, 0.27577711])
  85. self.misc_size = kwargs.pop('kwargs.pop', 384)
  86. self.visual_mv = kwargs.pop('visual_mv', False)
  87. self.max_words = kwargs.pop('max_words', 1000)
  88. self.mvs_visual = kwargs.pop('mvs_visual', False)
  89. self.infer_trans = data.Compose([
  90. data.CenterCropV2(size=self.resolution),
  91. data.ToTensor(),
  92. data.Normalize(mean=self.mean, std=self.std)
  93. ])
  94. self.misc_transforms = data.Compose([
  95. T.Lambda(partial(random_resize, size=self.misc_size)),
  96. data.CenterCropV2(self.misc_size),
  97. data.ToTensor()
  98. ])
  99. self.mv_transforms = data.Compose(
  100. [T.Resize(size=self.resolution),
  101. T.CenterCrop(self.resolution)])
  102. self.vit_transforms = T.Compose([
  103. CenterCropV3(self.vit_image_size),
  104. T.ToTensor(),
  105. T.Normalize(mean=self.vit_mean, std=self.vit_std)
  106. ])
  107. def preprocess(self, input: Input) -> Dict[str, Any]:
  108. video_key = input['Video:FILE']
  109. cap_txt = input['text']
  110. style_image = input['Image:FILE']
  111. total_frames = None
  112. feature_framerate = self.feature_framerate
  113. if os.path.exists(video_key):
  114. try:
  115. ref_frame, vit_image, video_data, misc_data, mv_data = self.video_data_preprocess(
  116. video_key, self.feature_framerate, total_frames,
  117. self.mvs_visual)
  118. except Exception as e:
  119. logger.info(
  120. '{} get frames failed... with error: {}'.format(
  121. video_key, e),
  122. flush=True)
  123. ref_frame = torch.zeros(3, self.vit_image_size,
  124. self.vit_image_size)
  125. video_data = torch.zeros(self.max_frames, 3,
  126. self.image_resolution,
  127. self.image_resolution)
  128. misc_data = torch.zeros(self.max_frames, 3, self.misc_size,
  129. self.misc_size)
  130. mv_data = torch.zeros(self.max_frames, 2,
  131. self.image_resolution,
  132. self.image_resolution)
  133. else:
  134. logger.info(
  135. 'The video path does not exist or no video dir provided!')
  136. ref_frame = torch.zeros(3, self.vit_image_size,
  137. self.vit_image_size)
  138. _ = torch.zeros(3, self.vit_image_size, self.vit_image_size)
  139. video_data = torch.zeros(self.max_frames, 3, self.image_resolution,
  140. self.image_resolution)
  141. misc_data = torch.zeros(self.max_frames, 3, self.misc_size,
  142. self.misc_size)
  143. mv_data = torch.zeros(self.max_frames, 2, self.image_resolution,
  144. self.image_resolution)
  145. # inpainting mask
  146. p = random.random()
  147. if p < 0.7:
  148. mask = make_irregular_mask(512, 512)
  149. elif p < 0.9:
  150. mask = make_rectangle_mask(512, 512)
  151. else:
  152. mask = make_uncrop(512, 512)
  153. mask = torch.from_numpy(
  154. cv2.resize(
  155. mask, (self.misc_size, self.misc_size),
  156. interpolation=cv2.INTER_NEAREST)).unsqueeze(0).float()
  157. mask = mask.unsqueeze(0).repeat_interleave(
  158. repeats=self.max_frames, dim=0)
  159. video_input = {
  160. 'ref_frame': ref_frame.unsqueeze(0),
  161. 'cap_txt': cap_txt,
  162. 'video_data': video_data.unsqueeze(0),
  163. 'misc_data': misc_data.unsqueeze(0),
  164. 'feature_framerate': feature_framerate,
  165. 'mask': mask.unsqueeze(0),
  166. 'mv_data': mv_data.unsqueeze(0),
  167. 'style_image': style_image
  168. }
  169. return video_input
  170. def forward(self, input: Dict[str, Any]) -> Dict[str, Any]:
  171. return self.model(input)
  172. def postprocess(self, inputs: Dict[str, Any],
  173. **post_params) -> Dict[str, Any]:
  174. output_video_path = post_params.get('output_video', None)
  175. temp_video_file = False
  176. if output_video_path is not None:
  177. output_video_path = tempfile.NamedTemporaryFile(suffix='.gif').name
  178. temp_video_file = True
  179. if temp_video_file:
  180. return {OutputKeys.OUTPUT_VIDEO: inputs['video_path']}
  181. else:
  182. return {OutputKeys.OUTPUT_VIDEO: inputs['video']}
  183. def video_data_preprocess(self, video_key, feature_framerate, total_frames,
  184. visual_mv):
  185. filename = video_key
  186. for _ in range(5):
  187. try:
  188. frame_types, frames, mvs, mvs_visual = self.extract_motion_vectors(
  189. input_video=filename,
  190. fps=feature_framerate,
  191. visual_mv=visual_mv)
  192. break
  193. except Exception as e:
  194. logger.error(
  195. '{} read video frames and motion vectors failed with error: {}'
  196. .format(video_key, e),
  197. flush=True)
  198. total_frames = len(frame_types)
  199. start_indexs = np.where((np.array(frame_types) == 'I') & (
  200. total_frames - np.arange(total_frames) >= self.max_frames))[0]
  201. start_index = np.random.choice(start_indexs)
  202. indices = np.arange(start_index, start_index + self.max_frames)
  203. # note frames are in BGR mode, need to trans to RGB mode
  204. frames = [Image.fromarray(frames[i][:, :, ::-1]) for i in indices]
  205. mvs = [torch.from_numpy(mvs[i].transpose((2, 0, 1))) for i in indices]
  206. mvs = torch.stack(mvs)
  207. if visual_mv:
  208. images = [(mvs_visual[i][:, :, ::-1]).astype('uint8')
  209. for i in indices]
  210. path = self.log_dir + '/visual_mv/' + video_key.split(
  211. '/')[-1] + '.gif'
  212. if not os.path.exists(self.log_dir + '/visual_mv/'):
  213. os.makedirs(self.log_dir + '/visual_mv/', exist_ok=True)
  214. logger.info('save motion vectors visualization to :', path)
  215. imageio.mimwrite(path, images, fps=8)
  216. have_frames = len(frames) > 0
  217. middle_indix = int(len(frames) / 2)
  218. if have_frames:
  219. ref_frame = frames[middle_indix]
  220. vit_image = self.vit_transforms(ref_frame)
  221. misc_imgs_np = self.misc_transforms[:2](frames)
  222. misc_imgs = self.misc_transforms[2:](misc_imgs_np)
  223. frames = self.infer_trans(frames)
  224. mvs = self.mv_transforms(mvs)
  225. else:
  226. vit_image = torch.zeros(3, self.vit_image_size,
  227. self.vit_image_size)
  228. video_data = torch.zeros(self.max_frames, 3, self.image_resolution,
  229. self.image_resolution)
  230. mv_data = torch.zeros(self.max_frames, 2, self.image_resolution,
  231. self.image_resolution)
  232. misc_data = torch.zeros(self.max_frames, 3, self.misc_size,
  233. self.misc_size)
  234. if have_frames:
  235. video_data[:len(frames), ...] = frames
  236. misc_data[:len(frames), ...] = misc_imgs
  237. mv_data[:len(frames), ...] = mvs
  238. ref_frame = vit_image
  239. del frames
  240. del misc_imgs
  241. del mvs
  242. return ref_frame, vit_image, video_data, misc_data, mv_data
  243. def extract_motion_vectors(self,
  244. input_video,
  245. fps=4,
  246. dump=False,
  247. verbose=False,
  248. visual_mv=False):
  249. if dump:
  250. now = datetime.now().strftime('%Y-%m-%dT%H:%M:%S')
  251. for child in ['frames', 'motion_vectors']:
  252. os.makedirs(os.path.join(f'out-{now}', child), exist_ok=True)
  253. temp = rand_name()
  254. tmp_video = os.path.join(
  255. input_video.split('/')[0], f'{temp}' + input_video.split('/')[-1])
  256. videocapture = cv2.VideoCapture(input_video)
  257. frames_num = videocapture.get(cv2.CAP_PROP_FRAME_COUNT)
  258. fps_video = videocapture.get(cv2.CAP_PROP_FPS)
  259. # check if enough frames
  260. if frames_num / fps_video * fps > 16:
  261. fps = max(fps, 1)
  262. else:
  263. fps = int(16 / (frames_num / fps_video)) + 1
  264. ffmpeg_cmd = f'ffmpeg -threads 8 -loglevel error -i {input_video} -filter:v \
  265. fps={fps} -c:v mpeg4 -f rawvideo {tmp_video}'
  266. if os.path.exists(tmp_video):
  267. os.remove(tmp_video)
  268. subprocess.run(args=ffmpeg_cmd, shell=True, timeout=120)
  269. cap = VideoCap()
  270. # open the video file
  271. ret = cap.open(tmp_video)
  272. if not ret:
  273. raise RuntimeError(f'Could not open {tmp_video}')
  274. step = 0
  275. times = []
  276. frame_types = []
  277. frames = []
  278. mvs = []
  279. mvs_visual = []
  280. # continuously read and display video frames and motion vectors
  281. while True:
  282. if verbose:
  283. logger.info('Frame: ', step, end=' ')
  284. tstart = time.perf_counter()
  285. # read next video frame and corresponding motion vectors
  286. ret, frame, motion_vectors, frame_type, timestamp = cap.read()
  287. tend = time.perf_counter()
  288. telapsed = tend - tstart
  289. times.append(telapsed)
  290. # if there is an error reading the frame
  291. if not ret:
  292. if verbose:
  293. logger.warning('No frame read. Stopping.')
  294. break
  295. frame_save = np.zeros(frame.copy().shape, dtype=np.uint8)
  296. if visual_mv:
  297. frame_save = draw_motion_vectors(frame_save, motion_vectors)
  298. # store motion vectors, frames, etc. in output directory
  299. dump = False
  300. if frame.shape[1] >= frame.shape[0]:
  301. w_half = (frame.shape[1] - frame.shape[0]) // 2
  302. if dump:
  303. cv2.imwrite(
  304. os.path.join('./mv_visual/', f'frame-{step}.jpg'),
  305. frame_save[:, w_half:-w_half])
  306. mvs_visual.append(frame_save[:, w_half:-w_half])
  307. else:
  308. h_half = (frame.shape[0] - frame.shape[1]) // 2
  309. if dump:
  310. cv2.imwrite(
  311. os.path.join('./mv_visual/', f'frame-{step}.jpg'),
  312. frame_save[h_half:-h_half, :])
  313. mvs_visual.append(frame_save[h_half:-h_half, :])
  314. h, w = frame.shape[:2]
  315. mv = np.zeros((h, w, 2))
  316. position = motion_vectors[:, 5:7].clip((0, 0), (w - 1, h - 1))
  317. mv[position[:, 1],
  318. position[:,
  319. 0]] = motion_vectors[:, 0:
  320. 1] * motion_vectors[:, 7:
  321. 9] / motion_vectors[:,
  322. 9:]
  323. step += 1
  324. frame_types.append(frame_type)
  325. frames.append(frame)
  326. mvs.append(mv)
  327. if verbose:
  328. logger.info('average dt: ', np.mean(times))
  329. cap.release()
  330. if os.path.exists(tmp_video):
  331. os.remove(tmp_video)
  332. return frame_types, frames, mvs, mvs_visual