modeling_instructblipvideo.py 69 KB

12345678910111213141516171819202122232425262728293031323334353637383940414243444546474849505152535455565758596061626364656667686970717273747576777879808182838485868788899091929394959697989910010110210310410510610710810911011111211311411511611711811912012112212312412512612712812913013113213313413513613713813914014114214314414514614714814915015115215315415515615715815916016116216316416516616716816917017117217317417517617717817918018118218318418518618718818919019119219319419519619719819920020120220320420520620720820921021121221321421521621721821922022122222322422522622722822923023123223323423523623723823924024124224324424524624724824925025125225325425525625725825926026126226326426526626726826927027127227327427527627727827928028128228328428528628728828929029129229329429529629729829930030130230330430530630730830931031131231331431531631731831932032132232332432532632732832933033133233333433533633733833934034134234334434534634734834935035135235335435535635735835936036136236336436536636736836937037137237337437537637737837938038138238338438538638738838939039139239339439539639739839940040140240340440540640740840941041141241341441541641741841942042142242342442542642742842943043143243343443543643743843944044144244344444544644744844945045145245345445545645745845946046146246346446546646746846947047147247347447547647747847948048148248348448548648748848949049149249349449549649749849950050150250350450550650750850951051151251351451551651751851952052152252352452552652752852953053153253353453553653753853954054154254354454554654754854955055155255355455555655755855956056156256356456556656756856957057157257357457557657757857958058158258358458558658758858959059159259359459559659759859960060160260360460560660760860961061161261361461561661761861962062162262362462562662762862963063163263363463563663763863964064164264364464564664764864965065165265365465565665765865966066166266366466566666766866967067167267367467567667767867968068168268368468568668768868969069169269369469569669769869970070170270370470570670770870971071171271371471571671771871972072172272372472572672772872973073173273373473573673773873974074174274374474574674774874975075175275375475575675775875976076176276376476576676776876977077177277377477577677777877978078178278378478578678778878979079179279379479579679779879980080180280380480580680780880981081181281381481581681781881982082182282382482582682782882983083183283383483583683783883984084184284384484584684784884985085185285385485585685785885986086186286386486586686786886987087187287387487587687787887988088188288388488588688788888989089189289389489589689789889990090190290390490590690790890991091191291391491591691791891992092192292392492592692792892993093193293393493593693793893994094194294394494594694794894995095195295395495595695795895996096196296396496596696796896997097197297397497597697797897998098198298398498598698798898999099199299399499599699799899910001001100210031004100510061007100810091010101110121013101410151016101710181019102010211022102310241025102610271028102910301031103210331034103510361037103810391040104110421043104410451046104710481049105010511052105310541055105610571058105910601061106210631064106510661067106810691070107110721073107410751076107710781079108010811082108310841085108610871088108910901091109210931094109510961097109810991100110111021103110411051106110711081109111011111112111311141115111611171118111911201121112211231124112511261127112811291130113111321133113411351136113711381139114011411142114311441145114611471148114911501151115211531154115511561157115811591160116111621163116411651166116711681169117011711172117311741175117611771178117911801181118211831184118511861187118811891190119111921193119411951196119711981199120012011202120312041205120612071208120912101211121212131214121512161217121812191220122112221223122412251226122712281229123012311232123312341235123612371238123912401241124212431244124512461247124812491250125112521253125412551256125712581259126012611262126312641265126612671268126912701271127212731274127512761277127812791280128112821283128412851286128712881289129012911292129312941295129612971298129913001301130213031304130513061307130813091310131113121313131413151316131713181319132013211322132313241325132613271328132913301331133213331334133513361337133813391340134113421343134413451346134713481349135013511352135313541355135613571358135913601361136213631364136513661367136813691370137113721373137413751376137713781379138013811382138313841385138613871388138913901391139213931394139513961397139813991400140114021403140414051406140714081409141014111412141314141415141614171418141914201421142214231424142514261427142814291430143114321433143414351436143714381439144014411442144314441445144614471448144914501451145214531454145514561457145814591460146114621463146414651466146714681469147014711472147314741475147614771478147914801481148214831484148514861487148814891490149114921493149414951496149714981499150015011502150315041505150615071508150915101511151215131514151515161517151815191520152115221523152415251526152715281529153015311532153315341535153615371538153915401541154215431544154515461547154815491550155115521553155415551556155715581559156015611562156315641565156615671568156915701571157215731574157515761577157815791580158115821583158415851586158715881589159015911592159315941595159615971598159916001601160216031604
  1. # 🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨
  2. # This file was automatically generated from src/transformers/models/instructblipvideo/modular_instructblipvideo.py.
  3. # Do NOT edit this file manually as any edits will be overwritten by the generation of
  4. # the file from the modular. If any change should be done, please apply the change to the
  5. # modular_instructblipvideo.py file directly. One of our CI enforces this.
  6. # 🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨
  7. # coding=utf-8
  8. # Copyright 2024 HuggingFace Inc. team. All rights reserved.
  9. #
  10. # Licensed under the Apache License, Version 2.0 (the "License");
  11. # you may not use this file except in compliance with the License.
  12. # You may obtain a copy of the License at
  13. #
  14. # http://www.apache.org/licenses/LICENSE-2.0
  15. #
  16. # Unless required by applicable law or agreed to in writing, software
  17. # distributed under the License is distributed on an "AS IS" BASIS,
  18. # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  19. # See the License for the specific language governing permissions and
  20. # limitations under the License.
  21. import math
  22. from dataclasses import dataclass
  23. from typing import Any, Callable, Optional, Union
  24. import torch
  25. from torch import nn
  26. from ...activations import ACT2FN
  27. from ...generation import GenerationMixin
  28. from ...modeling_flash_attention_utils import FlashAttentionKwargs
  29. from ...modeling_layers import GradientCheckpointingLayer
  30. from ...modeling_outputs import (
  31. BaseModelOutput,
  32. BaseModelOutputWithPastAndCrossAttentions,
  33. BaseModelOutputWithPooling,
  34. BaseModelOutputWithPoolingAndCrossAttentions,
  35. )
  36. from ...modeling_utils import ALL_ATTENTION_FUNCTIONS, PreTrainedModel
  37. from ...processing_utils import Unpack
  38. from ...pytorch_utils import apply_chunking_to_forward, find_pruneable_heads_and_indices, prune_linear_layer
  39. from ...utils import ModelOutput, TransformersKwargs, auto_docstring, can_return_tuple, logging, torch_int
  40. from ...utils.generic import OutputRecorder, check_model_inputs
  41. from ..auto import AutoModel, AutoModelForCausalLM, AutoModelForSeq2SeqLM
  42. from .configuration_instructblipvideo import (
  43. InstructBlipVideoConfig,
  44. InstructBlipVideoQFormerConfig,
  45. InstructBlipVideoVisionConfig,
  46. )
  47. logger = logging.get_logger(__name__)
  48. class InstructBlipVideoVisionEmbeddings(nn.Module):
  49. def __init__(self, config: InstructBlipVideoVisionConfig):
  50. super().__init__()
  51. self.config = config
  52. self.embed_dim = config.hidden_size
  53. self.image_size = config.image_size
  54. self.patch_size = config.patch_size
  55. self.class_embedding = nn.Parameter(torch.randn(1, 1, self.embed_dim))
  56. self.patch_embedding = nn.Conv2d(
  57. in_channels=3, out_channels=self.embed_dim, kernel_size=self.patch_size, stride=self.patch_size
  58. )
  59. self.num_patches = (self.image_size // self.patch_size) ** 2
  60. self.num_positions = self.num_patches + 1
  61. self.position_embedding = nn.Parameter(torch.randn(1, self.num_positions, self.embed_dim))
  62. def interpolate_pos_encoding(self, embeddings: torch.Tensor, height: int, width: int) -> torch.Tensor:
  63. """
  64. This method allows to interpolate the pre-trained position encodings, to be able to use the model on higher resolution
  65. images. This method is also adapted to support torch.jit tracing.
  66. Adapted from:
  67. - https://github.com/facebookresearch/dino/blob/de9ee3df6cf39fac952ab558447af1fa1365362a/vision_transformer.py#L174-L194, and
  68. - https://github.com/facebookresearch/dinov2/blob/e1277af2ba9496fbadf7aec6eba56e8d882d1e35/dinov2/models/vision_transformer.py#L179-L211
  69. """
  70. num_patches = embeddings.shape[1] - 1
  71. num_positions = self.position_embedding.shape[1] - 1
  72. # always interpolate when tracing to ensure the exported model works for dynamic input shapes
  73. if not torch.jit.is_tracing() and num_patches == num_positions and height == width:
  74. return self.position_embedding
  75. class_pos_embed = self.position_embedding[:, :1]
  76. patch_pos_embed = self.position_embedding[:, 1:]
  77. dim = embeddings.shape[-1]
  78. new_height = height // self.patch_size
  79. new_width = width // self.patch_size
  80. sqrt_num_positions = torch_int(num_positions**0.5)
  81. patch_pos_embed = patch_pos_embed.reshape(1, sqrt_num_positions, sqrt_num_positions, dim)
  82. patch_pos_embed = patch_pos_embed.permute(0, 3, 1, 2)
  83. patch_pos_embed = nn.functional.interpolate(
  84. patch_pos_embed,
  85. size=(new_height, new_width),
  86. mode="bicubic",
  87. align_corners=False,
  88. )
  89. patch_pos_embed = patch_pos_embed.permute(0, 2, 3, 1).view(1, -1, dim)
  90. return torch.cat((class_pos_embed, patch_pos_embed), dim=1)
  91. def forward(self, pixel_values: torch.FloatTensor, interpolate_pos_encoding: bool = False) -> torch.Tensor:
  92. batch_size, _, height, width = pixel_values.shape
  93. target_dtype = self.patch_embedding.weight.dtype
  94. patch_embeds = self.patch_embedding(pixel_values.to(dtype=target_dtype)) # shape = [*, width, grid, grid]
  95. patch_embeds = patch_embeds.flatten(2).transpose(1, 2)
  96. class_embeds = self.class_embedding.expand(batch_size, 1, -1).to(target_dtype)
  97. embeddings = torch.cat([class_embeds, patch_embeds], dim=1)
  98. if interpolate_pos_encoding:
  99. position_embedding = self.interpolate_pos_encoding(embeddings, height, width)
  100. else:
  101. position_embedding = self.position_embedding
  102. embeddings = embeddings + position_embedding[:, : embeddings.size(1), :].to(target_dtype)
  103. return embeddings
  104. @auto_docstring
  105. class InstructBlipVideoPreTrainedModel(PreTrainedModel):
  106. config: InstructBlipVideoConfig
  107. base_model_prefix = "blip"
  108. supports_gradient_checkpointing = True
  109. _supports_attention_backend = True
  110. _supports_flash_attn = True
  111. _supports_sdpa = True
  112. _supports_flex_attn = True
  113. _can_compile_fullgraph = True
  114. _no_split_modules = [
  115. "InstructBlipVideoQFormerEmbeddings",
  116. "InstructBlipVideoAttention",
  117. "InstructBlipVideoQFormerMultiHeadAttention",
  118. "InstructBlipVideoQFormerSelfOutput",
  119. ]
  120. def _init_weights(self, module):
  121. """Initialize the weights"""
  122. factor = self.config.initializer_range
  123. if isinstance(module, (nn.Linear, nn.Conv2d)):
  124. module.weight.data.normal_(mean=0.0, std=factor)
  125. if module.bias is not None:
  126. module.bias.data.zero_()
  127. elif isinstance(module, nn.Embedding):
  128. module.weight.data.normal_(mean=0.0, std=factor)
  129. elif isinstance(module, nn.LayerNorm):
  130. module.bias.data.zero_()
  131. module.weight.data.fill_(1.0)
  132. elif isinstance(module, InstructBlipVideoVisionEmbeddings):
  133. nn.init.trunc_normal_(module.position_embedding, mean=0.0, std=factor)
  134. nn.init.trunc_normal_(module.class_embedding, mean=0.0, std=factor)
  135. elif isinstance(module, (InstructBlipVideoForConditionalGeneration, InstructBlipVideoModel)):
  136. module.query_tokens.data.zero_()
  137. # Adapted from transformers.models.siglip.modeling_siglip.eager_attention_forward -> InstructBlipVideo doesn't cast attn weights to fp32
  138. def eager_attention_forward(
  139. module: nn.Module,
  140. query: torch.Tensor,
  141. key: torch.Tensor,
  142. value: torch.Tensor,
  143. attention_mask: Optional[torch.Tensor],
  144. scaling: float,
  145. dropout: float = 0.0,
  146. **kwargs,
  147. ):
  148. attn_weights = torch.matmul(query, key.transpose(-1, -2)) * scaling
  149. if attention_mask is not None:
  150. attn_weights = attn_weights + attention_mask
  151. attn_weights = nn.functional.softmax(attn_weights, dim=-1)
  152. attn_weights = nn.functional.dropout(attn_weights, p=dropout, training=module.training)
  153. attn_output = torch.matmul(attn_weights, value)
  154. attn_output = attn_output.transpose(1, 2).contiguous()
  155. return attn_output, attn_weights
  156. class InstructBlipVideoAttention(nn.Module):
  157. """Multi-headed attention from 'Attention Is All You Need' paper"""
  158. def __init__(self, config):
  159. super().__init__()
  160. self.config = config
  161. self.embed_dim = config.hidden_size
  162. self.num_heads = config.num_attention_heads
  163. self.head_dim = self.embed_dim // self.num_heads
  164. if self.head_dim * self.num_heads != self.embed_dim:
  165. raise ValueError(
  166. f"embed_dim must be divisible by num_heads (got `embed_dim`: {self.embed_dim} and `num_heads`:"
  167. f" {self.num_heads})."
  168. )
  169. self.scale = self.head_dim**-0.5
  170. self.is_causal = False
  171. self.attention_dropout = config.attention_dropout
  172. # small tweak here compared to CLIP, no bias here
  173. self.qkv = nn.Linear(self.embed_dim, 3 * self.embed_dim, bias=False)
  174. if config.qkv_bias:
  175. q_bias = nn.Parameter(torch.zeros(self.embed_dim))
  176. v_bias = nn.Parameter(torch.zeros(self.embed_dim))
  177. else:
  178. q_bias = None
  179. v_bias = None
  180. if q_bias is not None:
  181. qkv_bias = torch.cat((q_bias, torch.zeros_like(v_bias, requires_grad=False), v_bias))
  182. self.qkv.bias = nn.Parameter(qkv_bias)
  183. self.projection = nn.Linear(self.embed_dim, self.embed_dim)
  184. def _shape(self, tensor: torch.Tensor, seq_len: int, bsz: int):
  185. return tensor.view(bsz, seq_len, self.num_heads, self.head_dim).transpose(1, 2).contiguous()
  186. def forward(
  187. self,
  188. hidden_states: torch.Tensor,
  189. head_mask: Optional[torch.Tensor] = None,
  190. **kwargs,
  191. ) -> tuple[torch.Tensor, Optional[torch.Tensor], Optional[tuple[torch.Tensor]]]:
  192. """Input shape: Batch x Time x Channel"""
  193. bsz, tgt_len, embed_dim = hidden_states.size()
  194. mixed_qkv = self.qkv(hidden_states)
  195. mixed_qkv = mixed_qkv.reshape(bsz, tgt_len, 3, self.num_heads, embed_dim // self.num_heads).permute(
  196. 2, 0, 3, 1, 4
  197. )
  198. query_states, key_states, value_states = mixed_qkv[0], mixed_qkv[1], mixed_qkv[2]
  199. attention_interface: Callable = eager_attention_forward
  200. if self.config._attn_implementation != "eager":
  201. attention_interface = ALL_ATTENTION_FUNCTIONS[self.config._attn_implementation]
  202. attn_output, attn_weights = attention_interface(
  203. self,
  204. query_states,
  205. key_states,
  206. value_states,
  207. attention_mask=None,
  208. dropout=0.0 if not self.training else self.attention_dropout,
  209. scaling=self.scale,
  210. **kwargs,
  211. )
  212. attn_output = attn_output.reshape(bsz, tgt_len, -1).contiguous()
  213. attn_output = self.projection(attn_output)
  214. return attn_output, attn_weights
  215. class InstructBlipVideoMLP(nn.Module):
  216. def __init__(self, config):
  217. super().__init__()
  218. self.config = config
  219. self.activation_fn = ACT2FN[config.hidden_act]
  220. self.fc1 = nn.Linear(config.hidden_size, config.intermediate_size)
  221. self.fc2 = nn.Linear(config.intermediate_size, config.hidden_size)
  222. def forward(self, hidden_states: torch.Tensor) -> torch.Tensor:
  223. hidden_states = self.fc1(hidden_states)
  224. hidden_states = self.activation_fn(hidden_states)
  225. hidden_states = self.fc2(hidden_states)
  226. return hidden_states
  227. class InstructBlipVideoEncoderLayer(GradientCheckpointingLayer):
  228. def __init__(self, config: InstructBlipVideoConfig):
  229. super().__init__()
  230. self.embed_dim = config.hidden_size
  231. self.self_attn = InstructBlipVideoAttention(config)
  232. self.layer_norm1 = nn.LayerNorm(self.embed_dim, eps=config.layer_norm_eps)
  233. self.mlp = InstructBlipVideoMLP(config)
  234. self.layer_norm2 = nn.LayerNorm(self.embed_dim, eps=config.layer_norm_eps)
  235. @auto_docstring
  236. def forward(
  237. self,
  238. hidden_states: torch.Tensor,
  239. attention_mask: torch.Tensor,
  240. **kwargs: Unpack[TransformersKwargs],
  241. ) -> torch.FloatTensor:
  242. residual = hidden_states
  243. hidden_states = self.layer_norm1(hidden_states)
  244. hidden_states, _ = self.self_attn(
  245. hidden_states=hidden_states,
  246. head_mask=attention_mask,
  247. **kwargs,
  248. )
  249. hidden_states = hidden_states + residual
  250. residual = hidden_states
  251. hidden_states = self.layer_norm2(hidden_states)
  252. hidden_states = self.mlp(hidden_states)
  253. hidden_states = hidden_states + residual
  254. return hidden_states
  255. class InstructBlipVideoEncoder(nn.Module):
  256. """
  257. Transformer encoder consisting of `config.num_hidden_layers` self attention layers. Each layer is a
  258. [`InstructBlipVideoEncoderLayer`].
  259. Args:
  260. config (`InstructBlipVideoConfig`):
  261. The corresponding vision configuration for the `InstructBlipVideoEncoder`.
  262. """
  263. def __init__(self, config: InstructBlipVideoConfig):
  264. super().__init__()
  265. self.config = config
  266. self.layers = nn.ModuleList([InstructBlipVideoEncoderLayer(config) for _ in range(config.num_hidden_layers)])
  267. self.gradient_checkpointing = False
  268. @auto_docstring
  269. def forward(
  270. self,
  271. inputs_embeds,
  272. attention_mask: Optional[torch.Tensor] = None,
  273. **kwargs: Unpack[TransformersKwargs],
  274. ) -> Union[tuple, BaseModelOutput]:
  275. hidden_states = inputs_embeds
  276. for encoder_layer in self.layers:
  277. hidden_states = encoder_layer(
  278. hidden_states,
  279. attention_mask=attention_mask,
  280. **kwargs,
  281. )
  282. return BaseModelOutput(last_hidden_state=hidden_states)
  283. class InstructBlipVideoVisionModel(InstructBlipVideoPreTrainedModel):
  284. main_input_name = "pixel_values"
  285. config: InstructBlipVideoVisionConfig
  286. _can_record_outputs = {
  287. "hidden_states": InstructBlipVideoEncoderLayer,
  288. "attentions": InstructBlipVideoAttention,
  289. }
  290. def __init__(self, config: InstructBlipVideoVisionConfig):
  291. super().__init__(config)
  292. self.config = config
  293. embed_dim = config.hidden_size
  294. self.embeddings = InstructBlipVideoVisionEmbeddings(config)
  295. self.encoder = InstructBlipVideoEncoder(config)
  296. self.post_layernorm = nn.LayerNorm(embed_dim, eps=config.layer_norm_eps)
  297. self.post_init()
  298. @check_model_inputs(tie_last_hidden_states=False)
  299. @auto_docstring
  300. def forward(
  301. self,
  302. pixel_values: Optional[torch.FloatTensor] = None,
  303. interpolate_pos_encoding: bool = False,
  304. **kwargs: Unpack[TransformersKwargs],
  305. ) -> Union[tuple, BaseModelOutputWithPooling]:
  306. if pixel_values is None:
  307. raise ValueError("You have to specify pixel_values")
  308. hidden_states = self.embeddings(pixel_values, interpolate_pos_encoding=interpolate_pos_encoding)
  309. encoder_outputs: BaseModelOutput = self.encoder(
  310. inputs_embeds=hidden_states,
  311. **kwargs,
  312. )
  313. last_hidden_state = encoder_outputs.last_hidden_state
  314. last_hidden_state = self.post_layernorm(last_hidden_state)
  315. pooled_output = last_hidden_state[:, 0, :]
  316. pooled_output = self.post_layernorm(pooled_output)
  317. return BaseModelOutputWithPooling(
  318. last_hidden_state=last_hidden_state,
  319. pooler_output=pooled_output,
  320. )
  321. def get_input_embeddings(self):
  322. return self.embeddings
  323. class InstructBlipVideoQFormerMultiHeadAttention(nn.Module):
  324. def __init__(self, config, is_cross_attention=False):
  325. super().__init__()
  326. self.config = config
  327. if config.hidden_size % config.num_attention_heads != 0 and not hasattr(config, "embedding_size"):
  328. raise ValueError(
  329. "The hidden size (%d) is not a multiple of the number of attention heads (%d)"
  330. % (config.hidden_size, config.num_attention_heads)
  331. )
  332. self.num_attention_heads = config.num_attention_heads
  333. self.attention_head_size = int(config.hidden_size / config.num_attention_heads)
  334. self.all_head_size = self.num_attention_heads * self.attention_head_size
  335. self.query = nn.Linear(config.hidden_size, self.all_head_size)
  336. if is_cross_attention:
  337. self.key = nn.Linear(config.encoder_hidden_size, self.all_head_size)
  338. self.value = nn.Linear(config.encoder_hidden_size, self.all_head_size)
  339. else:
  340. self.key = nn.Linear(config.hidden_size, self.all_head_size)
  341. self.value = nn.Linear(config.hidden_size, self.all_head_size)
  342. self.dropout = nn.Dropout(config.attention_probs_dropout_prob)
  343. self.position_embedding_type = getattr(config, "position_embedding_type", "absolute")
  344. if self.position_embedding_type == "relative_key" or self.position_embedding_type == "relative_key_query":
  345. self.max_position_embeddings = config.max_position_embeddings
  346. self.distance_embedding = nn.Embedding(2 * config.max_position_embeddings - 1, self.attention_head_size)
  347. self.save_attention = False
  348. def save_attn_gradients(self, attn_gradients):
  349. self.attn_gradients = attn_gradients
  350. def get_attn_gradients(self):
  351. return self.attn_gradients
  352. def save_attention_map(self, attention_map):
  353. self.attention_map = attention_map
  354. def get_attention_map(self):
  355. return self.attention_map
  356. def transpose_for_scores(self, x):
  357. new_x_shape = x.size()[:-1] + (self.num_attention_heads, self.attention_head_size)
  358. x = x.view(*new_x_shape)
  359. return x.permute(0, 2, 1, 3)
  360. def forward(
  361. self,
  362. hidden_states,
  363. attention_mask=None,
  364. head_mask=None,
  365. encoder_hidden_states=None,
  366. encoder_attention_mask=None,
  367. **kwargs: Unpack[TransformersKwargs],
  368. ):
  369. # If this is instantiated as a cross-attention module, the keys
  370. # and values come from an encoder; the attention mask needs to be
  371. # such that the encoder's padding tokens are not attended to.
  372. is_cross_attention = encoder_hidden_states is not None
  373. if is_cross_attention:
  374. key_layer = self.transpose_for_scores(self.key(encoder_hidden_states))
  375. value_layer = self.transpose_for_scores(self.value(encoder_hidden_states))
  376. attention_mask = encoder_attention_mask
  377. else:
  378. key_layer = self.transpose_for_scores(self.key(hidden_states))
  379. value_layer = self.transpose_for_scores(self.value(hidden_states))
  380. mixed_query_layer = self.query(hidden_states)
  381. query_layer = self.transpose_for_scores(mixed_query_layer)
  382. # Take the dot product between "query" and "key" to get the raw attention scores.
  383. attention_scores = torch.matmul(query_layer, key_layer.transpose(-1, -2))
  384. if self.position_embedding_type == "relative_key" or self.position_embedding_type == "relative_key_query":
  385. seq_length = hidden_states.size()[1]
  386. position_ids_l = torch.arange(seq_length, dtype=torch.long, device=hidden_states.device).view(-1, 1)
  387. position_ids_r = torch.arange(seq_length, dtype=torch.long, device=hidden_states.device).view(1, -1)
  388. distance = position_ids_l - position_ids_r
  389. positional_embedding = self.distance_embedding(distance + self.max_position_embeddings - 1)
  390. positional_embedding = positional_embedding.to(dtype=query_layer.dtype) # fp16 compatibility
  391. if self.position_embedding_type == "relative_key":
  392. relative_position_scores = torch.einsum("bhld,lrd->bhlr", query_layer, positional_embedding)
  393. attention_scores = attention_scores + relative_position_scores
  394. elif self.position_embedding_type == "relative_key_query":
  395. relative_position_scores_query = torch.einsum("bhld,lrd->bhlr", query_layer, positional_embedding)
  396. relative_position_scores_key = torch.einsum("bhrd,lrd->bhlr", key_layer, positional_embedding)
  397. attention_scores = attention_scores + relative_position_scores_query + relative_position_scores_key
  398. attention_scores = attention_scores / math.sqrt(self.attention_head_size)
  399. attention_scores_dtype = attention_scores.dtype
  400. if attention_mask is not None:
  401. # Apply the attention mask is (precomputed for all layers in BertModel forward() function)
  402. attention_scores = attention_scores + attention_mask
  403. # Normalize the attention scores to probabilities.
  404. attention_probs = nn.Softmax(dim=-1)(attention_scores).to(attention_scores_dtype)
  405. if is_cross_attention and self.save_attention:
  406. self.save_attention_map(attention_probs)
  407. attention_probs.register_hook(self.save_attn_gradients)
  408. # This is actually dropping out entire tokens to attend to, which might
  409. # seem a bit unusual, but is taken from the original Transformer paper.
  410. attention_probs_dropped = self.dropout(attention_probs)
  411. # Mask heads if we want to
  412. if head_mask is not None:
  413. attention_probs_dropped = attention_probs_dropped * head_mask
  414. context_layer = torch.matmul(attention_probs_dropped, value_layer)
  415. context_layer = context_layer.permute(0, 2, 1, 3).contiguous()
  416. new_context_layer_shape = context_layer.size()[:-2] + (self.all_head_size,)
  417. context_layer = context_layer.view(*new_context_layer_shape)
  418. return context_layer, attention_probs
  419. class InstructBlipVideoQFormerSelfOutput(nn.Module):
  420. def __init__(self, config):
  421. super().__init__()
  422. self.dense = nn.Linear(config.hidden_size, config.hidden_size)
  423. self.LayerNorm = nn.LayerNorm(config.hidden_size, eps=config.layer_norm_eps)
  424. self.dropout = nn.Dropout(config.hidden_dropout_prob)
  425. def forward(self, hidden_states: torch.Tensor, input_tensor: torch.Tensor) -> torch.Tensor:
  426. hidden_states = self.dense(hidden_states)
  427. hidden_states = self.dropout(hidden_states)
  428. hidden_states = self.LayerNorm(hidden_states + input_tensor)
  429. return hidden_states
  430. class InstructBlipVideoQFormerAttention(nn.Module):
  431. def __init__(self, config, is_cross_attention=False):
  432. super().__init__()
  433. self.attention = InstructBlipVideoQFormerMultiHeadAttention(config, is_cross_attention)
  434. self.output = InstructBlipVideoQFormerSelfOutput(config)
  435. self.pruned_heads = set()
  436. def prune_heads(self, heads):
  437. if len(heads) == 0:
  438. return
  439. heads, index = find_pruneable_heads_and_indices(
  440. heads, self.attention.num_attention_heads, self.attention.attention_head_size, self.pruned_heads
  441. )
  442. # Prune linear layers
  443. self.attention.query = prune_linear_layer(self.attention.query, index)
  444. self.attention.key = prune_linear_layer(self.attention.key, index)
  445. self.attention.value = prune_linear_layer(self.attention.value, index)
  446. self.output.dense = prune_linear_layer(self.output.dense, index, dim=1)
  447. # Update hyper params and store pruned heads
  448. self.attention.num_attention_heads = self.attention.num_attention_heads - len(heads)
  449. self.attention.all_head_size = self.attention.attention_head_size * self.attention.num_attention_heads
  450. self.pruned_heads = self.pruned_heads.union(heads)
  451. def forward(
  452. self,
  453. hidden_states: torch.Tensor,
  454. attention_mask: Optional[torch.FloatTensor] = None,
  455. head_mask: Optional[torch.FloatTensor] = None,
  456. encoder_hidden_states: Optional[torch.FloatTensor] = None,
  457. encoder_attention_mask: Optional[torch.FloatTensor] = None,
  458. **kwargs: Unpack[TransformersKwargs],
  459. ) -> torch.Tensor:
  460. attn_output, _ = self.attention(
  461. hidden_states=hidden_states,
  462. attention_mask=attention_mask,
  463. head_mask=head_mask,
  464. encoder_hidden_states=encoder_hidden_states,
  465. encoder_attention_mask=encoder_attention_mask,
  466. **kwargs,
  467. )
  468. attention_output = self.output(attn_output, hidden_states)
  469. return attention_output
  470. class InstructBlipVideoQFormerIntermediate(nn.Module):
  471. def __init__(self, config):
  472. super().__init__()
  473. self.dense = nn.Linear(config.hidden_size, config.intermediate_size)
  474. if isinstance(config.hidden_act, str):
  475. self.intermediate_act_fn = ACT2FN[config.hidden_act]
  476. else:
  477. self.intermediate_act_fn = config.hidden_act
  478. def forward(self, hidden_states: torch.Tensor) -> torch.Tensor:
  479. hidden_states = self.dense(hidden_states)
  480. hidden_states = self.intermediate_act_fn(hidden_states)
  481. return hidden_states
  482. class InstructBlipVideoQFormerOutput(nn.Module):
  483. def __init__(self, config):
  484. super().__init__()
  485. self.dense = nn.Linear(config.intermediate_size, config.hidden_size)
  486. self.LayerNorm = nn.LayerNorm(config.hidden_size, eps=config.layer_norm_eps)
  487. self.dropout = nn.Dropout(config.hidden_dropout_prob)
  488. def forward(self, hidden_states: torch.Tensor, input_tensor: torch.Tensor) -> torch.Tensor:
  489. hidden_states = self.dense(hidden_states)
  490. hidden_states = self.dropout(hidden_states)
  491. hidden_states = self.LayerNorm(hidden_states + input_tensor)
  492. return hidden_states
  493. class InstructBlipVideoQFormerLayer(GradientCheckpointingLayer):
  494. def __init__(self, config, layer_idx):
  495. super().__init__()
  496. self.chunk_size_feed_forward = config.chunk_size_feed_forward
  497. self.seq_len_dim = 1
  498. self.attention = InstructBlipVideoQFormerAttention(config)
  499. self.layer_idx = layer_idx
  500. if layer_idx % config.cross_attention_frequency == 0:
  501. self.crossattention = InstructBlipVideoQFormerAttention(config, is_cross_attention=True)
  502. self.has_cross_attention = True
  503. else:
  504. self.has_cross_attention = False
  505. self.intermediate = InstructBlipVideoQFormerIntermediate(config)
  506. self.output = InstructBlipVideoQFormerOutput(config)
  507. self.intermediate_query = InstructBlipVideoQFormerIntermediate(config)
  508. self.output_query = InstructBlipVideoQFormerOutput(config)
  509. def forward(
  510. self,
  511. hidden_states,
  512. attention_mask=None,
  513. head_mask=None,
  514. encoder_hidden_states=None,
  515. encoder_attention_mask=None,
  516. query_length=0,
  517. **kwargs: Unpack[TransformersKwargs],
  518. ):
  519. attention_output = self.attention(
  520. hidden_states,
  521. attention_mask=attention_mask,
  522. head_mask=head_mask,
  523. **kwargs,
  524. )
  525. if query_length > 0:
  526. query_attention_output = attention_output[:, :query_length, :]
  527. if self.has_cross_attention:
  528. if encoder_hidden_states is None:
  529. raise ValueError("encoder_hidden_states must be given for cross-attention layers")
  530. query_attention_output = self.crossattention(
  531. query_attention_output,
  532. attention_mask=attention_mask,
  533. head_mask=head_mask,
  534. encoder_hidden_states=encoder_hidden_states,
  535. encoder_attention_mask=encoder_attention_mask,
  536. **kwargs,
  537. )
  538. layer_output = apply_chunking_to_forward(
  539. self.feed_forward_chunk_query,
  540. self.chunk_size_feed_forward,
  541. self.seq_len_dim,
  542. query_attention_output,
  543. )
  544. if attention_output.shape[1] > query_length:
  545. layer_output_text = apply_chunking_to_forward(
  546. self.feed_forward_chunk,
  547. self.chunk_size_feed_forward,
  548. self.seq_len_dim,
  549. attention_output[:, query_length:, :],
  550. ).to(layer_output.device)
  551. layer_output = torch.cat([layer_output, layer_output_text], dim=1)
  552. else:
  553. layer_output = apply_chunking_to_forward(
  554. self.feed_forward_chunk,
  555. self.chunk_size_feed_forward,
  556. self.seq_len_dim,
  557. attention_output,
  558. )
  559. return layer_output
  560. def feed_forward_chunk(self, attention_output):
  561. intermediate_output = self.intermediate(attention_output)
  562. layer_output = self.output(intermediate_output, attention_output)
  563. return layer_output
  564. def feed_forward_chunk_query(self, attention_output):
  565. intermediate_output = self.intermediate_query(attention_output)
  566. layer_output = self.output_query(intermediate_output, attention_output)
  567. return layer_output
  568. class InstructBlipVideoQFormerEncoder(nn.Module):
  569. def __init__(self, config):
  570. super().__init__()
  571. self.config = config
  572. self.layer = nn.ModuleList(
  573. [InstructBlipVideoQFormerLayer(config, layer_idx) for layer_idx in range(config.num_hidden_layers)]
  574. )
  575. self.gradient_checkpointing = False
  576. @can_return_tuple
  577. def forward(
  578. self,
  579. hidden_states,
  580. attention_mask=None,
  581. head_mask=None,
  582. encoder_hidden_states=None,
  583. encoder_attention_mask=None,
  584. query_length=0,
  585. **kwargs: Unpack[TransformersKwargs],
  586. ):
  587. for i in range(self.config.num_hidden_layers):
  588. layer_module = self.layer[i]
  589. layer_head_mask = head_mask[i] if head_mask is not None else None
  590. hidden_states = layer_module(
  591. hidden_states,
  592. attention_mask,
  593. layer_head_mask,
  594. encoder_hidden_states, # as a positional argument for gradient checkpointing
  595. encoder_attention_mask=encoder_attention_mask,
  596. query_length=query_length,
  597. **kwargs,
  598. )
  599. return BaseModelOutputWithPastAndCrossAttentions(
  600. last_hidden_state=hidden_states,
  601. )
  602. class InstructBlipVideoQFormerEmbeddings(nn.Module):
  603. """Construct the embeddings from word and position embeddings."""
  604. def __init__(self, config):
  605. super().__init__()
  606. self.word_embeddings = nn.Embedding(config.vocab_size, config.hidden_size, padding_idx=config.pad_token_id)
  607. self.position_embeddings = nn.Embedding(config.max_position_embeddings, config.hidden_size)
  608. self.layernorm = nn.LayerNorm(config.hidden_size, eps=config.layer_norm_eps)
  609. self.dropout = nn.Dropout(config.hidden_dropout_prob)
  610. # position_ids (1, len position emb) is contiguous in memory and exported when serialized
  611. self.register_buffer(
  612. "position_ids", torch.arange(config.max_position_embeddings).expand((1, -1)), persistent=False
  613. )
  614. self.position_embedding_type = getattr(config, "position_embedding_type", "absolute")
  615. self.config = config
  616. def forward(
  617. self,
  618. input_ids=None,
  619. position_ids=None,
  620. query_embeds=None,
  621. past_key_values_length=0,
  622. ):
  623. if input_ids is not None:
  624. seq_length = input_ids.size()[1]
  625. else:
  626. seq_length = 0
  627. if position_ids is None:
  628. position_ids = self.position_ids[:, past_key_values_length : seq_length + past_key_values_length].clone()
  629. if input_ids is not None:
  630. embeddings = self.word_embeddings(input_ids)
  631. if self.position_embedding_type == "absolute":
  632. position_embeddings = self.position_embeddings(position_ids.to(embeddings.device))
  633. embeddings = embeddings + position_embeddings
  634. if query_embeds is not None:
  635. embeddings = torch.cat((query_embeds, embeddings), dim=1)
  636. else:
  637. embeddings = query_embeds
  638. embeddings = embeddings.to(self.layernorm.weight.dtype)
  639. embeddings = self.layernorm(embeddings)
  640. embeddings = self.dropout(embeddings)
  641. return embeddings
  642. class InstructBlipVideoQFormerModel(InstructBlipVideoPreTrainedModel):
  643. """
  644. Querying Transformer (Q-Former), used in InstructBlipVideo. Slightly modified from BLIP-2 as it also takes the
  645. instruction as input.
  646. """
  647. _supports_attention_backend = False # adds position on attn weights before last matmul
  648. _supports_flash_attn = False
  649. _supports_sdpa = False
  650. _supports_flex_attn = False
  651. _can_record_outputs = {
  652. "hidden_states": InstructBlipVideoQFormerLayer,
  653. "attentions": [
  654. OutputRecorder(InstructBlipVideoQFormerMultiHeadAttention, index=1, layer_name=".attention"),
  655. ],
  656. "cross_attentions": [
  657. OutputRecorder(InstructBlipVideoQFormerMultiHeadAttention, index=1, layer_name=".crossattention"),
  658. ],
  659. }
  660. def __init__(self, config: InstructBlipVideoQFormerConfig):
  661. super().__init__(config)
  662. self.config = config
  663. self.embeddings = InstructBlipVideoQFormerEmbeddings(config)
  664. self.encoder = InstructBlipVideoQFormerEncoder(config)
  665. self.post_init()
  666. def get_input_embeddings(self):
  667. return self.embeddings.word_embeddings
  668. def set_input_embeddings(self, value):
  669. self.embeddings.word_embeddings = value
  670. def _prune_heads(self, heads_to_prune):
  671. """
  672. Prunes heads of the model. heads_to_prune: dict of {layer_num: list of heads to prune in this layer} See base
  673. class PreTrainedModel
  674. """
  675. for layer, heads in heads_to_prune.items():
  676. self.encoder.layer[layer].attention.prune_heads(heads)
  677. def get_extended_attention_mask(
  678. self,
  679. attention_mask: torch.Tensor,
  680. input_shape: tuple[int],
  681. device: torch.device,
  682. has_query: bool = False,
  683. ) -> torch.Tensor:
  684. """
  685. Makes broadcastable attention and causal masks so that future and masked tokens are ignored.
  686. Arguments:
  687. attention_mask (`torch.Tensor`):
  688. Mask with ones indicating tokens to attend to, zeros for tokens to ignore.
  689. input_shape (`tuple[int]`):
  690. The shape of the input to the model.
  691. device: (`torch.device`):
  692. The device of the input to the model.
  693. Returns:
  694. `torch.Tensor` The extended attention mask, with a the same dtype as `attention_mask.dtype`.
  695. """
  696. # We can provide a self-attention mask of dimensions [batch_size, from_seq_length, to_seq_length]
  697. # ourselves in which case we just need to make it broadcastable to all heads.
  698. if attention_mask.dim() == 3:
  699. extended_attention_mask = attention_mask[:, None, :, :]
  700. elif attention_mask.dim() == 2:
  701. # Provided a padding mask of dimensions [batch_size, seq_length]
  702. # - the model is an encoder, so make the mask broadcastable to [batch_size, num_heads, seq_length, seq_length]
  703. extended_attention_mask = attention_mask[:, None, None, :]
  704. else:
  705. raise ValueError(
  706. f"Wrong shape for input_ids (shape {input_shape}) or attention_mask (shape {attention_mask.shape})",
  707. )
  708. # Since attention_mask is 1.0 for positions we want to attend and 0.0 for
  709. # masked positions, this operation will create a tensor which is 0.0 for
  710. # positions we want to attend and -10000.0 for masked positions.
  711. # Since we are adding it to the raw scores before the softmax, this is
  712. # effectively the same as removing these entirely.
  713. extended_attention_mask = extended_attention_mask.to(dtype=self.dtype) # fp16 compatibility
  714. extended_attention_mask = (1.0 - extended_attention_mask) * -10000.0
  715. return extended_attention_mask
  716. @check_model_inputs()
  717. @auto_docstring
  718. def forward(
  719. self,
  720. input_ids: torch.LongTensor,
  721. attention_mask: Optional[torch.FloatTensor] = None,
  722. position_ids: Optional[torch.LongTensor] = None,
  723. query_embeds: Optional[torch.Tensor] = None,
  724. head_mask: Optional[torch.FloatTensor] = None,
  725. encoder_hidden_states: Optional[torch.FloatTensor] = None,
  726. encoder_attention_mask: Optional[torch.FloatTensor] = None,
  727. **kwargs: Unpack[TransformersKwargs],
  728. ) -> Union[tuple[torch.FloatTensor], BaseModelOutputWithPoolingAndCrossAttentions]:
  729. r"""
  730. query_embeds (`torch.FloatTensor` of shape `(batch_size, sequence_length, hidden_size)`):
  731. Hidden states to be used in the attention computation. If cross-attention,
  732. will be used for the query (i.e., key and value will use the encoder_hidden_states).
  733. """
  734. if input_ids is None and query_embeds is None:
  735. raise ValueError("You have to specify query_embeds when input_ids is None")
  736. query_length = query_embeds.shape[1] if query_embeds is not None else 0
  737. embedding_output = self.embeddings(
  738. input_ids=input_ids,
  739. position_ids=position_ids,
  740. query_embeds=query_embeds,
  741. )
  742. input_shape = embedding_output.size()[:-1]
  743. batch_size, seq_length = input_shape
  744. device = embedding_output.device
  745. if attention_mask is None:
  746. attention_mask = torch.ones(((batch_size, seq_length)), device=device)
  747. # We can provide a self-attention mask of dimensions [batch_size, from_seq_length, to_seq_length]
  748. # ourselves in which case we just need to make it broadcastable to all heads.
  749. extended_attention_mask = self.get_extended_attention_mask(attention_mask, input_shape, device)
  750. # If a 2D or 3D attention mask is provided for the cross-attention
  751. # we need to make broadcastable to [batch_size, num_heads, seq_length, seq_length]
  752. if encoder_hidden_states is not None:
  753. if isinstance(encoder_hidden_states, list):
  754. encoder_batch_size, encoder_sequence_length, _ = encoder_hidden_states[0].size()
  755. else:
  756. encoder_batch_size, encoder_sequence_length, _ = encoder_hidden_states.size()
  757. encoder_hidden_shape = (encoder_batch_size, encoder_sequence_length)
  758. if isinstance(encoder_attention_mask, list):
  759. encoder_extended_attention_mask = [self.invert_attention_mask(mask) for mask in encoder_attention_mask]
  760. elif encoder_attention_mask is None:
  761. encoder_attention_mask = torch.ones(encoder_hidden_shape, device=device)
  762. encoder_extended_attention_mask = self.invert_attention_mask(encoder_attention_mask)
  763. else:
  764. encoder_extended_attention_mask = self.invert_attention_mask(encoder_attention_mask)
  765. else:
  766. encoder_extended_attention_mask = None
  767. # Prepare head mask if needed
  768. # 1.0 in head_mask indicate we keep the head
  769. # attention_probs has shape bsz x n_heads x N x N
  770. # input head_mask has shape [num_heads] or [num_hidden_layers x num_heads]
  771. # and head_mask is converted to shape [num_hidden_layers x batch x num_heads x seq_length x seq_length]
  772. head_mask = self.get_head_mask(head_mask, self.config.num_hidden_layers)
  773. encoder_outputs: BaseModelOutput = self.encoder(
  774. embedding_output,
  775. attention_mask=extended_attention_mask,
  776. head_mask=head_mask,
  777. encoder_hidden_states=encoder_hidden_states,
  778. encoder_attention_mask=encoder_extended_attention_mask,
  779. query_length=query_length,
  780. **kwargs,
  781. )
  782. sequence_output = encoder_outputs.last_hidden_state
  783. pooled_output = sequence_output[:, 0, :]
  784. return BaseModelOutputWithPoolingAndCrossAttentions(
  785. last_hidden_state=sequence_output,
  786. pooler_output=pooled_output,
  787. )
  788. @dataclass
  789. @auto_docstring(
  790. custom_intro="""
  791. Class defining the outputs of [`InstructBlipVideoForConditionalGeneration`].
  792. """
  793. )
  794. class InstructBlipVideoForConditionalGenerationModelOutput(ModelOutput):
  795. r"""
  796. loss (`torch.FloatTensor`, *optional*, returned when `labels` is provided, `torch.FloatTensor` of shape `(1,)`):
  797. Language modeling loss from the language model.
  798. logits (`torch.FloatTensor` of shape `(batch_size, sequence_length, config.vocab_size)`):
  799. Prediction scores of the language modeling head of the language model.
  800. vision_outputs (`BaseModelOutputWithPooling`):
  801. Outputs of the vision encoder.
  802. qformer_outputs (`BaseModelOutputWithPoolingAndCrossAttentions`):
  803. Outputs of the Q-Former (Querying Transformer).
  804. language_model_outputs (`CausalLMOutputWithPast` or `Seq2SeqLMOutput`):
  805. Outputs of the language model.
  806. """
  807. loss: Optional[tuple[torch.FloatTensor]] = None
  808. logits: Optional[tuple[torch.FloatTensor]] = None
  809. vision_outputs: Optional[torch.FloatTensor] = None
  810. qformer_outputs: Optional[tuple[torch.FloatTensor]] = None
  811. language_model_outputs: Optional[tuple[torch.FloatTensor]] = None
  812. def to_tuple(self) -> tuple[Any]:
  813. return tuple(
  814. self[k]
  815. if k not in ["vision_outputs", "qformer_outputs", "language_model_outputs"]
  816. else getattr(self, k).to_tuple()
  817. for k in self.keys()
  818. )
  819. @auto_docstring(
  820. custom_intro="""
  821. InstructBlipVideo base Model consisting of language model, qformer and vision encoder.
  822. """
  823. )
  824. class InstructBlipVideoModel(InstructBlipVideoPreTrainedModel):
  825. main_input_name = "pixel_values"
  826. _keep_in_fp32_modules = ["query_tokens"] # TODO @ArthurZucker I don't know why this is required for FP8
  827. def __init__(self, config: InstructBlipVideoConfig):
  828. super().__init__(config)
  829. self.vision_model = InstructBlipVideoVisionModel(config.vision_config)
  830. self.query_tokens = nn.Parameter(torch.zeros(1, config.num_query_tokens, config.qformer_config.hidden_size))
  831. self.qformer = InstructBlipVideoQFormerModel(config.qformer_config)
  832. self.language_projection = nn.Linear(config.qformer_config.hidden_size, config.text_config.hidden_size)
  833. self.language_model = AutoModel.from_config(config.text_config)
  834. if self.language_model._no_split_modules is not None:
  835. self._no_split_modules.extend(self.language_model._no_split_modules)
  836. if self.language_model._keep_in_fp32_modules is not None:
  837. self._keep_in_fp32_modules.extend(self.language_model._keep_in_fp32_modules)
  838. # Initialize weights and apply final processing
  839. self.post_init()
  840. def get_input_embeddings(self):
  841. return self.language_model.get_input_embeddings()
  842. def set_input_embeddings(self, value):
  843. self.language_model.set_input_embeddings(value)
  844. def _tie_weights(self):
  845. if not self.config.use_decoder_only_language_model:
  846. self.language_model.encoder.embed_tokens = self.language_model.shared
  847. self.language_model.decoder.embed_tokens = self.language_model.shared
  848. def _preprocess_accelerate(self):
  849. r"""
  850. Some pre-processing hacks to make the model `accelerate` compatible. Check
  851. https://github.com/huggingface/transformers/pull/21707 for more details.
  852. """
  853. hf_device_map = self.hf_device_map
  854. if len(hf_device_map) > 1 and "language_model" not in hf_device_map and torch.cuda.device_count() > 1:
  855. # warn users about unexpected behavior when using multi-GPU + InstructBlipVideo + `accelerate`.
  856. logger.warning(
  857. "The `language_model` is not in the `hf_device_map` dictionary and you are running your script"
  858. " in a multi-GPU environment. this may lead to unexpected behavior when using `accelerate`."
  859. " Please pass a `device_map` that contains `language_model` to remove this warning."
  860. " Please refer to https://github.com/huggingface/blog/blob/main/accelerate-large-models.md for"
  861. " more details on creating a `device_map` for large models.",
  862. )
  863. if hasattr(self.language_model, "_hf_hook"):
  864. self.language_model._hf_hook.io_same_device = True # For `generate` compatibility
  865. def get_placeholder_mask(self, input_ids: torch.LongTensor, inputs_embeds: torch.FloatTensor):
  866. """
  867. Obtains multimodal placeholder mask from `input_ids` or `inputs_embeds`.
  868. """
  869. if input_ids is None:
  870. special_image_mask = inputs_embeds == self.get_input_embeddings()(
  871. torch.tensor(self.config.image_token_id, dtype=torch.long, device=inputs_embeds.device)
  872. )
  873. special_image_mask = special_image_mask.all(-1)
  874. else:
  875. special_image_mask = input_ids == self.config.image_token_id
  876. special_image_mask = special_image_mask.unsqueeze(-1).expand_as(inputs_embeds).to(inputs_embeds.device)
  877. return special_image_mask
  878. @can_return_tuple
  879. @auto_docstring
  880. def forward(
  881. self,
  882. pixel_values: torch.FloatTensor,
  883. qformer_input_ids: torch.FloatTensor,
  884. qformer_attention_mask: Optional[torch.LongTensor] = None,
  885. input_ids: Optional[torch.FloatTensor] = None,
  886. attention_mask: Optional[torch.LongTensor] = None,
  887. decoder_input_ids: Optional[torch.LongTensor] = None,
  888. decoder_attention_mask: Optional[torch.LongTensor] = None,
  889. inputs_embeds: Optional[torch.Tensor] = None,
  890. output_attentions: Optional[bool] = None,
  891. output_hidden_states: Optional[bool] = None,
  892. return_dict: Optional[bool] = None,
  893. interpolate_pos_encoding: bool = False,
  894. use_cache: Optional[bool] = None,
  895. **kwargs: Unpack[FlashAttentionKwargs],
  896. ) -> Union[tuple, InstructBlipVideoForConditionalGenerationModelOutput]:
  897. r"""
  898. qformer_input_ids (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*):
  899. Indices of input sequence tokens in the vocabulary of the Q-Former. Input tokens can optionally be provided
  900. to serve as text prompt, which the Q-Former model will encode.
  901. Indices can be obtained using [`InstructBlipVideoProcessor`]. See [`InstructBlipVideoProcessor.__call__`] for
  902. details.
  903. [What are input IDs?](../glossary#input-ids)
  904. qformer_attention_mask (`torch.Tensor` of shape `(batch_size, sequence_length)`, *optional*):
  905. Mask to avoid performing attention on padding token indices. Mask values selected in `[0, 1]`:
  906. - 1 for tokens that are **not masked**,
  907. - 0 for tokens that are **masked**.
  908. [What are attention masks?](../glossary#attention-mask)
  909. decoder_attention_mask (`torch.BoolTensor` of shape `(batch_size, target_sequence_length)`, *optional*):
  910. Default behavior: generate a tensor that ignores pad tokens in `decoder_input_ids`. Causal mask will also
  911. be used by default.
  912. Only relevant in case an encoder-decoder language model (like T5) is used.
  913. """
  914. return_dict = return_dict if return_dict is not None else self.config.use_return_dict
  915. # step 1: forward the images through the vision encoder,
  916. # we process in a batched way, later unbatch it back (video has frames=4 always)
  917. batch_size, frames, channel, height, width = pixel_values.shape
  918. pixel_values = pixel_values.reshape(batch_size * frames, channel, height, width)
  919. vision_outputs = self.vision_model(
  920. pixel_values=pixel_values,
  921. output_attentions=output_attentions,
  922. output_hidden_states=output_hidden_states,
  923. return_dict=return_dict,
  924. interpolate_pos_encoding=interpolate_pos_encoding,
  925. )
  926. image_embeds = vision_outputs[0]
  927. # step 2: forward the query tokens through the QFormer, using the image embeddings for cross-attention
  928. image_attention_mask = torch.ones(image_embeds.size()[:-1], dtype=torch.long, device=image_embeds.device)
  929. # difference with BLIP-2 here: we also feed the instruction prompt to the Q-Former
  930. query_tokens = self.query_tokens.expand(image_embeds.shape[0], -1, -1)
  931. query_attention_mask = torch.ones(query_tokens.size()[:-1], dtype=torch.long, device=image_embeds.device)
  932. if qformer_attention_mask is None:
  933. qformer_attention_mask = torch.ones_like(qformer_input_ids)
  934. qformer_input_ids = qformer_input_ids.repeat_interleave(frames, dim=0)
  935. qformer_attention_mask = qformer_attention_mask.repeat_interleave(frames, dim=0)
  936. qformer_attention_mask = torch.cat([query_attention_mask, qformer_attention_mask], dim=1)
  937. query_outputs = self.qformer(
  938. input_ids=qformer_input_ids,
  939. attention_mask=qformer_attention_mask,
  940. query_embeds=query_tokens,
  941. encoder_hidden_states=image_embeds,
  942. encoder_attention_mask=image_attention_mask,
  943. output_attentions=output_attentions,
  944. output_hidden_states=output_hidden_states,
  945. return_dict=return_dict,
  946. )
  947. query_output = query_outputs[0][:, : query_tokens.size(1), :]
  948. # step 3: use the language model, conditioned on the query outputs and the prompt
  949. language_model_inputs = self.language_projection(query_output)
  950. # unbatch inputs back, each video-frame gets `num_query_tokens` seq length
  951. language_model_inputs = language_model_inputs.reshape(batch_size, self.config.num_query_tokens * frames, -1)
  952. if inputs_embeds is None:
  953. inputs_embeds = self.language_model.get_input_embeddings()(input_ids)
  954. special_image_mask = input_ids == self.config.video_token_id
  955. if attention_mask is None:
  956. attention_mask = torch.ones_like(input_ids)
  957. else:
  958. special_image_mask = inputs_embeds == self.get_input_embeddings()(
  959. torch.tensor(self.config.video_token_id, dtype=torch.long, device=inputs_embeds.device)
  960. )
  961. special_image_mask = special_image_mask.all(-1)
  962. special_image_mask = special_image_mask.unsqueeze(-1).expand_as(inputs_embeds).to(inputs_embeds.device)
  963. language_model_inputs = language_model_inputs.to(inputs_embeds.device, inputs_embeds.dtype)
  964. inputs_embeds = inputs_embeds.masked_scatter(special_image_mask, language_model_inputs)
  965. if self.config.use_decoder_only_language_model:
  966. outputs = self.language_model(
  967. inputs_embeds=inputs_embeds,
  968. attention_mask=attention_mask,
  969. output_attentions=output_attentions,
  970. output_hidden_states=output_hidden_states,
  971. return_dict=return_dict,
  972. use_cache=use_cache,
  973. **kwargs,
  974. )
  975. else:
  976. outputs = self.language_model(
  977. inputs_embeds=inputs_embeds,
  978. attention_mask=attention_mask,
  979. decoder_input_ids=decoder_input_ids,
  980. decoder_attention_mask=decoder_attention_mask,
  981. output_attentions=output_attentions,
  982. output_hidden_states=output_hidden_states,
  983. return_dict=return_dict,
  984. use_cache=use_cache,
  985. **kwargs,
  986. )
  987. return InstructBlipVideoForConditionalGenerationModelOutput(
  988. vision_outputs=vision_outputs,
  989. qformer_outputs=query_outputs,
  990. language_model_outputs=outputs,
  991. )
  992. @auto_docstring(
  993. custom_intro="""
  994. InstructBlipVideo Model for generating text given an image and an optional text prompt. The model consists of a vision
  995. encoder, Querying Transformer (Q-Former) and a language model.
  996. One can optionally pass `input_ids` to the model, which serve as a text prompt, to make the language model continue
  997. the prompt. Otherwise, the language model starts generating text from the [BOS] (beginning-of-sequence) token.
  998. """
  999. )
  1000. class InstructBlipVideoForConditionalGeneration(InstructBlipVideoPreTrainedModel, GenerationMixin):
  1001. config: InstructBlipVideoConfig
  1002. main_input_name = "pixel_values"
  1003. _can_compile_fullgraph = True
  1004. _keep_in_fp32_modules = ["query_tokens"] # TODO @ArthurZucker I don't know why this is required for FP8
  1005. def __init__(self, config: InstructBlipVideoConfig):
  1006. super().__init__(config)
  1007. self.vision_model = InstructBlipVideoVisionModel._from_config(config.vision_config)
  1008. self.query_tokens = nn.Parameter(torch.zeros(1, config.num_query_tokens, config.qformer_config.hidden_size))
  1009. self.qformer = InstructBlipVideoQFormerModel._from_config(config.qformer_config)
  1010. self.language_projection = nn.Linear(config.qformer_config.hidden_size, config.text_config.hidden_size)
  1011. if config.use_decoder_only_language_model:
  1012. language_model = AutoModelForCausalLM.from_config(config.text_config)
  1013. else:
  1014. language_model = AutoModelForSeq2SeqLM.from_config(config.text_config)
  1015. if language_model._no_split_modules is not None:
  1016. self._no_split_modules.extend(language_model._no_split_modules)
  1017. if language_model._keep_in_fp32_modules is not None:
  1018. self._keep_in_fp32_modules.extend(language_model._keep_in_fp32_modules)
  1019. self.language_model = language_model
  1020. # Initialize weights and apply final processing
  1021. self.post_init()
  1022. def get_input_embeddings(self):
  1023. return self.language_model.get_input_embeddings()
  1024. def set_input_embeddings(self, value):
  1025. self.language_model.set_input_embeddings(value)
  1026. def set_output_embeddings(self, new_embeddings):
  1027. self.language_model.set_output_embeddings(new_embeddings)
  1028. def get_output_embeddings(self) -> nn.Module:
  1029. return self.language_model.get_output_embeddings()
  1030. def get_encoder(self):
  1031. return self.language_model.get_encoder()
  1032. def get_decoder(self):
  1033. return self.language_model.get_decoder()
  1034. def _tie_weights(self):
  1035. if not self.config.use_decoder_only_language_model:
  1036. self.language_model.encoder.embed_tokens = self.language_model.shared
  1037. self.language_model.decoder.embed_tokens = self.language_model.shared
  1038. def _preprocess_accelerate(self):
  1039. r"""
  1040. Some pre-processing hacks to make the model `accelerate` compatible. Check
  1041. https://github.com/huggingface/transformers/pull/21707 for more details.
  1042. """
  1043. hf_device_map = self.hf_device_map
  1044. if len(hf_device_map) > 1 and "language_model" not in hf_device_map and torch.cuda.device_count() > 1:
  1045. # warn users about unexpected behavior when using multi-GPU + InstructBlipVideo + `accelerate`.
  1046. logger.warning(
  1047. "The `language_model` is not in the `hf_device_map` dictionary and you are running your script"
  1048. " in a multi-GPU environment. this may lead to unexpected behavior when using `accelerate`."
  1049. " Please pass a `device_map` that contains `language_model` to remove this warning."
  1050. " Please refer to https://github.com/huggingface/blog/blob/main/accelerate-large-models.md for"
  1051. " more details on creating a `device_map` for large models.",
  1052. )
  1053. if hasattr(self.language_model, "_hf_hook"):
  1054. self.language_model._hf_hook.io_same_device = True # For `generate` compatibility
  1055. def get_image_features(
  1056. self,
  1057. pixel_values: torch.FloatTensor,
  1058. qformer_input_ids: torch.LongTensor,
  1059. qformer_attention_mask: Optional[torch.LongTensor] = None,
  1060. interpolate_pos_encoding: Optional[bool] = False,
  1061. return_dict: Optional[bool] = False,
  1062. ):
  1063. """
  1064. Encodes images into continuous embeddings that can be forwarded to the language model.
  1065. Args:
  1066. pixel_values (`torch.FloatTensor` of shape `(batch_size, num_channels, image_size, image_size)`):
  1067. The tensors corresponding to the input images.
  1068. """
  1069. pass
  1070. def get_placeholder_mask(self, input_ids: torch.LongTensor, inputs_embeds: torch.FloatTensor):
  1071. """
  1072. Obtains multimodal placeholder mask from `input_ids` or `inputs_embeds`.
  1073. """
  1074. if input_ids is None:
  1075. special_image_mask = inputs_embeds == self.get_input_embeddings()(
  1076. torch.tensor(self.config.video_token_id, dtype=torch.long, device=inputs_embeds.device)
  1077. )
  1078. special_image_mask = special_image_mask.all(-1)
  1079. else:
  1080. special_image_mask = input_ids == self.config.video_token_id
  1081. special_image_mask = special_image_mask.unsqueeze(-1).expand_as(inputs_embeds).to(inputs_embeds.device)
  1082. return special_image_mask
  1083. @can_return_tuple
  1084. @auto_docstring
  1085. def forward(
  1086. self,
  1087. pixel_values: torch.FloatTensor,
  1088. qformer_input_ids: torch.FloatTensor,
  1089. qformer_attention_mask: Optional[torch.LongTensor] = None,
  1090. input_ids: Optional[torch.FloatTensor] = None,
  1091. attention_mask: Optional[torch.LongTensor] = None,
  1092. decoder_input_ids: Optional[torch.LongTensor] = None,
  1093. decoder_attention_mask: Optional[torch.LongTensor] = None,
  1094. inputs_embeds: Optional[torch.FloatTensor] = None,
  1095. output_attentions: Optional[bool] = None,
  1096. output_hidden_states: Optional[bool] = None,
  1097. labels: Optional[torch.LongTensor] = None,
  1098. return_dict: Optional[bool] = None,
  1099. interpolate_pos_encoding: bool = False,
  1100. use_cache: Optional[bool] = None,
  1101. **kwargs: Unpack[TransformersKwargs],
  1102. ) -> Union[tuple, InstructBlipVideoForConditionalGenerationModelOutput]:
  1103. r"""
  1104. qformer_input_ids (`torch.LongTensor` of shape (batch_size, sequence_length)):
  1105. The sequence used as a prompt to be fed to the Q-Former module.
  1106. qformer_attention_mask (`torch.LongTensor` of shape (batch_size, sequence_length), *optional*):
  1107. Mask to avoid performing attention on padding token indices.
  1108. Examples:
  1109. ```python
  1110. >>> from transformers import InstructBlipVideoProcessor, InstructBlipVideoForConditionalGeneration
  1111. >>> import torch
  1112. >>> from huggingface_hub import hf_hub_download
  1113. >>> import av
  1114. >>> import numpy as np
  1115. >>> def read_video_pyav(container, indices):
  1116. ... '''
  1117. ... Decode the video with PyAV decoder.
  1118. ... Args:
  1119. ... container (`av.container.input.InputContainer`): PyAV container.
  1120. ... indices (`list[int]`): List of frame indices to decode.
  1121. ... Returns:
  1122. ... result (np.ndarray): np array of decoded frames of shape (num_frames, height, width, 3).
  1123. ... '''
  1124. ... frames = []
  1125. ... container.seek(0)
  1126. ... start_index = indices[0]
  1127. ... end_index = indices[-1]
  1128. ... for i, frame in enumerate(container.decode(video=0)):
  1129. ... if i > end_index:
  1130. ... break
  1131. ... if i >= start_index and i in indices:
  1132. ... frames.append(frame)
  1133. ... return np.stack([x.to_ndarray(format="rgb24") for x in frames])
  1134. >>> model = InstructBlipVideoForConditionalGeneration.from_pretrained("Salesforce/instructblip-vicuna-7b", device_map="auto")
  1135. >>> processor = InstructBlipVideoProcessor.from_pretrained("Salesforce/instructblip-vicuna-7b")
  1136. >>> file_path = hf_hub_download(
  1137. ... repo_id="nielsr/video-demo", filename="eating_spaghetti.mp4", repo_type="dataset"
  1138. ... )
  1139. >>> container = av.open(file_path)
  1140. >>> # sample uniformly 4 frames from the videWhy is this video funny?o
  1141. >>> total_frames = container.streams.video[0].frames
  1142. >>> indices = np.arange(0, total_frames, total_frames / 4).astype(int)
  1143. >>> clip = read_video_pyav(container, indices)
  1144. >>> prompt = "What is happening in the video?"
  1145. >>> inputs = processor(text=prompt, images=clip, return_tensors="pt").to(model.device)
  1146. >>> outputs = model.generate(
  1147. ... **inputs,
  1148. ... do_sample=False,
  1149. ... num_beams=5,
  1150. ... max_length=256,
  1151. ... repetition_penalty=1.5,
  1152. ... length_penalty=1.0,
  1153. ... )
  1154. >>> generated_text = processor.batch_decode(outputs, skip_special_tokens=True)[0].strip()
  1155. >>> print(generated_text)
  1156. "A person is eating a bowl of pasta, and they are using a fork to eat it. The person is sitting at a table, and the plate of pasta is on the table in front"
  1157. ```"""
  1158. return_dict = return_dict if return_dict is not None else self.config.use_return_dict
  1159. language_model_inputs, vision_outputs, query_outputs = self.get_video_features(
  1160. pixel_values,
  1161. qformer_input_ids=qformer_input_ids,
  1162. qformer_attention_mask=qformer_attention_mask,
  1163. interpolate_pos_encoding=interpolate_pos_encoding,
  1164. return_dict=True,
  1165. )
  1166. vision_outputs = vision_outputs.to_tuple() if not return_dict else vision_outputs
  1167. query_outputs = query_outputs.to_tuple() if not return_dict else query_outputs
  1168. if inputs_embeds is None:
  1169. inputs_embeds = self.get_input_embeddings()(input_ids)
  1170. if attention_mask is None:
  1171. attention_mask = torch.ones_like(input_ids)
  1172. language_model_inputs = language_model_inputs.to(inputs_embeds.device, inputs_embeds.dtype)
  1173. special_image_mask = self.get_placeholder_mask(input_ids, inputs_embeds=inputs_embeds)
  1174. inputs_embeds = inputs_embeds.masked_scatter(special_image_mask, language_model_inputs)
  1175. if self.config.use_decoder_only_language_model:
  1176. outputs = self.language_model(
  1177. inputs_embeds=inputs_embeds,
  1178. attention_mask=attention_mask,
  1179. output_attentions=output_attentions,
  1180. output_hidden_states=output_hidden_states,
  1181. return_dict=return_dict,
  1182. use_cache=use_cache,
  1183. **kwargs,
  1184. )
  1185. logits = outputs.logits if return_dict else outputs[0]
  1186. loss = None
  1187. if labels is not None:
  1188. loss = self.loss_function(
  1189. logits=logits, labels=labels, vocab_size=self.config.text_config.vocab_size, **kwargs
  1190. )
  1191. else:
  1192. outputs = self.language_model(
  1193. inputs_embeds=inputs_embeds,
  1194. attention_mask=attention_mask,
  1195. decoder_input_ids=decoder_input_ids,
  1196. decoder_attention_mask=decoder_attention_mask,
  1197. output_attentions=output_attentions,
  1198. output_hidden_states=output_hidden_states,
  1199. return_dict=return_dict,
  1200. labels=labels,
  1201. use_cache=use_cache,
  1202. **kwargs,
  1203. )
  1204. loss = outputs.loss if return_dict else outputs[0]
  1205. logits = outputs.logits if return_dict else outputs[1]
  1206. return InstructBlipVideoForConditionalGenerationModelOutput(
  1207. loss=loss,
  1208. logits=logits,
  1209. vision_outputs=vision_outputs,
  1210. qformer_outputs=query_outputs,
  1211. language_model_outputs=outputs,
  1212. )
  1213. @torch.no_grad()
  1214. def generate(
  1215. self,
  1216. pixel_values: torch.FloatTensor,
  1217. qformer_input_ids: Optional[torch.LongTensor] = None,
  1218. qformer_attention_mask: Optional[torch.LongTensor] = None,
  1219. input_ids: Optional[torch.LongTensor] = None,
  1220. attention_mask: Optional[torch.LongTensor] = None,
  1221. inputs_embeds: Optional[torch.FloatTensor] = None,
  1222. interpolate_pos_encoding: bool = False,
  1223. **generate_kwargs,
  1224. ) -> torch.LongTensor:
  1225. r"""
  1226. Overrides `generate` function to be able to use the model as a conditional generator.
  1227. Args:
  1228. pixel_values (`torch.FloatTensor` of shape (batch_size, num_channels, height, width) or
  1229. (batch_size, num_frames, num_channels, height, width)): Input images or videos to be processed.
  1230. qformer_input_ids (`torch.LongTensor` of shape (batch_size, sequence_length), *optional*):
  1231. The sequence used as a prompt to be fed to the Q-Former module.
  1232. qformer_attention_mask (`torch.LongTensor` of shape (batch_size, sequence_length), *optional*):
  1233. Mask to avoid performing attention on padding token indices.
  1234. input_ids (`torch.LongTensor` of shape (batch_size, sequence_length), *optional*):
  1235. The sequence used as a prompt for the generation.
  1236. attention_mask (`torch.LongTensor` of shape (batch_size, sequence_length), *optional*):
  1237. Mask to avoid performing attention on padding token indices.
  1238. inputs_embeds (`torch.FloatTensor` of shape `(batch_size, sequence_length, hidden_size)`):
  1239. Embedded representation of the inputs. Should be float, not int tokens.
  1240. interpolate_pos_encoding (`bool`, *optional*, defaults to `False`):
  1241. Whether to interpolate the positional encoding of the image embeddings.
  1242. Returns:
  1243. captions (list): A list of strings of length batch_size * num_captions.
  1244. """
  1245. if hasattr(self, "hf_device_map"):
  1246. # preprocess for `accelerate`
  1247. self._preprocess_accelerate()
  1248. batch_size = pixel_values.shape[0]
  1249. language_model_inputs, vision_outputs, query_outputs = self.get_video_features(
  1250. pixel_values,
  1251. qformer_input_ids=qformer_input_ids,
  1252. qformer_attention_mask=qformer_attention_mask,
  1253. interpolate_pos_encoding=interpolate_pos_encoding,
  1254. return_dict=True,
  1255. )
  1256. if inputs_embeds is None:
  1257. if input_ids is None:
  1258. video_tokens = [self.config.video_token_index] * self.config.num_query_tokens * 4
  1259. start_tokens = video_tokens + [self.config.text_config.bos_token_id]
  1260. input_ids = torch.tensor([start_tokens], dtype=torch.long, device=pixel_values.device)
  1261. input_ids = input_ids.repeat(batch_size, 1)
  1262. inputs_embeds = self.get_input_embeddings()(input_ids)
  1263. if attention_mask is None:
  1264. attention_mask = torch.ones_like(input_ids)
  1265. language_model_inputs = language_model_inputs.to(inputs_embeds.device, inputs_embeds.dtype)
  1266. special_image_mask = self.get_placeholder_mask(input_ids, inputs_embeds=inputs_embeds)
  1267. inputs_embeds = inputs_embeds.masked_scatter(special_image_mask, language_model_inputs)
  1268. inputs = {"inputs_embeds": inputs_embeds, "attention_mask": attention_mask}
  1269. if not self.language_model.config.is_encoder_decoder:
  1270. inputs["input_ids"] = input_ids
  1271. outputs = self.language_model.generate(**inputs, **generate_kwargs)
  1272. return outputs
  1273. def get_video_features(
  1274. self,
  1275. pixel_values: torch.FloatTensor,
  1276. qformer_input_ids: torch.LongTensor,
  1277. qformer_attention_mask: Optional[torch.LongTensor] = None,
  1278. interpolate_pos_encoding: Optional[bool] = False,
  1279. return_dict: Optional[bool] = False,
  1280. ):
  1281. """
  1282. Encodes images into continuous embeddings that can be forwarded to the language model.
  1283. Args:
  1284. pixel_values (`torch.FloatTensor` of shape `(batch_size, num_channels, image_size, image_size)`):
  1285. The tensors corresponding to the input images.
  1286. """
  1287. # step 1: forward the images through the vision encoder,
  1288. # we process in a batched way, later unbatch it back (video has frames=4 always)
  1289. batch_size, frames, channel, height, width = pixel_values.shape
  1290. pixel_values = pixel_values.reshape(batch_size * frames, channel, height, width)
  1291. vision_outputs = self.vision_model(
  1292. pixel_values=pixel_values,
  1293. interpolate_pos_encoding=interpolate_pos_encoding,
  1294. return_dict=True,
  1295. )
  1296. image_embeds = vision_outputs[0]
  1297. # step 2: forward the query tokens through the QFormer, using the image embeddings for cross-attention
  1298. image_attention_mask = torch.ones(image_embeds.size()[:-1], dtype=torch.long, device=image_embeds.device)
  1299. # difference with BLIP-2 here: we also feed the instruction prompt to the Q-Former
  1300. query_tokens = self.query_tokens.expand(image_embeds.shape[0], -1, -1)
  1301. query_attention_mask = torch.ones(query_tokens.size()[:-1], dtype=torch.long, device=image_embeds.device)
  1302. if qformer_attention_mask is None:
  1303. qformer_attention_mask = torch.ones_like(qformer_input_ids)
  1304. qformer_input_ids = qformer_input_ids.repeat_interleave(frames, dim=0)
  1305. qformer_attention_mask = qformer_attention_mask.repeat_interleave(frames, dim=0)
  1306. qformer_attention_mask = torch.cat([query_attention_mask, qformer_attention_mask], dim=1)
  1307. query_outputs = self.qformer(
  1308. input_ids=qformer_input_ids,
  1309. attention_mask=qformer_attention_mask,
  1310. query_embeds=query_tokens,
  1311. encoder_hidden_states=image_embeds,
  1312. encoder_attention_mask=image_attention_mask,
  1313. return_dict=True,
  1314. )
  1315. query_output = query_outputs[0][:, : query_tokens.size(1), :]
  1316. # step 3: use the language model, conditioned on the query outputs and the prompt
  1317. language_model_inputs = self.language_projection(query_output)
  1318. # unbatch inputs back, each video-frame gets `num_query_tokens` seq length
  1319. language_model_inputs = language_model_inputs.reshape(batch_size, self.config.num_query_tokens * frames, -1)
  1320. if return_dict:
  1321. return language_model_inputs, vision_outputs, query_outputs
  1322. return language_model_inputs
  1323. __all__ = [
  1324. "InstructBlipVideoVisionModel",
  1325. "InstructBlipVideoPreTrainedModel",
  1326. "InstructBlipVideoQFormerModel",
  1327. "InstructBlipVideoModel",
  1328. "InstructBlipVideoForConditionalGeneration",
  1329. ]