| 12345678910111213141516171819202122232425262728293031323334353637383940414243444546474849505152535455565758596061626364656667686970717273747576777879808182838485868788899091929394959697989910010110210310410510610710810911011111211311411511611711811912012112212312412512612712812913013113213313413513613713813914014114214314414514614714814915015115215315415515615715815916016116216316416516616716816917017117217317417517617717817918018118218318418518618718818919019119219319419519619719819920020120220320420520620720820921021121221321421521621721821922022122222322422522622722822923023123223323423523623723823924024124224324424524624724824925025125225325425525625725825926026126226326426526626726826927027127227327427527627727827928028128228328428528628728828929029129229329429529629729829930030130230330430530630730830931031131231331431531631731831932032132232332432532632732832933033133233333433533633733833934034134234334434534634734834935035135235335435535635735835936036136236336436536636736836937037137237337437537637737837938038138238338438538638738838939039139239339439539639739839940040140240340440540640740840941041141241341441541641741841942042142242342442542642742842943043143243343443543643743843944044144244344444544644744844945045145245345445545645745845946046146246346446546646746846947047147247347447547647747847948048148248348448548648748848949049149249349449549649749849950050150250350450550650750850951051151251351451551651751851952052152252352452552652752852953053153253353453553653753853954054154254354454554654754854955055155255355455555655755855956056156256356456556656756856957057157257357457557657757857958058158258358458558658758858959059159259359459559659759859960060160260360460560660760860961061161261361461561661761861962062162262362462562662762862963063163263363463563663763863964064164264364464564664764864965065165265365465565665765865966066166266366466566666766866967067167267367467567667767867968068168268368468568668768868969069169269369469569669769869970070170270370470570670770870971071171271371471571671771871972072172272372472572672772872973073173273373473573673773873974074174274374474574674774874975075175275375475575675775875976076176276376476576676776876977077177277377477577677777877978078178278378478578678778878979079179279379479579679779879980080180280380480580680780880981081181281381481581681781881982082182282382482582682782882983083183283383483583683783883984084184284384484584684784884985085185285385485585685785885986086186286386486586686786886987087187287387487587687787887988088188288388488588688788888989089189289389489589689789889990090190290390490590690790890991091191291391491591691791891992092192292392492592692792892993093193293393493593693793893994094194294394494594694794894995095195295395495595695795895996096196296396496596696796896997097197297397497597697797897998098198298398498598698798898999099199299399499599699799899910001001100210031004100510061007100810091010101110121013101410151016101710181019102010211022102310241025102610271028102910301031103210331034103510361037103810391040104110421043104410451046104710481049105010511052105310541055105610571058105910601061106210631064106510661067106810691070107110721073107410751076107710781079108010811082108310841085108610871088108910901091109210931094109510961097109810991100110111021103110411051106110711081109111011111112111311141115111611171118111911201121112211231124112511261127112811291130113111321133113411351136113711381139114011411142114311441145114611471148114911501151115211531154115511561157115811591160116111621163116411651166116711681169117011711172117311741175117611771178117911801181118211831184118511861187118811891190119111921193119411951196119711981199120012011202120312041205120612071208120912101211121212131214121512161217121812191220122112221223122412251226122712281229123012311232123312341235123612371238123912401241124212431244124512461247124812491250125112521253125412551256125712581259126012611262126312641265126612671268126912701271127212731274127512761277127812791280128112821283128412851286128712881289129012911292129312941295129612971298129913001301130213031304130513061307130813091310131113121313131413151316131713181319132013211322132313241325132613271328132913301331133213331334133513361337133813391340134113421343134413451346134713481349135013511352135313541355135613571358135913601361136213631364136513661367136813691370137113721373137413751376137713781379138013811382138313841385138613871388138913901391139213931394139513961397139813991400140114021403140414051406140714081409141014111412141314141415141614171418141914201421142214231424142514261427142814291430143114321433143414351436143714381439144014411442144314441445144614471448144914501451145214531454145514561457145814591460146114621463146414651466146714681469147014711472147314741475147614771478147914801481148214831484148514861487148814891490149114921493149414951496149714981499150015011502150315041505150615071508150915101511151215131514151515161517151815191520152115221523152415251526152715281529153015311532153315341535153615371538153915401541154215431544154515461547154815491550155115521553155415551556155715581559156015611562156315641565156615671568156915701571157215731574157515761577157815791580158115821583158415851586158715881589159015911592159315941595159615971598159916001601160216031604 |
- # 🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨
- # This file was automatically generated from src/transformers/models/instructblipvideo/modular_instructblipvideo.py.
- # Do NOT edit this file manually as any edits will be overwritten by the generation of
- # the file from the modular. If any change should be done, please apply the change to the
- # modular_instructblipvideo.py file directly. One of our CI enforces this.
- # 🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨
- # coding=utf-8
- # Copyright 2024 HuggingFace Inc. team. All rights reserved.
- #
- # Licensed under the Apache License, Version 2.0 (the "License");
- # you may not use this file except in compliance with the License.
- # You may obtain a copy of the License at
- #
- # http://www.apache.org/licenses/LICENSE-2.0
- #
- # Unless required by applicable law or agreed to in writing, software
- # distributed under the License is distributed on an "AS IS" BASIS,
- # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
- # See the License for the specific language governing permissions and
- # limitations under the License.
- import math
- from dataclasses import dataclass
- from typing import Any, Callable, Optional, Union
- import torch
- from torch import nn
- from ...activations import ACT2FN
- from ...generation import GenerationMixin
- from ...modeling_flash_attention_utils import FlashAttentionKwargs
- from ...modeling_layers import GradientCheckpointingLayer
- from ...modeling_outputs import (
- BaseModelOutput,
- BaseModelOutputWithPastAndCrossAttentions,
- BaseModelOutputWithPooling,
- BaseModelOutputWithPoolingAndCrossAttentions,
- )
- from ...modeling_utils import ALL_ATTENTION_FUNCTIONS, PreTrainedModel
- from ...processing_utils import Unpack
- from ...pytorch_utils import apply_chunking_to_forward, find_pruneable_heads_and_indices, prune_linear_layer
- from ...utils import ModelOutput, TransformersKwargs, auto_docstring, can_return_tuple, logging, torch_int
- from ...utils.generic import OutputRecorder, check_model_inputs
- from ..auto import AutoModel, AutoModelForCausalLM, AutoModelForSeq2SeqLM
- from .configuration_instructblipvideo import (
- InstructBlipVideoConfig,
- InstructBlipVideoQFormerConfig,
- InstructBlipVideoVisionConfig,
- )
- logger = logging.get_logger(__name__)
- class InstructBlipVideoVisionEmbeddings(nn.Module):
- def __init__(self, config: InstructBlipVideoVisionConfig):
- super().__init__()
- self.config = config
- self.embed_dim = config.hidden_size
- self.image_size = config.image_size
- self.patch_size = config.patch_size
- self.class_embedding = nn.Parameter(torch.randn(1, 1, self.embed_dim))
- self.patch_embedding = nn.Conv2d(
- in_channels=3, out_channels=self.embed_dim, kernel_size=self.patch_size, stride=self.patch_size
- )
- self.num_patches = (self.image_size // self.patch_size) ** 2
- self.num_positions = self.num_patches + 1
- self.position_embedding = nn.Parameter(torch.randn(1, self.num_positions, self.embed_dim))
- def interpolate_pos_encoding(self, embeddings: torch.Tensor, height: int, width: int) -> torch.Tensor:
- """
- This method allows to interpolate the pre-trained position encodings, to be able to use the model on higher resolution
- images. This method is also adapted to support torch.jit tracing.
- Adapted from:
- - https://github.com/facebookresearch/dino/blob/de9ee3df6cf39fac952ab558447af1fa1365362a/vision_transformer.py#L174-L194, and
- - https://github.com/facebookresearch/dinov2/blob/e1277af2ba9496fbadf7aec6eba56e8d882d1e35/dinov2/models/vision_transformer.py#L179-L211
- """
- num_patches = embeddings.shape[1] - 1
- num_positions = self.position_embedding.shape[1] - 1
- # always interpolate when tracing to ensure the exported model works for dynamic input shapes
- if not torch.jit.is_tracing() and num_patches == num_positions and height == width:
- return self.position_embedding
- class_pos_embed = self.position_embedding[:, :1]
- patch_pos_embed = self.position_embedding[:, 1:]
- dim = embeddings.shape[-1]
- new_height = height // self.patch_size
- new_width = width // self.patch_size
- sqrt_num_positions = torch_int(num_positions**0.5)
- patch_pos_embed = patch_pos_embed.reshape(1, sqrt_num_positions, sqrt_num_positions, dim)
- patch_pos_embed = patch_pos_embed.permute(0, 3, 1, 2)
- patch_pos_embed = nn.functional.interpolate(
- patch_pos_embed,
- size=(new_height, new_width),
- mode="bicubic",
- align_corners=False,
- )
- patch_pos_embed = patch_pos_embed.permute(0, 2, 3, 1).view(1, -1, dim)
- return torch.cat((class_pos_embed, patch_pos_embed), dim=1)
- def forward(self, pixel_values: torch.FloatTensor, interpolate_pos_encoding: bool = False) -> torch.Tensor:
- batch_size, _, height, width = pixel_values.shape
- target_dtype = self.patch_embedding.weight.dtype
- patch_embeds = self.patch_embedding(pixel_values.to(dtype=target_dtype)) # shape = [*, width, grid, grid]
- patch_embeds = patch_embeds.flatten(2).transpose(1, 2)
- class_embeds = self.class_embedding.expand(batch_size, 1, -1).to(target_dtype)
- embeddings = torch.cat([class_embeds, patch_embeds], dim=1)
- if interpolate_pos_encoding:
- position_embedding = self.interpolate_pos_encoding(embeddings, height, width)
- else:
- position_embedding = self.position_embedding
- embeddings = embeddings + position_embedding[:, : embeddings.size(1), :].to(target_dtype)
- return embeddings
- @auto_docstring
- class InstructBlipVideoPreTrainedModel(PreTrainedModel):
- config: InstructBlipVideoConfig
- base_model_prefix = "blip"
- supports_gradient_checkpointing = True
- _supports_attention_backend = True
- _supports_flash_attn = True
- _supports_sdpa = True
- _supports_flex_attn = True
- _can_compile_fullgraph = True
- _no_split_modules = [
- "InstructBlipVideoQFormerEmbeddings",
- "InstructBlipVideoAttention",
- "InstructBlipVideoQFormerMultiHeadAttention",
- "InstructBlipVideoQFormerSelfOutput",
- ]
- def _init_weights(self, module):
- """Initialize the weights"""
- factor = self.config.initializer_range
- if isinstance(module, (nn.Linear, nn.Conv2d)):
- module.weight.data.normal_(mean=0.0, std=factor)
- if module.bias is not None:
- module.bias.data.zero_()
- elif isinstance(module, nn.Embedding):
- module.weight.data.normal_(mean=0.0, std=factor)
- elif isinstance(module, nn.LayerNorm):
- module.bias.data.zero_()
- module.weight.data.fill_(1.0)
- elif isinstance(module, InstructBlipVideoVisionEmbeddings):
- nn.init.trunc_normal_(module.position_embedding, mean=0.0, std=factor)
- nn.init.trunc_normal_(module.class_embedding, mean=0.0, std=factor)
- elif isinstance(module, (InstructBlipVideoForConditionalGeneration, InstructBlipVideoModel)):
- module.query_tokens.data.zero_()
- # Adapted from transformers.models.siglip.modeling_siglip.eager_attention_forward -> InstructBlipVideo doesn't cast attn weights to fp32
- def eager_attention_forward(
- module: nn.Module,
- query: torch.Tensor,
- key: torch.Tensor,
- value: torch.Tensor,
- attention_mask: Optional[torch.Tensor],
- scaling: float,
- dropout: float = 0.0,
- **kwargs,
- ):
- attn_weights = torch.matmul(query, key.transpose(-1, -2)) * scaling
- if attention_mask is not None:
- attn_weights = attn_weights + attention_mask
- attn_weights = nn.functional.softmax(attn_weights, dim=-1)
- attn_weights = nn.functional.dropout(attn_weights, p=dropout, training=module.training)
- attn_output = torch.matmul(attn_weights, value)
- attn_output = attn_output.transpose(1, 2).contiguous()
- return attn_output, attn_weights
- class InstructBlipVideoAttention(nn.Module):
- """Multi-headed attention from 'Attention Is All You Need' paper"""
- def __init__(self, config):
- super().__init__()
- self.config = config
- self.embed_dim = config.hidden_size
- self.num_heads = config.num_attention_heads
- self.head_dim = self.embed_dim // self.num_heads
- if self.head_dim * self.num_heads != self.embed_dim:
- raise ValueError(
- f"embed_dim must be divisible by num_heads (got `embed_dim`: {self.embed_dim} and `num_heads`:"
- f" {self.num_heads})."
- )
- self.scale = self.head_dim**-0.5
- self.is_causal = False
- self.attention_dropout = config.attention_dropout
- # small tweak here compared to CLIP, no bias here
- self.qkv = nn.Linear(self.embed_dim, 3 * self.embed_dim, bias=False)
- if config.qkv_bias:
- q_bias = nn.Parameter(torch.zeros(self.embed_dim))
- v_bias = nn.Parameter(torch.zeros(self.embed_dim))
- else:
- q_bias = None
- v_bias = None
- if q_bias is not None:
- qkv_bias = torch.cat((q_bias, torch.zeros_like(v_bias, requires_grad=False), v_bias))
- self.qkv.bias = nn.Parameter(qkv_bias)
- self.projection = nn.Linear(self.embed_dim, self.embed_dim)
- def _shape(self, tensor: torch.Tensor, seq_len: int, bsz: int):
- return tensor.view(bsz, seq_len, self.num_heads, self.head_dim).transpose(1, 2).contiguous()
- def forward(
- self,
- hidden_states: torch.Tensor,
- head_mask: Optional[torch.Tensor] = None,
- **kwargs,
- ) -> tuple[torch.Tensor, Optional[torch.Tensor], Optional[tuple[torch.Tensor]]]:
- """Input shape: Batch x Time x Channel"""
- bsz, tgt_len, embed_dim = hidden_states.size()
- mixed_qkv = self.qkv(hidden_states)
- mixed_qkv = mixed_qkv.reshape(bsz, tgt_len, 3, self.num_heads, embed_dim // self.num_heads).permute(
- 2, 0, 3, 1, 4
- )
- query_states, key_states, value_states = mixed_qkv[0], mixed_qkv[1], mixed_qkv[2]
- attention_interface: Callable = eager_attention_forward
- if self.config._attn_implementation != "eager":
- attention_interface = ALL_ATTENTION_FUNCTIONS[self.config._attn_implementation]
- attn_output, attn_weights = attention_interface(
- self,
- query_states,
- key_states,
- value_states,
- attention_mask=None,
- dropout=0.0 if not self.training else self.attention_dropout,
- scaling=self.scale,
- **kwargs,
- )
- attn_output = attn_output.reshape(bsz, tgt_len, -1).contiguous()
- attn_output = self.projection(attn_output)
- return attn_output, attn_weights
- class InstructBlipVideoMLP(nn.Module):
- def __init__(self, config):
- super().__init__()
- self.config = config
- self.activation_fn = ACT2FN[config.hidden_act]
- self.fc1 = nn.Linear(config.hidden_size, config.intermediate_size)
- self.fc2 = nn.Linear(config.intermediate_size, config.hidden_size)
- def forward(self, hidden_states: torch.Tensor) -> torch.Tensor:
- hidden_states = self.fc1(hidden_states)
- hidden_states = self.activation_fn(hidden_states)
- hidden_states = self.fc2(hidden_states)
- return hidden_states
- class InstructBlipVideoEncoderLayer(GradientCheckpointingLayer):
- def __init__(self, config: InstructBlipVideoConfig):
- super().__init__()
- self.embed_dim = config.hidden_size
- self.self_attn = InstructBlipVideoAttention(config)
- self.layer_norm1 = nn.LayerNorm(self.embed_dim, eps=config.layer_norm_eps)
- self.mlp = InstructBlipVideoMLP(config)
- self.layer_norm2 = nn.LayerNorm(self.embed_dim, eps=config.layer_norm_eps)
- @auto_docstring
- def forward(
- self,
- hidden_states: torch.Tensor,
- attention_mask: torch.Tensor,
- **kwargs: Unpack[TransformersKwargs],
- ) -> torch.FloatTensor:
- residual = hidden_states
- hidden_states = self.layer_norm1(hidden_states)
- hidden_states, _ = self.self_attn(
- hidden_states=hidden_states,
- head_mask=attention_mask,
- **kwargs,
- )
- hidden_states = hidden_states + residual
- residual = hidden_states
- hidden_states = self.layer_norm2(hidden_states)
- hidden_states = self.mlp(hidden_states)
- hidden_states = hidden_states + residual
- return hidden_states
- class InstructBlipVideoEncoder(nn.Module):
- """
- Transformer encoder consisting of `config.num_hidden_layers` self attention layers. Each layer is a
- [`InstructBlipVideoEncoderLayer`].
- Args:
- config (`InstructBlipVideoConfig`):
- The corresponding vision configuration for the `InstructBlipVideoEncoder`.
- """
- def __init__(self, config: InstructBlipVideoConfig):
- super().__init__()
- self.config = config
- self.layers = nn.ModuleList([InstructBlipVideoEncoderLayer(config) for _ in range(config.num_hidden_layers)])
- self.gradient_checkpointing = False
- @auto_docstring
- def forward(
- self,
- inputs_embeds,
- attention_mask: Optional[torch.Tensor] = None,
- **kwargs: Unpack[TransformersKwargs],
- ) -> Union[tuple, BaseModelOutput]:
- hidden_states = inputs_embeds
- for encoder_layer in self.layers:
- hidden_states = encoder_layer(
- hidden_states,
- attention_mask=attention_mask,
- **kwargs,
- )
- return BaseModelOutput(last_hidden_state=hidden_states)
- class InstructBlipVideoVisionModel(InstructBlipVideoPreTrainedModel):
- main_input_name = "pixel_values"
- config: InstructBlipVideoVisionConfig
- _can_record_outputs = {
- "hidden_states": InstructBlipVideoEncoderLayer,
- "attentions": InstructBlipVideoAttention,
- }
- def __init__(self, config: InstructBlipVideoVisionConfig):
- super().__init__(config)
- self.config = config
- embed_dim = config.hidden_size
- self.embeddings = InstructBlipVideoVisionEmbeddings(config)
- self.encoder = InstructBlipVideoEncoder(config)
- self.post_layernorm = nn.LayerNorm(embed_dim, eps=config.layer_norm_eps)
- self.post_init()
- @check_model_inputs(tie_last_hidden_states=False)
- @auto_docstring
- def forward(
- self,
- pixel_values: Optional[torch.FloatTensor] = None,
- interpolate_pos_encoding: bool = False,
- **kwargs: Unpack[TransformersKwargs],
- ) -> Union[tuple, BaseModelOutputWithPooling]:
- if pixel_values is None:
- raise ValueError("You have to specify pixel_values")
- hidden_states = self.embeddings(pixel_values, interpolate_pos_encoding=interpolate_pos_encoding)
- encoder_outputs: BaseModelOutput = self.encoder(
- inputs_embeds=hidden_states,
- **kwargs,
- )
- last_hidden_state = encoder_outputs.last_hidden_state
- last_hidden_state = self.post_layernorm(last_hidden_state)
- pooled_output = last_hidden_state[:, 0, :]
- pooled_output = self.post_layernorm(pooled_output)
- return BaseModelOutputWithPooling(
- last_hidden_state=last_hidden_state,
- pooler_output=pooled_output,
- )
- def get_input_embeddings(self):
- return self.embeddings
- class InstructBlipVideoQFormerMultiHeadAttention(nn.Module):
- def __init__(self, config, is_cross_attention=False):
- super().__init__()
- self.config = config
- if config.hidden_size % config.num_attention_heads != 0 and not hasattr(config, "embedding_size"):
- raise ValueError(
- "The hidden size (%d) is not a multiple of the number of attention heads (%d)"
- % (config.hidden_size, config.num_attention_heads)
- )
- self.num_attention_heads = config.num_attention_heads
- self.attention_head_size = int(config.hidden_size / config.num_attention_heads)
- self.all_head_size = self.num_attention_heads * self.attention_head_size
- self.query = nn.Linear(config.hidden_size, self.all_head_size)
- if is_cross_attention:
- self.key = nn.Linear(config.encoder_hidden_size, self.all_head_size)
- self.value = nn.Linear(config.encoder_hidden_size, self.all_head_size)
- else:
- self.key = nn.Linear(config.hidden_size, self.all_head_size)
- self.value = nn.Linear(config.hidden_size, self.all_head_size)
- self.dropout = nn.Dropout(config.attention_probs_dropout_prob)
- self.position_embedding_type = getattr(config, "position_embedding_type", "absolute")
- if self.position_embedding_type == "relative_key" or self.position_embedding_type == "relative_key_query":
- self.max_position_embeddings = config.max_position_embeddings
- self.distance_embedding = nn.Embedding(2 * config.max_position_embeddings - 1, self.attention_head_size)
- self.save_attention = False
- def save_attn_gradients(self, attn_gradients):
- self.attn_gradients = attn_gradients
- def get_attn_gradients(self):
- return self.attn_gradients
- def save_attention_map(self, attention_map):
- self.attention_map = attention_map
- def get_attention_map(self):
- return self.attention_map
- def transpose_for_scores(self, x):
- new_x_shape = x.size()[:-1] + (self.num_attention_heads, self.attention_head_size)
- x = x.view(*new_x_shape)
- return x.permute(0, 2, 1, 3)
- def forward(
- self,
- hidden_states,
- attention_mask=None,
- head_mask=None,
- encoder_hidden_states=None,
- encoder_attention_mask=None,
- **kwargs: Unpack[TransformersKwargs],
- ):
- # If this is instantiated as a cross-attention module, the keys
- # and values come from an encoder; the attention mask needs to be
- # such that the encoder's padding tokens are not attended to.
- is_cross_attention = encoder_hidden_states is not None
- if is_cross_attention:
- key_layer = self.transpose_for_scores(self.key(encoder_hidden_states))
- value_layer = self.transpose_for_scores(self.value(encoder_hidden_states))
- attention_mask = encoder_attention_mask
- else:
- key_layer = self.transpose_for_scores(self.key(hidden_states))
- value_layer = self.transpose_for_scores(self.value(hidden_states))
- mixed_query_layer = self.query(hidden_states)
- query_layer = self.transpose_for_scores(mixed_query_layer)
- # Take the dot product between "query" and "key" to get the raw attention scores.
- attention_scores = torch.matmul(query_layer, key_layer.transpose(-1, -2))
- if self.position_embedding_type == "relative_key" or self.position_embedding_type == "relative_key_query":
- seq_length = hidden_states.size()[1]
- position_ids_l = torch.arange(seq_length, dtype=torch.long, device=hidden_states.device).view(-1, 1)
- position_ids_r = torch.arange(seq_length, dtype=torch.long, device=hidden_states.device).view(1, -1)
- distance = position_ids_l - position_ids_r
- positional_embedding = self.distance_embedding(distance + self.max_position_embeddings - 1)
- positional_embedding = positional_embedding.to(dtype=query_layer.dtype) # fp16 compatibility
- if self.position_embedding_type == "relative_key":
- relative_position_scores = torch.einsum("bhld,lrd->bhlr", query_layer, positional_embedding)
- attention_scores = attention_scores + relative_position_scores
- elif self.position_embedding_type == "relative_key_query":
- relative_position_scores_query = torch.einsum("bhld,lrd->bhlr", query_layer, positional_embedding)
- relative_position_scores_key = torch.einsum("bhrd,lrd->bhlr", key_layer, positional_embedding)
- attention_scores = attention_scores + relative_position_scores_query + relative_position_scores_key
- attention_scores = attention_scores / math.sqrt(self.attention_head_size)
- attention_scores_dtype = attention_scores.dtype
- if attention_mask is not None:
- # Apply the attention mask is (precomputed for all layers in BertModel forward() function)
- attention_scores = attention_scores + attention_mask
- # Normalize the attention scores to probabilities.
- attention_probs = nn.Softmax(dim=-1)(attention_scores).to(attention_scores_dtype)
- if is_cross_attention and self.save_attention:
- self.save_attention_map(attention_probs)
- attention_probs.register_hook(self.save_attn_gradients)
- # This is actually dropping out entire tokens to attend to, which might
- # seem a bit unusual, but is taken from the original Transformer paper.
- attention_probs_dropped = self.dropout(attention_probs)
- # Mask heads if we want to
- if head_mask is not None:
- attention_probs_dropped = attention_probs_dropped * head_mask
- context_layer = torch.matmul(attention_probs_dropped, value_layer)
- context_layer = context_layer.permute(0, 2, 1, 3).contiguous()
- new_context_layer_shape = context_layer.size()[:-2] + (self.all_head_size,)
- context_layer = context_layer.view(*new_context_layer_shape)
- return context_layer, attention_probs
- class InstructBlipVideoQFormerSelfOutput(nn.Module):
- def __init__(self, config):
- super().__init__()
- self.dense = nn.Linear(config.hidden_size, config.hidden_size)
- self.LayerNorm = nn.LayerNorm(config.hidden_size, eps=config.layer_norm_eps)
- self.dropout = nn.Dropout(config.hidden_dropout_prob)
- def forward(self, hidden_states: torch.Tensor, input_tensor: torch.Tensor) -> torch.Tensor:
- hidden_states = self.dense(hidden_states)
- hidden_states = self.dropout(hidden_states)
- hidden_states = self.LayerNorm(hidden_states + input_tensor)
- return hidden_states
- class InstructBlipVideoQFormerAttention(nn.Module):
- def __init__(self, config, is_cross_attention=False):
- super().__init__()
- self.attention = InstructBlipVideoQFormerMultiHeadAttention(config, is_cross_attention)
- self.output = InstructBlipVideoQFormerSelfOutput(config)
- self.pruned_heads = set()
- def prune_heads(self, heads):
- if len(heads) == 0:
- return
- heads, index = find_pruneable_heads_and_indices(
- heads, self.attention.num_attention_heads, self.attention.attention_head_size, self.pruned_heads
- )
- # Prune linear layers
- self.attention.query = prune_linear_layer(self.attention.query, index)
- self.attention.key = prune_linear_layer(self.attention.key, index)
- self.attention.value = prune_linear_layer(self.attention.value, index)
- self.output.dense = prune_linear_layer(self.output.dense, index, dim=1)
- # Update hyper params and store pruned heads
- self.attention.num_attention_heads = self.attention.num_attention_heads - len(heads)
- self.attention.all_head_size = self.attention.attention_head_size * self.attention.num_attention_heads
- self.pruned_heads = self.pruned_heads.union(heads)
- def forward(
- self,
- hidden_states: torch.Tensor,
- attention_mask: Optional[torch.FloatTensor] = None,
- head_mask: Optional[torch.FloatTensor] = None,
- encoder_hidden_states: Optional[torch.FloatTensor] = None,
- encoder_attention_mask: Optional[torch.FloatTensor] = None,
- **kwargs: Unpack[TransformersKwargs],
- ) -> torch.Tensor:
- attn_output, _ = self.attention(
- hidden_states=hidden_states,
- attention_mask=attention_mask,
- head_mask=head_mask,
- encoder_hidden_states=encoder_hidden_states,
- encoder_attention_mask=encoder_attention_mask,
- **kwargs,
- )
- attention_output = self.output(attn_output, hidden_states)
- return attention_output
- class InstructBlipVideoQFormerIntermediate(nn.Module):
- def __init__(self, config):
- super().__init__()
- self.dense = nn.Linear(config.hidden_size, config.intermediate_size)
- if isinstance(config.hidden_act, str):
- self.intermediate_act_fn = ACT2FN[config.hidden_act]
- else:
- self.intermediate_act_fn = config.hidden_act
- def forward(self, hidden_states: torch.Tensor) -> torch.Tensor:
- hidden_states = self.dense(hidden_states)
- hidden_states = self.intermediate_act_fn(hidden_states)
- return hidden_states
- class InstructBlipVideoQFormerOutput(nn.Module):
- def __init__(self, config):
- super().__init__()
- self.dense = nn.Linear(config.intermediate_size, config.hidden_size)
- self.LayerNorm = nn.LayerNorm(config.hidden_size, eps=config.layer_norm_eps)
- self.dropout = nn.Dropout(config.hidden_dropout_prob)
- def forward(self, hidden_states: torch.Tensor, input_tensor: torch.Tensor) -> torch.Tensor:
- hidden_states = self.dense(hidden_states)
- hidden_states = self.dropout(hidden_states)
- hidden_states = self.LayerNorm(hidden_states + input_tensor)
- return hidden_states
- class InstructBlipVideoQFormerLayer(GradientCheckpointingLayer):
- def __init__(self, config, layer_idx):
- super().__init__()
- self.chunk_size_feed_forward = config.chunk_size_feed_forward
- self.seq_len_dim = 1
- self.attention = InstructBlipVideoQFormerAttention(config)
- self.layer_idx = layer_idx
- if layer_idx % config.cross_attention_frequency == 0:
- self.crossattention = InstructBlipVideoQFormerAttention(config, is_cross_attention=True)
- self.has_cross_attention = True
- else:
- self.has_cross_attention = False
- self.intermediate = InstructBlipVideoQFormerIntermediate(config)
- self.output = InstructBlipVideoQFormerOutput(config)
- self.intermediate_query = InstructBlipVideoQFormerIntermediate(config)
- self.output_query = InstructBlipVideoQFormerOutput(config)
- def forward(
- self,
- hidden_states,
- attention_mask=None,
- head_mask=None,
- encoder_hidden_states=None,
- encoder_attention_mask=None,
- query_length=0,
- **kwargs: Unpack[TransformersKwargs],
- ):
- attention_output = self.attention(
- hidden_states,
- attention_mask=attention_mask,
- head_mask=head_mask,
- **kwargs,
- )
- if query_length > 0:
- query_attention_output = attention_output[:, :query_length, :]
- if self.has_cross_attention:
- if encoder_hidden_states is None:
- raise ValueError("encoder_hidden_states must be given for cross-attention layers")
- query_attention_output = self.crossattention(
- query_attention_output,
- attention_mask=attention_mask,
- head_mask=head_mask,
- encoder_hidden_states=encoder_hidden_states,
- encoder_attention_mask=encoder_attention_mask,
- **kwargs,
- )
- layer_output = apply_chunking_to_forward(
- self.feed_forward_chunk_query,
- self.chunk_size_feed_forward,
- self.seq_len_dim,
- query_attention_output,
- )
- if attention_output.shape[1] > query_length:
- layer_output_text = apply_chunking_to_forward(
- self.feed_forward_chunk,
- self.chunk_size_feed_forward,
- self.seq_len_dim,
- attention_output[:, query_length:, :],
- ).to(layer_output.device)
- layer_output = torch.cat([layer_output, layer_output_text], dim=1)
- else:
- layer_output = apply_chunking_to_forward(
- self.feed_forward_chunk,
- self.chunk_size_feed_forward,
- self.seq_len_dim,
- attention_output,
- )
- return layer_output
- def feed_forward_chunk(self, attention_output):
- intermediate_output = self.intermediate(attention_output)
- layer_output = self.output(intermediate_output, attention_output)
- return layer_output
- def feed_forward_chunk_query(self, attention_output):
- intermediate_output = self.intermediate_query(attention_output)
- layer_output = self.output_query(intermediate_output, attention_output)
- return layer_output
- class InstructBlipVideoQFormerEncoder(nn.Module):
- def __init__(self, config):
- super().__init__()
- self.config = config
- self.layer = nn.ModuleList(
- [InstructBlipVideoQFormerLayer(config, layer_idx) for layer_idx in range(config.num_hidden_layers)]
- )
- self.gradient_checkpointing = False
- @can_return_tuple
- def forward(
- self,
- hidden_states,
- attention_mask=None,
- head_mask=None,
- encoder_hidden_states=None,
- encoder_attention_mask=None,
- query_length=0,
- **kwargs: Unpack[TransformersKwargs],
- ):
- for i in range(self.config.num_hidden_layers):
- layer_module = self.layer[i]
- layer_head_mask = head_mask[i] if head_mask is not None else None
- hidden_states = layer_module(
- hidden_states,
- attention_mask,
- layer_head_mask,
- encoder_hidden_states, # as a positional argument for gradient checkpointing
- encoder_attention_mask=encoder_attention_mask,
- query_length=query_length,
- **kwargs,
- )
- return BaseModelOutputWithPastAndCrossAttentions(
- last_hidden_state=hidden_states,
- )
- class InstructBlipVideoQFormerEmbeddings(nn.Module):
- """Construct the embeddings from word and position embeddings."""
- def __init__(self, config):
- super().__init__()
- self.word_embeddings = nn.Embedding(config.vocab_size, config.hidden_size, padding_idx=config.pad_token_id)
- self.position_embeddings = nn.Embedding(config.max_position_embeddings, config.hidden_size)
- self.layernorm = nn.LayerNorm(config.hidden_size, eps=config.layer_norm_eps)
- self.dropout = nn.Dropout(config.hidden_dropout_prob)
- # position_ids (1, len position emb) is contiguous in memory and exported when serialized
- self.register_buffer(
- "position_ids", torch.arange(config.max_position_embeddings).expand((1, -1)), persistent=False
- )
- self.position_embedding_type = getattr(config, "position_embedding_type", "absolute")
- self.config = config
- def forward(
- self,
- input_ids=None,
- position_ids=None,
- query_embeds=None,
- past_key_values_length=0,
- ):
- if input_ids is not None:
- seq_length = input_ids.size()[1]
- else:
- seq_length = 0
- if position_ids is None:
- position_ids = self.position_ids[:, past_key_values_length : seq_length + past_key_values_length].clone()
- if input_ids is not None:
- embeddings = self.word_embeddings(input_ids)
- if self.position_embedding_type == "absolute":
- position_embeddings = self.position_embeddings(position_ids.to(embeddings.device))
- embeddings = embeddings + position_embeddings
- if query_embeds is not None:
- embeddings = torch.cat((query_embeds, embeddings), dim=1)
- else:
- embeddings = query_embeds
- embeddings = embeddings.to(self.layernorm.weight.dtype)
- embeddings = self.layernorm(embeddings)
- embeddings = self.dropout(embeddings)
- return embeddings
- class InstructBlipVideoQFormerModel(InstructBlipVideoPreTrainedModel):
- """
- Querying Transformer (Q-Former), used in InstructBlipVideo. Slightly modified from BLIP-2 as it also takes the
- instruction as input.
- """
- _supports_attention_backend = False # adds position on attn weights before last matmul
- _supports_flash_attn = False
- _supports_sdpa = False
- _supports_flex_attn = False
- _can_record_outputs = {
- "hidden_states": InstructBlipVideoQFormerLayer,
- "attentions": [
- OutputRecorder(InstructBlipVideoQFormerMultiHeadAttention, index=1, layer_name=".attention"),
- ],
- "cross_attentions": [
- OutputRecorder(InstructBlipVideoQFormerMultiHeadAttention, index=1, layer_name=".crossattention"),
- ],
- }
- def __init__(self, config: InstructBlipVideoQFormerConfig):
- super().__init__(config)
- self.config = config
- self.embeddings = InstructBlipVideoQFormerEmbeddings(config)
- self.encoder = InstructBlipVideoQFormerEncoder(config)
- self.post_init()
- def get_input_embeddings(self):
- return self.embeddings.word_embeddings
- def set_input_embeddings(self, value):
- self.embeddings.word_embeddings = value
- def _prune_heads(self, heads_to_prune):
- """
- Prunes heads of the model. heads_to_prune: dict of {layer_num: list of heads to prune in this layer} See base
- class PreTrainedModel
- """
- for layer, heads in heads_to_prune.items():
- self.encoder.layer[layer].attention.prune_heads(heads)
- def get_extended_attention_mask(
- self,
- attention_mask: torch.Tensor,
- input_shape: tuple[int],
- device: torch.device,
- has_query: bool = False,
- ) -> torch.Tensor:
- """
- Makes broadcastable attention and causal masks so that future and masked tokens are ignored.
- Arguments:
- attention_mask (`torch.Tensor`):
- Mask with ones indicating tokens to attend to, zeros for tokens to ignore.
- input_shape (`tuple[int]`):
- The shape of the input to the model.
- device: (`torch.device`):
- The device of the input to the model.
- Returns:
- `torch.Tensor` The extended attention mask, with a the same dtype as `attention_mask.dtype`.
- """
- # We can provide a self-attention mask of dimensions [batch_size, from_seq_length, to_seq_length]
- # ourselves in which case we just need to make it broadcastable to all heads.
- if attention_mask.dim() == 3:
- extended_attention_mask = attention_mask[:, None, :, :]
- elif attention_mask.dim() == 2:
- # Provided a padding mask of dimensions [batch_size, seq_length]
- # - the model is an encoder, so make the mask broadcastable to [batch_size, num_heads, seq_length, seq_length]
- extended_attention_mask = attention_mask[:, None, None, :]
- else:
- raise ValueError(
- f"Wrong shape for input_ids (shape {input_shape}) or attention_mask (shape {attention_mask.shape})",
- )
- # Since attention_mask is 1.0 for positions we want to attend and 0.0 for
- # masked positions, this operation will create a tensor which is 0.0 for
- # positions we want to attend and -10000.0 for masked positions.
- # Since we are adding it to the raw scores before the softmax, this is
- # effectively the same as removing these entirely.
- extended_attention_mask = extended_attention_mask.to(dtype=self.dtype) # fp16 compatibility
- extended_attention_mask = (1.0 - extended_attention_mask) * -10000.0
- return extended_attention_mask
- @check_model_inputs()
- @auto_docstring
- def forward(
- self,
- input_ids: torch.LongTensor,
- attention_mask: Optional[torch.FloatTensor] = None,
- position_ids: Optional[torch.LongTensor] = None,
- query_embeds: Optional[torch.Tensor] = None,
- head_mask: Optional[torch.FloatTensor] = None,
- encoder_hidden_states: Optional[torch.FloatTensor] = None,
- encoder_attention_mask: Optional[torch.FloatTensor] = None,
- **kwargs: Unpack[TransformersKwargs],
- ) -> Union[tuple[torch.FloatTensor], BaseModelOutputWithPoolingAndCrossAttentions]:
- r"""
- query_embeds (`torch.FloatTensor` of shape `(batch_size, sequence_length, hidden_size)`):
- Hidden states to be used in the attention computation. If cross-attention,
- will be used for the query (i.e., key and value will use the encoder_hidden_states).
- """
- if input_ids is None and query_embeds is None:
- raise ValueError("You have to specify query_embeds when input_ids is None")
- query_length = query_embeds.shape[1] if query_embeds is not None else 0
- embedding_output = self.embeddings(
- input_ids=input_ids,
- position_ids=position_ids,
- query_embeds=query_embeds,
- )
- input_shape = embedding_output.size()[:-1]
- batch_size, seq_length = input_shape
- device = embedding_output.device
- if attention_mask is None:
- attention_mask = torch.ones(((batch_size, seq_length)), device=device)
- # We can provide a self-attention mask of dimensions [batch_size, from_seq_length, to_seq_length]
- # ourselves in which case we just need to make it broadcastable to all heads.
- extended_attention_mask = self.get_extended_attention_mask(attention_mask, input_shape, device)
- # If a 2D or 3D attention mask is provided for the cross-attention
- # we need to make broadcastable to [batch_size, num_heads, seq_length, seq_length]
- if encoder_hidden_states is not None:
- if isinstance(encoder_hidden_states, list):
- encoder_batch_size, encoder_sequence_length, _ = encoder_hidden_states[0].size()
- else:
- encoder_batch_size, encoder_sequence_length, _ = encoder_hidden_states.size()
- encoder_hidden_shape = (encoder_batch_size, encoder_sequence_length)
- if isinstance(encoder_attention_mask, list):
- encoder_extended_attention_mask = [self.invert_attention_mask(mask) for mask in encoder_attention_mask]
- elif encoder_attention_mask is None:
- encoder_attention_mask = torch.ones(encoder_hidden_shape, device=device)
- encoder_extended_attention_mask = self.invert_attention_mask(encoder_attention_mask)
- else:
- encoder_extended_attention_mask = self.invert_attention_mask(encoder_attention_mask)
- else:
- encoder_extended_attention_mask = None
- # Prepare head mask if needed
- # 1.0 in head_mask indicate we keep the head
- # attention_probs has shape bsz x n_heads x N x N
- # input head_mask has shape [num_heads] or [num_hidden_layers x num_heads]
- # and head_mask is converted to shape [num_hidden_layers x batch x num_heads x seq_length x seq_length]
- head_mask = self.get_head_mask(head_mask, self.config.num_hidden_layers)
- encoder_outputs: BaseModelOutput = self.encoder(
- embedding_output,
- attention_mask=extended_attention_mask,
- head_mask=head_mask,
- encoder_hidden_states=encoder_hidden_states,
- encoder_attention_mask=encoder_extended_attention_mask,
- query_length=query_length,
- **kwargs,
- )
- sequence_output = encoder_outputs.last_hidden_state
- pooled_output = sequence_output[:, 0, :]
- return BaseModelOutputWithPoolingAndCrossAttentions(
- last_hidden_state=sequence_output,
- pooler_output=pooled_output,
- )
- @dataclass
- @auto_docstring(
- custom_intro="""
- Class defining the outputs of [`InstructBlipVideoForConditionalGeneration`].
- """
- )
- class InstructBlipVideoForConditionalGenerationModelOutput(ModelOutput):
- r"""
- loss (`torch.FloatTensor`, *optional*, returned when `labels` is provided, `torch.FloatTensor` of shape `(1,)`):
- Language modeling loss from the language model.
- logits (`torch.FloatTensor` of shape `(batch_size, sequence_length, config.vocab_size)`):
- Prediction scores of the language modeling head of the language model.
- vision_outputs (`BaseModelOutputWithPooling`):
- Outputs of the vision encoder.
- qformer_outputs (`BaseModelOutputWithPoolingAndCrossAttentions`):
- Outputs of the Q-Former (Querying Transformer).
- language_model_outputs (`CausalLMOutputWithPast` or `Seq2SeqLMOutput`):
- Outputs of the language model.
- """
- loss: Optional[tuple[torch.FloatTensor]] = None
- logits: Optional[tuple[torch.FloatTensor]] = None
- vision_outputs: Optional[torch.FloatTensor] = None
- qformer_outputs: Optional[tuple[torch.FloatTensor]] = None
- language_model_outputs: Optional[tuple[torch.FloatTensor]] = None
- def to_tuple(self) -> tuple[Any]:
- return tuple(
- self[k]
- if k not in ["vision_outputs", "qformer_outputs", "language_model_outputs"]
- else getattr(self, k).to_tuple()
- for k in self.keys()
- )
- @auto_docstring(
- custom_intro="""
- InstructBlipVideo base Model consisting of language model, qformer and vision encoder.
- """
- )
- class InstructBlipVideoModel(InstructBlipVideoPreTrainedModel):
- main_input_name = "pixel_values"
- _keep_in_fp32_modules = ["query_tokens"] # TODO @ArthurZucker I don't know why this is required for FP8
- def __init__(self, config: InstructBlipVideoConfig):
- super().__init__(config)
- self.vision_model = InstructBlipVideoVisionModel(config.vision_config)
- self.query_tokens = nn.Parameter(torch.zeros(1, config.num_query_tokens, config.qformer_config.hidden_size))
- self.qformer = InstructBlipVideoQFormerModel(config.qformer_config)
- self.language_projection = nn.Linear(config.qformer_config.hidden_size, config.text_config.hidden_size)
- self.language_model = AutoModel.from_config(config.text_config)
- if self.language_model._no_split_modules is not None:
- self._no_split_modules.extend(self.language_model._no_split_modules)
- if self.language_model._keep_in_fp32_modules is not None:
- self._keep_in_fp32_modules.extend(self.language_model._keep_in_fp32_modules)
- # Initialize weights and apply final processing
- self.post_init()
- def get_input_embeddings(self):
- return self.language_model.get_input_embeddings()
- def set_input_embeddings(self, value):
- self.language_model.set_input_embeddings(value)
- def _tie_weights(self):
- if not self.config.use_decoder_only_language_model:
- self.language_model.encoder.embed_tokens = self.language_model.shared
- self.language_model.decoder.embed_tokens = self.language_model.shared
- def _preprocess_accelerate(self):
- r"""
- Some pre-processing hacks to make the model `accelerate` compatible. Check
- https://github.com/huggingface/transformers/pull/21707 for more details.
- """
- hf_device_map = self.hf_device_map
- if len(hf_device_map) > 1 and "language_model" not in hf_device_map and torch.cuda.device_count() > 1:
- # warn users about unexpected behavior when using multi-GPU + InstructBlipVideo + `accelerate`.
- logger.warning(
- "The `language_model` is not in the `hf_device_map` dictionary and you are running your script"
- " in a multi-GPU environment. this may lead to unexpected behavior when using `accelerate`."
- " Please pass a `device_map` that contains `language_model` to remove this warning."
- " Please refer to https://github.com/huggingface/blog/blob/main/accelerate-large-models.md for"
- " more details on creating a `device_map` for large models.",
- )
- if hasattr(self.language_model, "_hf_hook"):
- self.language_model._hf_hook.io_same_device = True # For `generate` compatibility
- def get_placeholder_mask(self, input_ids: torch.LongTensor, inputs_embeds: torch.FloatTensor):
- """
- Obtains multimodal placeholder mask from `input_ids` or `inputs_embeds`.
- """
- if input_ids is None:
- special_image_mask = inputs_embeds == self.get_input_embeddings()(
- torch.tensor(self.config.image_token_id, dtype=torch.long, device=inputs_embeds.device)
- )
- special_image_mask = special_image_mask.all(-1)
- else:
- special_image_mask = input_ids == self.config.image_token_id
- special_image_mask = special_image_mask.unsqueeze(-1).expand_as(inputs_embeds).to(inputs_embeds.device)
- return special_image_mask
- @can_return_tuple
- @auto_docstring
- def forward(
- self,
- pixel_values: torch.FloatTensor,
- qformer_input_ids: torch.FloatTensor,
- qformer_attention_mask: Optional[torch.LongTensor] = None,
- input_ids: Optional[torch.FloatTensor] = None,
- attention_mask: Optional[torch.LongTensor] = None,
- decoder_input_ids: Optional[torch.LongTensor] = None,
- decoder_attention_mask: Optional[torch.LongTensor] = None,
- inputs_embeds: Optional[torch.Tensor] = None,
- output_attentions: Optional[bool] = None,
- output_hidden_states: Optional[bool] = None,
- return_dict: Optional[bool] = None,
- interpolate_pos_encoding: bool = False,
- use_cache: Optional[bool] = None,
- **kwargs: Unpack[FlashAttentionKwargs],
- ) -> Union[tuple, InstructBlipVideoForConditionalGenerationModelOutput]:
- r"""
- qformer_input_ids (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*):
- Indices of input sequence tokens in the vocabulary of the Q-Former. Input tokens can optionally be provided
- to serve as text prompt, which the Q-Former model will encode.
- Indices can be obtained using [`InstructBlipVideoProcessor`]. See [`InstructBlipVideoProcessor.__call__`] for
- details.
- [What are input IDs?](../glossary#input-ids)
- qformer_attention_mask (`torch.Tensor` of shape `(batch_size, sequence_length)`, *optional*):
- Mask to avoid performing attention on padding token indices. Mask values selected in `[0, 1]`:
- - 1 for tokens that are **not masked**,
- - 0 for tokens that are **masked**.
- [What are attention masks?](../glossary#attention-mask)
- decoder_attention_mask (`torch.BoolTensor` of shape `(batch_size, target_sequence_length)`, *optional*):
- Default behavior: generate a tensor that ignores pad tokens in `decoder_input_ids`. Causal mask will also
- be used by default.
- Only relevant in case an encoder-decoder language model (like T5) is used.
- """
- return_dict = return_dict if return_dict is not None else self.config.use_return_dict
- # step 1: forward the images through the vision encoder,
- # we process in a batched way, later unbatch it back (video has frames=4 always)
- batch_size, frames, channel, height, width = pixel_values.shape
- pixel_values = pixel_values.reshape(batch_size * frames, channel, height, width)
- vision_outputs = self.vision_model(
- pixel_values=pixel_values,
- output_attentions=output_attentions,
- output_hidden_states=output_hidden_states,
- return_dict=return_dict,
- interpolate_pos_encoding=interpolate_pos_encoding,
- )
- image_embeds = vision_outputs[0]
- # step 2: forward the query tokens through the QFormer, using the image embeddings for cross-attention
- image_attention_mask = torch.ones(image_embeds.size()[:-1], dtype=torch.long, device=image_embeds.device)
- # difference with BLIP-2 here: we also feed the instruction prompt to the Q-Former
- query_tokens = self.query_tokens.expand(image_embeds.shape[0], -1, -1)
- query_attention_mask = torch.ones(query_tokens.size()[:-1], dtype=torch.long, device=image_embeds.device)
- if qformer_attention_mask is None:
- qformer_attention_mask = torch.ones_like(qformer_input_ids)
- qformer_input_ids = qformer_input_ids.repeat_interleave(frames, dim=0)
- qformer_attention_mask = qformer_attention_mask.repeat_interleave(frames, dim=0)
- qformer_attention_mask = torch.cat([query_attention_mask, qformer_attention_mask], dim=1)
- query_outputs = self.qformer(
- input_ids=qformer_input_ids,
- attention_mask=qformer_attention_mask,
- query_embeds=query_tokens,
- encoder_hidden_states=image_embeds,
- encoder_attention_mask=image_attention_mask,
- output_attentions=output_attentions,
- output_hidden_states=output_hidden_states,
- return_dict=return_dict,
- )
- query_output = query_outputs[0][:, : query_tokens.size(1), :]
- # step 3: use the language model, conditioned on the query outputs and the prompt
- language_model_inputs = self.language_projection(query_output)
- # unbatch inputs back, each video-frame gets `num_query_tokens` seq length
- language_model_inputs = language_model_inputs.reshape(batch_size, self.config.num_query_tokens * frames, -1)
- if inputs_embeds is None:
- inputs_embeds = self.language_model.get_input_embeddings()(input_ids)
- special_image_mask = input_ids == self.config.video_token_id
- if attention_mask is None:
- attention_mask = torch.ones_like(input_ids)
- else:
- special_image_mask = inputs_embeds == self.get_input_embeddings()(
- torch.tensor(self.config.video_token_id, dtype=torch.long, device=inputs_embeds.device)
- )
- special_image_mask = special_image_mask.all(-1)
- special_image_mask = special_image_mask.unsqueeze(-1).expand_as(inputs_embeds).to(inputs_embeds.device)
- language_model_inputs = language_model_inputs.to(inputs_embeds.device, inputs_embeds.dtype)
- inputs_embeds = inputs_embeds.masked_scatter(special_image_mask, language_model_inputs)
- if self.config.use_decoder_only_language_model:
- outputs = self.language_model(
- inputs_embeds=inputs_embeds,
- attention_mask=attention_mask,
- output_attentions=output_attentions,
- output_hidden_states=output_hidden_states,
- return_dict=return_dict,
- use_cache=use_cache,
- **kwargs,
- )
- else:
- outputs = self.language_model(
- inputs_embeds=inputs_embeds,
- attention_mask=attention_mask,
- decoder_input_ids=decoder_input_ids,
- decoder_attention_mask=decoder_attention_mask,
- output_attentions=output_attentions,
- output_hidden_states=output_hidden_states,
- return_dict=return_dict,
- use_cache=use_cache,
- **kwargs,
- )
- return InstructBlipVideoForConditionalGenerationModelOutput(
- vision_outputs=vision_outputs,
- qformer_outputs=query_outputs,
- language_model_outputs=outputs,
- )
- @auto_docstring(
- custom_intro="""
- InstructBlipVideo Model for generating text given an image and an optional text prompt. The model consists of a vision
- encoder, Querying Transformer (Q-Former) and a language model.
- One can optionally pass `input_ids` to the model, which serve as a text prompt, to make the language model continue
- the prompt. Otherwise, the language model starts generating text from the [BOS] (beginning-of-sequence) token.
- """
- )
- class InstructBlipVideoForConditionalGeneration(InstructBlipVideoPreTrainedModel, GenerationMixin):
- config: InstructBlipVideoConfig
- main_input_name = "pixel_values"
- _can_compile_fullgraph = True
- _keep_in_fp32_modules = ["query_tokens"] # TODO @ArthurZucker I don't know why this is required for FP8
- def __init__(self, config: InstructBlipVideoConfig):
- super().__init__(config)
- self.vision_model = InstructBlipVideoVisionModel._from_config(config.vision_config)
- self.query_tokens = nn.Parameter(torch.zeros(1, config.num_query_tokens, config.qformer_config.hidden_size))
- self.qformer = InstructBlipVideoQFormerModel._from_config(config.qformer_config)
- self.language_projection = nn.Linear(config.qformer_config.hidden_size, config.text_config.hidden_size)
- if config.use_decoder_only_language_model:
- language_model = AutoModelForCausalLM.from_config(config.text_config)
- else:
- language_model = AutoModelForSeq2SeqLM.from_config(config.text_config)
- if language_model._no_split_modules is not None:
- self._no_split_modules.extend(language_model._no_split_modules)
- if language_model._keep_in_fp32_modules is not None:
- self._keep_in_fp32_modules.extend(language_model._keep_in_fp32_modules)
- self.language_model = language_model
- # Initialize weights and apply final processing
- self.post_init()
- def get_input_embeddings(self):
- return self.language_model.get_input_embeddings()
- def set_input_embeddings(self, value):
- self.language_model.set_input_embeddings(value)
- def set_output_embeddings(self, new_embeddings):
- self.language_model.set_output_embeddings(new_embeddings)
- def get_output_embeddings(self) -> nn.Module:
- return self.language_model.get_output_embeddings()
- def get_encoder(self):
- return self.language_model.get_encoder()
- def get_decoder(self):
- return self.language_model.get_decoder()
- def _tie_weights(self):
- if not self.config.use_decoder_only_language_model:
- self.language_model.encoder.embed_tokens = self.language_model.shared
- self.language_model.decoder.embed_tokens = self.language_model.shared
- def _preprocess_accelerate(self):
- r"""
- Some pre-processing hacks to make the model `accelerate` compatible. Check
- https://github.com/huggingface/transformers/pull/21707 for more details.
- """
- hf_device_map = self.hf_device_map
- if len(hf_device_map) > 1 and "language_model" not in hf_device_map and torch.cuda.device_count() > 1:
- # warn users about unexpected behavior when using multi-GPU + InstructBlipVideo + `accelerate`.
- logger.warning(
- "The `language_model` is not in the `hf_device_map` dictionary and you are running your script"
- " in a multi-GPU environment. this may lead to unexpected behavior when using `accelerate`."
- " Please pass a `device_map` that contains `language_model` to remove this warning."
- " Please refer to https://github.com/huggingface/blog/blob/main/accelerate-large-models.md for"
- " more details on creating a `device_map` for large models.",
- )
- if hasattr(self.language_model, "_hf_hook"):
- self.language_model._hf_hook.io_same_device = True # For `generate` compatibility
- def get_image_features(
- self,
- pixel_values: torch.FloatTensor,
- qformer_input_ids: torch.LongTensor,
- qformer_attention_mask: Optional[torch.LongTensor] = None,
- interpolate_pos_encoding: Optional[bool] = False,
- return_dict: Optional[bool] = False,
- ):
- """
- Encodes images into continuous embeddings that can be forwarded to the language model.
- Args:
- pixel_values (`torch.FloatTensor` of shape `(batch_size, num_channels, image_size, image_size)`):
- The tensors corresponding to the input images.
- """
- pass
- def get_placeholder_mask(self, input_ids: torch.LongTensor, inputs_embeds: torch.FloatTensor):
- """
- Obtains multimodal placeholder mask from `input_ids` or `inputs_embeds`.
- """
- if input_ids is None:
- special_image_mask = inputs_embeds == self.get_input_embeddings()(
- torch.tensor(self.config.video_token_id, dtype=torch.long, device=inputs_embeds.device)
- )
- special_image_mask = special_image_mask.all(-1)
- else:
- special_image_mask = input_ids == self.config.video_token_id
- special_image_mask = special_image_mask.unsqueeze(-1).expand_as(inputs_embeds).to(inputs_embeds.device)
- return special_image_mask
- @can_return_tuple
- @auto_docstring
- def forward(
- self,
- pixel_values: torch.FloatTensor,
- qformer_input_ids: torch.FloatTensor,
- qformer_attention_mask: Optional[torch.LongTensor] = None,
- input_ids: Optional[torch.FloatTensor] = None,
- attention_mask: Optional[torch.LongTensor] = None,
- decoder_input_ids: Optional[torch.LongTensor] = None,
- decoder_attention_mask: Optional[torch.LongTensor] = None,
- inputs_embeds: Optional[torch.FloatTensor] = None,
- output_attentions: Optional[bool] = None,
- output_hidden_states: Optional[bool] = None,
- labels: Optional[torch.LongTensor] = None,
- return_dict: Optional[bool] = None,
- interpolate_pos_encoding: bool = False,
- use_cache: Optional[bool] = None,
- **kwargs: Unpack[TransformersKwargs],
- ) -> Union[tuple, InstructBlipVideoForConditionalGenerationModelOutput]:
- r"""
- qformer_input_ids (`torch.LongTensor` of shape (batch_size, sequence_length)):
- The sequence used as a prompt to be fed to the Q-Former module.
- qformer_attention_mask (`torch.LongTensor` of shape (batch_size, sequence_length), *optional*):
- Mask to avoid performing attention on padding token indices.
- Examples:
- ```python
- >>> from transformers import InstructBlipVideoProcessor, InstructBlipVideoForConditionalGeneration
- >>> import torch
- >>> from huggingface_hub import hf_hub_download
- >>> import av
- >>> import numpy as np
- >>> def read_video_pyav(container, indices):
- ... '''
- ... Decode the video with PyAV decoder.
- ... Args:
- ... container (`av.container.input.InputContainer`): PyAV container.
- ... indices (`list[int]`): List of frame indices to decode.
- ... Returns:
- ... result (np.ndarray): np array of decoded frames of shape (num_frames, height, width, 3).
- ... '''
- ... frames = []
- ... container.seek(0)
- ... start_index = indices[0]
- ... end_index = indices[-1]
- ... for i, frame in enumerate(container.decode(video=0)):
- ... if i > end_index:
- ... break
- ... if i >= start_index and i in indices:
- ... frames.append(frame)
- ... return np.stack([x.to_ndarray(format="rgb24") for x in frames])
- >>> model = InstructBlipVideoForConditionalGeneration.from_pretrained("Salesforce/instructblip-vicuna-7b", device_map="auto")
- >>> processor = InstructBlipVideoProcessor.from_pretrained("Salesforce/instructblip-vicuna-7b")
- >>> file_path = hf_hub_download(
- ... repo_id="nielsr/video-demo", filename="eating_spaghetti.mp4", repo_type="dataset"
- ... )
- >>> container = av.open(file_path)
- >>> # sample uniformly 4 frames from the videWhy is this video funny?o
- >>> total_frames = container.streams.video[0].frames
- >>> indices = np.arange(0, total_frames, total_frames / 4).astype(int)
- >>> clip = read_video_pyav(container, indices)
- >>> prompt = "What is happening in the video?"
- >>> inputs = processor(text=prompt, images=clip, return_tensors="pt").to(model.device)
- >>> outputs = model.generate(
- ... **inputs,
- ... do_sample=False,
- ... num_beams=5,
- ... max_length=256,
- ... repetition_penalty=1.5,
- ... length_penalty=1.0,
- ... )
- >>> generated_text = processor.batch_decode(outputs, skip_special_tokens=True)[0].strip()
- >>> print(generated_text)
- "A person is eating a bowl of pasta, and they are using a fork to eat it. The person is sitting at a table, and the plate of pasta is on the table in front"
- ```"""
- return_dict = return_dict if return_dict is not None else self.config.use_return_dict
- language_model_inputs, vision_outputs, query_outputs = self.get_video_features(
- pixel_values,
- qformer_input_ids=qformer_input_ids,
- qformer_attention_mask=qformer_attention_mask,
- interpolate_pos_encoding=interpolate_pos_encoding,
- return_dict=True,
- )
- vision_outputs = vision_outputs.to_tuple() if not return_dict else vision_outputs
- query_outputs = query_outputs.to_tuple() if not return_dict else query_outputs
- if inputs_embeds is None:
- inputs_embeds = self.get_input_embeddings()(input_ids)
- if attention_mask is None:
- attention_mask = torch.ones_like(input_ids)
- language_model_inputs = language_model_inputs.to(inputs_embeds.device, inputs_embeds.dtype)
- special_image_mask = self.get_placeholder_mask(input_ids, inputs_embeds=inputs_embeds)
- inputs_embeds = inputs_embeds.masked_scatter(special_image_mask, language_model_inputs)
- if self.config.use_decoder_only_language_model:
- outputs = self.language_model(
- inputs_embeds=inputs_embeds,
- attention_mask=attention_mask,
- output_attentions=output_attentions,
- output_hidden_states=output_hidden_states,
- return_dict=return_dict,
- use_cache=use_cache,
- **kwargs,
- )
- logits = outputs.logits if return_dict else outputs[0]
- loss = None
- if labels is not None:
- loss = self.loss_function(
- logits=logits, labels=labels, vocab_size=self.config.text_config.vocab_size, **kwargs
- )
- else:
- outputs = self.language_model(
- inputs_embeds=inputs_embeds,
- attention_mask=attention_mask,
- decoder_input_ids=decoder_input_ids,
- decoder_attention_mask=decoder_attention_mask,
- output_attentions=output_attentions,
- output_hidden_states=output_hidden_states,
- return_dict=return_dict,
- labels=labels,
- use_cache=use_cache,
- **kwargs,
- )
- loss = outputs.loss if return_dict else outputs[0]
- logits = outputs.logits if return_dict else outputs[1]
- return InstructBlipVideoForConditionalGenerationModelOutput(
- loss=loss,
- logits=logits,
- vision_outputs=vision_outputs,
- qformer_outputs=query_outputs,
- language_model_outputs=outputs,
- )
- @torch.no_grad()
- def generate(
- self,
- pixel_values: torch.FloatTensor,
- qformer_input_ids: Optional[torch.LongTensor] = None,
- qformer_attention_mask: Optional[torch.LongTensor] = None,
- input_ids: Optional[torch.LongTensor] = None,
- attention_mask: Optional[torch.LongTensor] = None,
- inputs_embeds: Optional[torch.FloatTensor] = None,
- interpolate_pos_encoding: bool = False,
- **generate_kwargs,
- ) -> torch.LongTensor:
- r"""
- Overrides `generate` function to be able to use the model as a conditional generator.
- Args:
- pixel_values (`torch.FloatTensor` of shape (batch_size, num_channels, height, width) or
- (batch_size, num_frames, num_channels, height, width)): Input images or videos to be processed.
- qformer_input_ids (`torch.LongTensor` of shape (batch_size, sequence_length), *optional*):
- The sequence used as a prompt to be fed to the Q-Former module.
- qformer_attention_mask (`torch.LongTensor` of shape (batch_size, sequence_length), *optional*):
- Mask to avoid performing attention on padding token indices.
- input_ids (`torch.LongTensor` of shape (batch_size, sequence_length), *optional*):
- The sequence used as a prompt for the generation.
- attention_mask (`torch.LongTensor` of shape (batch_size, sequence_length), *optional*):
- Mask to avoid performing attention on padding token indices.
- inputs_embeds (`torch.FloatTensor` of shape `(batch_size, sequence_length, hidden_size)`):
- Embedded representation of the inputs. Should be float, not int tokens.
- interpolate_pos_encoding (`bool`, *optional*, defaults to `False`):
- Whether to interpolate the positional encoding of the image embeddings.
- Returns:
- captions (list): A list of strings of length batch_size * num_captions.
- """
- if hasattr(self, "hf_device_map"):
- # preprocess for `accelerate`
- self._preprocess_accelerate()
- batch_size = pixel_values.shape[0]
- language_model_inputs, vision_outputs, query_outputs = self.get_video_features(
- pixel_values,
- qformer_input_ids=qformer_input_ids,
- qformer_attention_mask=qformer_attention_mask,
- interpolate_pos_encoding=interpolate_pos_encoding,
- return_dict=True,
- )
- if inputs_embeds is None:
- if input_ids is None:
- video_tokens = [self.config.video_token_index] * self.config.num_query_tokens * 4
- start_tokens = video_tokens + [self.config.text_config.bos_token_id]
- input_ids = torch.tensor([start_tokens], dtype=torch.long, device=pixel_values.device)
- input_ids = input_ids.repeat(batch_size, 1)
- inputs_embeds = self.get_input_embeddings()(input_ids)
- if attention_mask is None:
- attention_mask = torch.ones_like(input_ids)
- language_model_inputs = language_model_inputs.to(inputs_embeds.device, inputs_embeds.dtype)
- special_image_mask = self.get_placeholder_mask(input_ids, inputs_embeds=inputs_embeds)
- inputs_embeds = inputs_embeds.masked_scatter(special_image_mask, language_model_inputs)
- inputs = {"inputs_embeds": inputs_embeds, "attention_mask": attention_mask}
- if not self.language_model.config.is_encoder_decoder:
- inputs["input_ids"] = input_ids
- outputs = self.language_model.generate(**inputs, **generate_kwargs)
- return outputs
- def get_video_features(
- self,
- pixel_values: torch.FloatTensor,
- qformer_input_ids: torch.LongTensor,
- qformer_attention_mask: Optional[torch.LongTensor] = None,
- interpolate_pos_encoding: Optional[bool] = False,
- return_dict: Optional[bool] = False,
- ):
- """
- Encodes images into continuous embeddings that can be forwarded to the language model.
- Args:
- pixel_values (`torch.FloatTensor` of shape `(batch_size, num_channels, image_size, image_size)`):
- The tensors corresponding to the input images.
- """
- # step 1: forward the images through the vision encoder,
- # we process in a batched way, later unbatch it back (video has frames=4 always)
- batch_size, frames, channel, height, width = pixel_values.shape
- pixel_values = pixel_values.reshape(batch_size * frames, channel, height, width)
- vision_outputs = self.vision_model(
- pixel_values=pixel_values,
- interpolate_pos_encoding=interpolate_pos_encoding,
- return_dict=True,
- )
- image_embeds = vision_outputs[0]
- # step 2: forward the query tokens through the QFormer, using the image embeddings for cross-attention
- image_attention_mask = torch.ones(image_embeds.size()[:-1], dtype=torch.long, device=image_embeds.device)
- # difference with BLIP-2 here: we also feed the instruction prompt to the Q-Former
- query_tokens = self.query_tokens.expand(image_embeds.shape[0], -1, -1)
- query_attention_mask = torch.ones(query_tokens.size()[:-1], dtype=torch.long, device=image_embeds.device)
- if qformer_attention_mask is None:
- qformer_attention_mask = torch.ones_like(qformer_input_ids)
- qformer_input_ids = qformer_input_ids.repeat_interleave(frames, dim=0)
- qformer_attention_mask = qformer_attention_mask.repeat_interleave(frames, dim=0)
- qformer_attention_mask = torch.cat([query_attention_mask, qformer_attention_mask], dim=1)
- query_outputs = self.qformer(
- input_ids=qformer_input_ids,
- attention_mask=qformer_attention_mask,
- query_embeds=query_tokens,
- encoder_hidden_states=image_embeds,
- encoder_attention_mask=image_attention_mask,
- return_dict=True,
- )
- query_output = query_outputs[0][:, : query_tokens.size(1), :]
- # step 3: use the language model, conditioned on the query outputs and the prompt
- language_model_inputs = self.language_projection(query_output)
- # unbatch inputs back, each video-frame gets `num_query_tokens` seq length
- language_model_inputs = language_model_inputs.reshape(batch_size, self.config.num_query_tokens * frames, -1)
- if return_dict:
- return language_model_inputs, vision_outputs, query_outputs
- return language_model_inputs
- __all__ = [
- "InstructBlipVideoVisionModel",
- "InstructBlipVideoPreTrainedModel",
- "InstructBlipVideoQFormerModel",
- "InstructBlipVideoModel",
- "InstructBlipVideoForConditionalGeneration",
- ]
|