modular_instructblipvideo.py 27 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525526527528529530531532533534535536537538539540541542543544545546547548549550551552553554555556557558559560561562563564565566567568569570571572573574575576577578579580581582583584585586587588589590591592593594595596597598599600601602603604605606607608609610611612613
  1. # coding=utf-8
  2. # Copyright 2024 HuggingFace Inc. team. All rights reserved.
  3. #
  4. # Licensed under the Apache License, Version 2.0 (the "License");
  5. # you may not use this file except in compliance with the License.
  6. # You may obtain a copy of the License at
  7. #
  8. # http://www.apache.org/licenses/LICENSE-2.0
  9. #
  10. # Unless required by applicable law or agreed to in writing, software
  11. # distributed under the License is distributed on an "AS IS" BASIS,
  12. # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  13. # See the License for the specific language governing permissions and
  14. # limitations under the License.
  15. from typing import Optional, Union
  16. import torch
  17. from transformers.models.instructblip.configuration_instructblip import (
  18. InstructBlipQFormerConfig,
  19. InstructBlipVisionConfig,
  20. )
  21. from transformers.models.instructblip.modeling_instructblip import (
  22. InstructBlipForConditionalGeneration,
  23. InstructBlipForConditionalGenerationModelOutput,
  24. InstructBlipModel,
  25. InstructBlipPreTrainedModel,
  26. InstructBlipQFormerModel,
  27. InstructBlipVisionModel,
  28. TransformersKwargs,
  29. )
  30. from ...configuration_utils import PretrainedConfig
  31. from ...modeling_flash_attention_utils import FlashAttentionKwargs
  32. from ...models.auto.modeling_auto import MODEL_FOR_CAUSAL_LM_MAPPING_NAMES
  33. from ...processing_utils import Unpack
  34. from ...utils import logging
  35. from ..auto import CONFIG_MAPPING, AutoConfig
  36. logger = logging.get_logger(__name__)
  37. class InstructBlipVideoVisionConfig(InstructBlipVisionConfig):
  38. pass
  39. class InstructBlipVideoQFormerConfig(InstructBlipQFormerConfig):
  40. pass
  41. class InstructBlipVideoConfig(PretrainedConfig):
  42. r"""
  43. [`InstructBlipVideoConfig`] is the configuration class to store the configuration of a
  44. [`InstructBlipVideoForConditionalGeneration`]. It is used to instantiate a Instructblipvideo model according to the specified
  45. arguments, defining the vision model, Q-Former model and language model configs. Instantiating a configuration with
  46. the defaults will yield a similar configuration to that of the Instructblipvideo
  47. [Salesforce/instruct-blip-flan-t5](https://huggingface.co/Salesforce/instruct-blip-flan-t5) architecture.
  48. Configuration objects inherit from [`PretrainedConfig`] and can be used to control the model outputs. Read the
  49. documentation from [`PretrainedConfig`] for more information.
  50. Args:
  51. vision_config (`dict`, *optional*):
  52. Dictionary of configuration options used to initialize [`InstructBlipVideoVisionConfig`].
  53. qformer_config (`dict`, *optional*):
  54. Dictionary of configuration options used to initialize [`InstructBlipVideoQFormerConfig`].
  55. text_config (`dict`, *optional*):
  56. Dictionary of configuration options used to initialize any [`PretrainedConfig`].
  57. num_query_tokens (`int`, *optional*, defaults to 32):
  58. The number of query tokens passed through the Transformer.
  59. video_token_index (`int`, *optional*):
  60. Token index of special video token.
  61. kwargs (*optional*):
  62. Dictionary of keyword arguments.
  63. Example:
  64. ```python
  65. >>> from transformers import (
  66. ... InstructBlipVideoVisionConfig,
  67. ... InstructBlipVideoQFormerConfig,
  68. ... OPTConfig,
  69. ... InstructBlipVideoConfig,
  70. ... InstructBlipVideoForConditionalGeneration,
  71. ... )
  72. >>> # Initializing a InstructBlipVideoConfig with Salesforce/instruct-blip-flan-t5 style configuration
  73. >>> configuration = InstructBlipVideoConfig()
  74. >>> # Initializing a InstructBlipVideoForConditionalGeneration (with random weights) from the Salesforce/instruct-blip-flan-t5 style configuration
  75. >>> model = InstructBlipVideoForConditionalGeneration(configuration)
  76. >>> # Accessing the model configuration
  77. >>> configuration = model.config
  78. >>> # We can also initialize a InstructBlipVideoConfig from a InstructBlipVideoVisionConfig, InstructBlipVideoQFormerConfig and any PretrainedConfig
  79. >>> # Initializing Instructblipvideo vision, Instructblipvideo Q-Former and language model configurations
  80. >>> vision_config = InstructBlipVideoVisionConfig()
  81. >>> qformer_config = InstructBlipVideoQFormerConfig()
  82. >>> text_config = OPTConfig()
  83. >>> config = InstructBlipVideoConfig.from_text_vision_configs(vision_config, qformer_config, text_config)
  84. ```"""
  85. model_type = "instructblipvideo"
  86. attribute_map = {
  87. "video_token_id": "video_token_index",
  88. }
  89. sub_configs = {
  90. "text_config": AutoConfig,
  91. "qformer_config": InstructBlipVideoQFormerConfig,
  92. "vision_config": InstructBlipVideoVisionConfig,
  93. }
  94. def __init__(
  95. self,
  96. vision_config=None,
  97. qformer_config=None,
  98. text_config=None,
  99. num_query_tokens=32,
  100. video_token_index=None,
  101. **kwargs,
  102. ):
  103. super().__init__(**kwargs)
  104. if vision_config is None:
  105. vision_config = {}
  106. logger.info("vision_config is None. initializing the InstructBlipVideoVisionConfig with default values.")
  107. if qformer_config is None:
  108. qformer_config = {}
  109. logger.info("qformer_config is None. Initializing the InstructBlipVideoQFormerConfig with default values.")
  110. if text_config is None:
  111. text_config = {}
  112. logger.info("text_config is None. Initializing the text config with default values (`OPTConfig`).")
  113. self.vision_config = InstructBlipVideoVisionConfig(**vision_config)
  114. self.qformer_config = InstructBlipVideoQFormerConfig(**qformer_config)
  115. text_model_type = text_config.get("model_type", "opt")
  116. self.text_config = CONFIG_MAPPING[text_model_type](**text_config)
  117. self.num_query_tokens = num_query_tokens
  118. self.video_token_index = video_token_index
  119. self.qformer_config.encoder_hidden_size = self.vision_config.hidden_size
  120. self.use_decoder_only_language_model = self.text_config.model_type in MODEL_FOR_CAUSAL_LM_MAPPING_NAMES
  121. self.initializer_factor = 1.0
  122. self.initializer_range = 0.02
  123. @classmethod
  124. def from_vision_qformer_text_configs(
  125. cls,
  126. vision_config: InstructBlipVideoVisionConfig,
  127. qformer_config: InstructBlipVideoQFormerConfig,
  128. text_config: PretrainedConfig,
  129. **kwargs,
  130. ):
  131. r"""
  132. Instantiate a [`InstructBlipVideoConfig`] (or a derived class) from a InstructBlipVideo vision model, Q-Former and
  133. language model configurations.
  134. Returns:
  135. [`InstructBlipVideoConfig`]: An instance of a configuration object
  136. """
  137. return cls(
  138. vision_config=vision_config.to_dict(),
  139. qformer_config=qformer_config.to_dict(),
  140. text_config=text_config.to_dict(),
  141. **kwargs,
  142. )
  143. class InstructBlipVideoPreTrainedModel(InstructBlipPreTrainedModel):
  144. pass
  145. class InstructBlipVideoVisionModel(InstructBlipVisionModel):
  146. pass
  147. class InstructBlipVideoQFormerModel(InstructBlipQFormerModel):
  148. pass
  149. class InstructBlipVideoForConditionalGenerationModelOutput(InstructBlipForConditionalGenerationModelOutput):
  150. pass
  151. class InstructBlipVideoModel(InstructBlipModel):
  152. def forward(
  153. self,
  154. pixel_values: torch.FloatTensor,
  155. qformer_input_ids: torch.FloatTensor,
  156. qformer_attention_mask: Optional[torch.LongTensor] = None,
  157. input_ids: Optional[torch.FloatTensor] = None,
  158. attention_mask: Optional[torch.LongTensor] = None,
  159. decoder_input_ids: Optional[torch.LongTensor] = None,
  160. decoder_attention_mask: Optional[torch.LongTensor] = None,
  161. inputs_embeds: Optional[torch.Tensor] = None,
  162. output_attentions: Optional[bool] = None,
  163. output_hidden_states: Optional[bool] = None,
  164. return_dict: Optional[bool] = None,
  165. interpolate_pos_encoding: bool = False,
  166. use_cache: Optional[bool] = None,
  167. **kwargs: Unpack[FlashAttentionKwargs],
  168. ) -> Union[tuple, InstructBlipVideoForConditionalGenerationModelOutput]:
  169. return_dict = return_dict if return_dict is not None else self.config.use_return_dict
  170. # step 1: forward the images through the vision encoder,
  171. # we process in a batched way, later unbatch it back (video has frames=4 always)
  172. batch_size, frames, channel, height, width = pixel_values.shape
  173. pixel_values = pixel_values.reshape(batch_size * frames, channel, height, width)
  174. vision_outputs = self.vision_model(
  175. pixel_values=pixel_values,
  176. output_attentions=output_attentions,
  177. output_hidden_states=output_hidden_states,
  178. return_dict=return_dict,
  179. interpolate_pos_encoding=interpolate_pos_encoding,
  180. )
  181. image_embeds = vision_outputs[0]
  182. # step 2: forward the query tokens through the QFormer, using the image embeddings for cross-attention
  183. image_attention_mask = torch.ones(image_embeds.size()[:-1], dtype=torch.long, device=image_embeds.device)
  184. # difference with BLIP-2 here: we also feed the instruction prompt to the Q-Former
  185. query_tokens = self.query_tokens.expand(image_embeds.shape[0], -1, -1)
  186. query_attention_mask = torch.ones(query_tokens.size()[:-1], dtype=torch.long, device=image_embeds.device)
  187. if qformer_attention_mask is None:
  188. qformer_attention_mask = torch.ones_like(qformer_input_ids)
  189. qformer_input_ids = qformer_input_ids.repeat_interleave(frames, dim=0)
  190. qformer_attention_mask = qformer_attention_mask.repeat_interleave(frames, dim=0)
  191. qformer_attention_mask = torch.cat([query_attention_mask, qformer_attention_mask], dim=1)
  192. query_outputs = self.qformer(
  193. input_ids=qformer_input_ids,
  194. attention_mask=qformer_attention_mask,
  195. query_embeds=query_tokens,
  196. encoder_hidden_states=image_embeds,
  197. encoder_attention_mask=image_attention_mask,
  198. output_attentions=output_attentions,
  199. output_hidden_states=output_hidden_states,
  200. return_dict=return_dict,
  201. )
  202. query_output = query_outputs[0][:, : query_tokens.size(1), :]
  203. # step 3: use the language model, conditioned on the query outputs and the prompt
  204. language_model_inputs = self.language_projection(query_output)
  205. # unbatch inputs back, each video-frame gets `num_query_tokens` seq length
  206. language_model_inputs = language_model_inputs.reshape(batch_size, self.config.num_query_tokens * frames, -1)
  207. if inputs_embeds is None:
  208. inputs_embeds = self.language_model.get_input_embeddings()(input_ids)
  209. special_image_mask = input_ids == self.config.video_token_id
  210. if attention_mask is None:
  211. attention_mask = torch.ones_like(input_ids)
  212. else:
  213. special_image_mask = inputs_embeds == self.get_input_embeddings()(
  214. torch.tensor(self.config.video_token_id, dtype=torch.long, device=inputs_embeds.device)
  215. )
  216. special_image_mask = special_image_mask.all(-1)
  217. special_image_mask = special_image_mask.unsqueeze(-1).expand_as(inputs_embeds).to(inputs_embeds.device)
  218. language_model_inputs = language_model_inputs.to(inputs_embeds.device, inputs_embeds.dtype)
  219. inputs_embeds = inputs_embeds.masked_scatter(special_image_mask, language_model_inputs)
  220. if self.config.use_decoder_only_language_model:
  221. outputs = self.language_model(
  222. inputs_embeds=inputs_embeds,
  223. attention_mask=attention_mask,
  224. output_attentions=output_attentions,
  225. output_hidden_states=output_hidden_states,
  226. return_dict=return_dict,
  227. use_cache=use_cache,
  228. **kwargs,
  229. )
  230. else:
  231. outputs = self.language_model(
  232. inputs_embeds=inputs_embeds,
  233. attention_mask=attention_mask,
  234. decoder_input_ids=decoder_input_ids,
  235. decoder_attention_mask=decoder_attention_mask,
  236. output_attentions=output_attentions,
  237. output_hidden_states=output_hidden_states,
  238. return_dict=return_dict,
  239. use_cache=use_cache,
  240. **kwargs,
  241. )
  242. return InstructBlipVideoForConditionalGenerationModelOutput(
  243. vision_outputs=vision_outputs,
  244. qformer_outputs=query_outputs,
  245. language_model_outputs=outputs,
  246. )
  247. class InstructBlipVideoForConditionalGeneration(InstructBlipForConditionalGeneration):
  248. def get_video_features(
  249. self,
  250. pixel_values: torch.FloatTensor,
  251. qformer_input_ids: torch.LongTensor,
  252. qformer_attention_mask: Optional[torch.LongTensor] = None,
  253. interpolate_pos_encoding: Optional[bool] = False,
  254. return_dict: Optional[bool] = False,
  255. ):
  256. """
  257. Encodes images into continuous embeddings that can be forwarded to the language model.
  258. Args:
  259. pixel_values (`torch.FloatTensor` of shape `(batch_size, num_channels, image_size, image_size)`):
  260. The tensors corresponding to the input images.
  261. """
  262. # step 1: forward the images through the vision encoder,
  263. # we process in a batched way, later unbatch it back (video has frames=4 always)
  264. batch_size, frames, channel, height, width = pixel_values.shape
  265. pixel_values = pixel_values.reshape(batch_size * frames, channel, height, width)
  266. vision_outputs = self.vision_model(
  267. pixel_values=pixel_values,
  268. interpolate_pos_encoding=interpolate_pos_encoding,
  269. return_dict=True,
  270. )
  271. image_embeds = vision_outputs[0]
  272. # step 2: forward the query tokens through the QFormer, using the image embeddings for cross-attention
  273. image_attention_mask = torch.ones(image_embeds.size()[:-1], dtype=torch.long, device=image_embeds.device)
  274. # difference with BLIP-2 here: we also feed the instruction prompt to the Q-Former
  275. query_tokens = self.query_tokens.expand(image_embeds.shape[0], -1, -1)
  276. query_attention_mask = torch.ones(query_tokens.size()[:-1], dtype=torch.long, device=image_embeds.device)
  277. if qformer_attention_mask is None:
  278. qformer_attention_mask = torch.ones_like(qformer_input_ids)
  279. qformer_input_ids = qformer_input_ids.repeat_interleave(frames, dim=0)
  280. qformer_attention_mask = qformer_attention_mask.repeat_interleave(frames, dim=0)
  281. qformer_attention_mask = torch.cat([query_attention_mask, qformer_attention_mask], dim=1)
  282. query_outputs = self.qformer(
  283. input_ids=qformer_input_ids,
  284. attention_mask=qformer_attention_mask,
  285. query_embeds=query_tokens,
  286. encoder_hidden_states=image_embeds,
  287. encoder_attention_mask=image_attention_mask,
  288. return_dict=True,
  289. )
  290. query_output = query_outputs[0][:, : query_tokens.size(1), :]
  291. # step 3: use the language model, conditioned on the query outputs and the prompt
  292. language_model_inputs = self.language_projection(query_output)
  293. # unbatch inputs back, each video-frame gets `num_query_tokens` seq length
  294. language_model_inputs = language_model_inputs.reshape(batch_size, self.config.num_query_tokens * frames, -1)
  295. if return_dict:
  296. return language_model_inputs, vision_outputs, query_outputs
  297. return language_model_inputs
  298. # Model supports only videos
  299. def get_image_features(
  300. self,
  301. pixel_values: torch.FloatTensor,
  302. qformer_input_ids: torch.LongTensor,
  303. qformer_attention_mask: Optional[torch.LongTensor] = None,
  304. interpolate_pos_encoding: Optional[bool] = False,
  305. return_dict: Optional[bool] = False,
  306. ):
  307. pass
  308. def get_placeholder_mask(self, input_ids: torch.LongTensor, inputs_embeds: torch.FloatTensor):
  309. """
  310. Obtains multimodal placeholder mask from `input_ids` or `inputs_embeds`.
  311. """
  312. if input_ids is None:
  313. special_image_mask = inputs_embeds == self.get_input_embeddings()(
  314. torch.tensor(self.config.video_token_id, dtype=torch.long, device=inputs_embeds.device)
  315. )
  316. special_image_mask = special_image_mask.all(-1)
  317. else:
  318. special_image_mask = input_ids == self.config.video_token_id
  319. special_image_mask = special_image_mask.unsqueeze(-1).expand_as(inputs_embeds).to(inputs_embeds.device)
  320. return special_image_mask
  321. def forward(
  322. self,
  323. pixel_values: torch.FloatTensor,
  324. qformer_input_ids: torch.FloatTensor,
  325. qformer_attention_mask: Optional[torch.LongTensor] = None,
  326. input_ids: Optional[torch.FloatTensor] = None,
  327. attention_mask: Optional[torch.LongTensor] = None,
  328. decoder_input_ids: Optional[torch.LongTensor] = None,
  329. decoder_attention_mask: Optional[torch.LongTensor] = None,
  330. inputs_embeds: Optional[torch.FloatTensor] = None,
  331. output_attentions: Optional[bool] = None,
  332. output_hidden_states: Optional[bool] = None,
  333. labels: Optional[torch.LongTensor] = None,
  334. return_dict: Optional[bool] = None,
  335. interpolate_pos_encoding: bool = False,
  336. use_cache: Optional[bool] = None,
  337. **kwargs: Unpack[TransformersKwargs],
  338. ) -> Union[tuple, InstructBlipVideoForConditionalGenerationModelOutput]:
  339. r"""
  340. qformer_input_ids (`torch.LongTensor` of shape (batch_size, sequence_length)):
  341. The sequence used as a prompt to be fed to the Q-Former module.
  342. qformer_attention_mask (`torch.LongTensor` of shape (batch_size, sequence_length), *optional*):
  343. Mask to avoid performing attention on padding token indices.
  344. Examples:
  345. ```python
  346. >>> from transformers import InstructBlipVideoProcessor, InstructBlipVideoForConditionalGeneration
  347. >>> import torch
  348. >>> from huggingface_hub import hf_hub_download
  349. >>> import av
  350. >>> import numpy as np
  351. >>> def read_video_pyav(container, indices):
  352. ... '''
  353. ... Decode the video with PyAV decoder.
  354. ... Args:
  355. ... container (`av.container.input.InputContainer`): PyAV container.
  356. ... indices (`list[int]`): List of frame indices to decode.
  357. ... Returns:
  358. ... result (np.ndarray): np array of decoded frames of shape (num_frames, height, width, 3).
  359. ... '''
  360. ... frames = []
  361. ... container.seek(0)
  362. ... start_index = indices[0]
  363. ... end_index = indices[-1]
  364. ... for i, frame in enumerate(container.decode(video=0)):
  365. ... if i > end_index:
  366. ... break
  367. ... if i >= start_index and i in indices:
  368. ... frames.append(frame)
  369. ... return np.stack([x.to_ndarray(format="rgb24") for x in frames])
  370. >>> model = InstructBlipVideoForConditionalGeneration.from_pretrained("Salesforce/instructblip-vicuna-7b", device_map="auto")
  371. >>> processor = InstructBlipVideoProcessor.from_pretrained("Salesforce/instructblip-vicuna-7b")
  372. >>> file_path = hf_hub_download(
  373. ... repo_id="nielsr/video-demo", filename="eating_spaghetti.mp4", repo_type="dataset"
  374. ... )
  375. >>> container = av.open(file_path)
  376. >>> # sample uniformly 4 frames from the videWhy is this video funny?o
  377. >>> total_frames = container.streams.video[0].frames
  378. >>> indices = np.arange(0, total_frames, total_frames / 4).astype(int)
  379. >>> clip = read_video_pyav(container, indices)
  380. >>> prompt = "What is happening in the video?"
  381. >>> inputs = processor(text=prompt, images=clip, return_tensors="pt").to(model.device)
  382. >>> outputs = model.generate(
  383. ... **inputs,
  384. ... do_sample=False,
  385. ... num_beams=5,
  386. ... max_length=256,
  387. ... repetition_penalty=1.5,
  388. ... length_penalty=1.0,
  389. ... )
  390. >>> generated_text = processor.batch_decode(outputs, skip_special_tokens=True)[0].strip()
  391. >>> print(generated_text)
  392. "A person is eating a bowl of pasta, and they are using a fork to eat it. The person is sitting at a table, and the plate of pasta is on the table in front"
  393. ```"""
  394. return_dict = return_dict if return_dict is not None else self.config.use_return_dict
  395. language_model_inputs, vision_outputs, query_outputs = self.get_video_features(
  396. pixel_values,
  397. qformer_input_ids=qformer_input_ids,
  398. qformer_attention_mask=qformer_attention_mask,
  399. interpolate_pos_encoding=interpolate_pos_encoding,
  400. return_dict=True,
  401. )
  402. vision_outputs = vision_outputs.to_tuple() if not return_dict else vision_outputs
  403. query_outputs = query_outputs.to_tuple() if not return_dict else query_outputs
  404. if inputs_embeds is None:
  405. inputs_embeds = self.get_input_embeddings()(input_ids)
  406. if attention_mask is None:
  407. attention_mask = torch.ones_like(input_ids)
  408. language_model_inputs = language_model_inputs.to(inputs_embeds.device, inputs_embeds.dtype)
  409. special_image_mask = self.get_placeholder_mask(input_ids, inputs_embeds=inputs_embeds)
  410. inputs_embeds = inputs_embeds.masked_scatter(special_image_mask, language_model_inputs)
  411. if self.config.use_decoder_only_language_model:
  412. outputs = self.language_model(
  413. inputs_embeds=inputs_embeds,
  414. attention_mask=attention_mask,
  415. output_attentions=output_attentions,
  416. output_hidden_states=output_hidden_states,
  417. return_dict=return_dict,
  418. use_cache=use_cache,
  419. **kwargs,
  420. )
  421. logits = outputs.logits if return_dict else outputs[0]
  422. loss = None
  423. if labels is not None:
  424. loss = self.loss_function(
  425. logits=logits, labels=labels, vocab_size=self.config.text_config.vocab_size, **kwargs
  426. )
  427. else:
  428. outputs = self.language_model(
  429. inputs_embeds=inputs_embeds,
  430. attention_mask=attention_mask,
  431. decoder_input_ids=decoder_input_ids,
  432. decoder_attention_mask=decoder_attention_mask,
  433. output_attentions=output_attentions,
  434. output_hidden_states=output_hidden_states,
  435. return_dict=return_dict,
  436. labels=labels,
  437. use_cache=use_cache,
  438. **kwargs,
  439. )
  440. loss = outputs.loss if return_dict else outputs[0]
  441. logits = outputs.logits if return_dict else outputs[1]
  442. return InstructBlipVideoForConditionalGenerationModelOutput(
  443. loss=loss,
  444. logits=logits,
  445. vision_outputs=vision_outputs,
  446. qformer_outputs=query_outputs,
  447. language_model_outputs=outputs,
  448. )
  449. @torch.no_grad()
  450. def generate(
  451. self,
  452. pixel_values: torch.FloatTensor,
  453. qformer_input_ids: Optional[torch.LongTensor] = None,
  454. qformer_attention_mask: Optional[torch.LongTensor] = None,
  455. input_ids: Optional[torch.LongTensor] = None,
  456. attention_mask: Optional[torch.LongTensor] = None,
  457. inputs_embeds: Optional[torch.FloatTensor] = None,
  458. interpolate_pos_encoding: bool = False,
  459. **generate_kwargs,
  460. ) -> torch.LongTensor:
  461. r"""
  462. Overrides `generate` function to be able to use the model as a conditional generator.
  463. Args:
  464. pixel_values (`torch.FloatTensor` of shape (batch_size, num_channels, height, width) or
  465. (batch_size, num_frames, num_channels, height, width)): Input images or videos to be processed.
  466. qformer_input_ids (`torch.LongTensor` of shape (batch_size, sequence_length), *optional*):
  467. The sequence used as a prompt to be fed to the Q-Former module.
  468. qformer_attention_mask (`torch.LongTensor` of shape (batch_size, sequence_length), *optional*):
  469. Mask to avoid performing attention on padding token indices.
  470. input_ids (`torch.LongTensor` of shape (batch_size, sequence_length), *optional*):
  471. The sequence used as a prompt for the generation.
  472. attention_mask (`torch.LongTensor` of shape (batch_size, sequence_length), *optional*):
  473. Mask to avoid performing attention on padding token indices.
  474. inputs_embeds (`torch.FloatTensor` of shape `(batch_size, sequence_length, hidden_size)`):
  475. Embedded representation of the inputs. Should be float, not int tokens.
  476. interpolate_pos_encoding (`bool`, *optional*, defaults to `False`):
  477. Whether to interpolate the positional encoding of the image embeddings.
  478. Returns:
  479. captions (list): A list of strings of length batch_size * num_captions.
  480. """
  481. if hasattr(self, "hf_device_map"):
  482. # preprocess for `accelerate`
  483. self._preprocess_accelerate()
  484. batch_size = pixel_values.shape[0]
  485. language_model_inputs, vision_outputs, query_outputs = self.get_video_features(
  486. pixel_values,
  487. qformer_input_ids=qformer_input_ids,
  488. qformer_attention_mask=qformer_attention_mask,
  489. interpolate_pos_encoding=interpolate_pos_encoding,
  490. return_dict=True,
  491. )
  492. if inputs_embeds is None:
  493. if input_ids is None:
  494. video_tokens = [self.config.video_token_index] * self.config.num_query_tokens * 4
  495. start_tokens = video_tokens + [self.config.text_config.bos_token_id]
  496. input_ids = torch.tensor([start_tokens], dtype=torch.long, device=pixel_values.device)
  497. input_ids = input_ids.repeat(batch_size, 1)
  498. inputs_embeds = self.get_input_embeddings()(input_ids)
  499. if attention_mask is None:
  500. attention_mask = torch.ones_like(input_ids)
  501. language_model_inputs = language_model_inputs.to(inputs_embeds.device, inputs_embeds.dtype)
  502. special_image_mask = self.get_placeholder_mask(input_ids, inputs_embeds=inputs_embeds)
  503. inputs_embeds = inputs_embeds.masked_scatter(special_image_mask, language_model_inputs)
  504. inputs = {"inputs_embeds": inputs_embeds, "attention_mask": attention_mask}
  505. if not self.language_model.config.is_encoder_decoder:
  506. inputs["input_ids"] = input_ids
  507. outputs = self.language_model.generate(**inputs, **generate_kwargs)
  508. return outputs
  509. __all__ = [
  510. "InstructBlipVideoConfig",
  511. "InstructBlipVideoQFormerConfig",
  512. "InstructBlipVideoVisionConfig",
  513. "InstructBlipVideoVisionModel",
  514. "InstructBlipVideoPreTrainedModel",
  515. "InstructBlipVideoQFormerModel",
  516. "InstructBlipVideoModel",
  517. "InstructBlipVideoForConditionalGeneration",
  518. ]