modeling_instructblip.py 66 KB

12345678910111213141516171819202122232425262728293031323334353637383940414243444546474849505152535455565758596061626364656667686970717273747576777879808182838485868788899091929394959697989910010110210310410510610710810911011111211311411511611711811912012112212312412512612712812913013113213313413513613713813914014114214314414514614714814915015115215315415515615715815916016116216316416516616716816917017117217317417517617717817918018118218318418518618718818919019119219319419519619719819920020120220320420520620720820921021121221321421521621721821922022122222322422522622722822923023123223323423523623723823924024124224324424524624724824925025125225325425525625725825926026126226326426526626726826927027127227327427527627727827928028128228328428528628728828929029129229329429529629729829930030130230330430530630730830931031131231331431531631731831932032132232332432532632732832933033133233333433533633733833934034134234334434534634734834935035135235335435535635735835936036136236336436536636736836937037137237337437537637737837938038138238338438538638738838939039139239339439539639739839940040140240340440540640740840941041141241341441541641741841942042142242342442542642742842943043143243343443543643743843944044144244344444544644744844945045145245345445545645745845946046146246346446546646746846947047147247347447547647747847948048148248348448548648748848949049149249349449549649749849950050150250350450550650750850951051151251351451551651751851952052152252352452552652752852953053153253353453553653753853954054154254354454554654754854955055155255355455555655755855956056156256356456556656756856957057157257357457557657757857958058158258358458558658758858959059159259359459559659759859960060160260360460560660760860961061161261361461561661761861962062162262362462562662762862963063163263363463563663763863964064164264364464564664764864965065165265365465565665765865966066166266366466566666766866967067167267367467567667767867968068168268368468568668768868969069169269369469569669769869970070170270370470570670770870971071171271371471571671771871972072172272372472572672772872973073173273373473573673773873974074174274374474574674774874975075175275375475575675775875976076176276376476576676776876977077177277377477577677777877978078178278378478578678778878979079179279379479579679779879980080180280380480580680780880981081181281381481581681781881982082182282382482582682782882983083183283383483583683783883984084184284384484584684784884985085185285385485585685785885986086186286386486586686786886987087187287387487587687787887988088188288388488588688788888989089189289389489589689789889990090190290390490590690790890991091191291391491591691791891992092192292392492592692792892993093193293393493593693793893994094194294394494594694794894995095195295395495595695795895996096196296396496596696796896997097197297397497597697797897998098198298398498598698798898999099199299399499599699799899910001001100210031004100510061007100810091010101110121013101410151016101710181019102010211022102310241025102610271028102910301031103210331034103510361037103810391040104110421043104410451046104710481049105010511052105310541055105610571058105910601061106210631064106510661067106810691070107110721073107410751076107710781079108010811082108310841085108610871088108910901091109210931094109510961097109810991100110111021103110411051106110711081109111011111112111311141115111611171118111911201121112211231124112511261127112811291130113111321133113411351136113711381139114011411142114311441145114611471148114911501151115211531154115511561157115811591160116111621163116411651166116711681169117011711172117311741175117611771178117911801181118211831184118511861187118811891190119111921193119411951196119711981199120012011202120312041205120612071208120912101211121212131214121512161217121812191220122112221223122412251226122712281229123012311232123312341235123612371238123912401241124212431244124512461247124812491250125112521253125412551256125712581259126012611262126312641265126612671268126912701271127212731274127512761277127812791280128112821283128412851286128712881289129012911292129312941295129612971298129913001301130213031304130513061307130813091310131113121313131413151316131713181319132013211322132313241325132613271328132913301331133213331334133513361337133813391340134113421343134413451346134713481349135013511352135313541355135613571358135913601361136213631364136513661367136813691370137113721373137413751376137713781379138013811382138313841385138613871388138913901391139213931394139513961397139813991400140114021403140414051406140714081409141014111412141314141415141614171418141914201421142214231424142514261427142814291430143114321433143414351436143714381439144014411442144314441445144614471448144914501451145214531454145514561457145814591460146114621463146414651466146714681469147014711472147314741475147614771478147914801481148214831484148514861487148814891490149114921493149414951496149714981499150015011502150315041505150615071508150915101511151215131514151515161517151815191520152115221523152415251526152715281529153015311532
  1. # coding=utf-8
  2. # Copyright 2023 The Salesforce Authors and The HuggingFace Team. All rights reserved.
  3. #
  4. # Licensed under the Apache License, Version 2.0 (the "License");
  5. # you may not use this file except in compliance with the License.
  6. # You may obtain a copy of the License at
  7. #
  8. # http://www.apache.org/licenses/LICENSE-2.0
  9. #
  10. # Unless required by applicable law or agreed to in writing, software
  11. # distributed under the License is distributed on an "AS IS" BASIS,
  12. # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  13. # See the License for the specific language governing permissions and
  14. # limitations under the License.
  15. """PyTorch InstructBLIP model."""
  16. import math
  17. from dataclasses import dataclass
  18. from typing import Any, Callable, Optional, Union
  19. import torch
  20. from torch import nn
  21. from ...activations import ACT2FN
  22. from ...generation import GenerationMixin
  23. from ...modeling_flash_attention_utils import FlashAttentionKwargs
  24. from ...modeling_layers import GradientCheckpointingLayer
  25. from ...modeling_outputs import (
  26. BaseModelOutput,
  27. BaseModelOutputWithPastAndCrossAttentions,
  28. BaseModelOutputWithPooling,
  29. BaseModelOutputWithPoolingAndCrossAttentions,
  30. )
  31. from ...modeling_utils import ALL_ATTENTION_FUNCTIONS, PreTrainedModel
  32. from ...processing_utils import Unpack
  33. from ...pytorch_utils import apply_chunking_to_forward, find_pruneable_heads_and_indices, prune_linear_layer
  34. from ...utils import ModelOutput, TransformersKwargs, auto_docstring, can_return_tuple, logging, torch_int
  35. from ...utils.generic import OutputRecorder, check_model_inputs
  36. from ..auto import AutoModel, AutoModelForCausalLM, AutoModelForSeq2SeqLM
  37. from .configuration_instructblip import InstructBlipConfig, InstructBlipQFormerConfig, InstructBlipVisionConfig
  38. logger = logging.get_logger(__name__)
  39. @dataclass
  40. @auto_docstring(
  41. custom_intro="""
  42. Class defining the outputs of [`InstructBlipForConditionalGeneration`].
  43. """
  44. )
  45. # Copied from transformers.models.blip_2.modeling_blip_2.Blip2ForConditionalGenerationModelOutput with Blip2->InstructBlip
  46. class InstructBlipForConditionalGenerationModelOutput(ModelOutput):
  47. r"""
  48. loss (`torch.FloatTensor`, *optional*, returned when `labels` is provided, `torch.FloatTensor` of shape `(1,)`):
  49. Language modeling loss from the language model.
  50. logits (`torch.FloatTensor` of shape `(batch_size, sequence_length, config.vocab_size)`):
  51. Prediction scores of the language modeling head of the language model.
  52. vision_outputs (`BaseModelOutputWithPooling`):
  53. Outputs of the vision encoder.
  54. qformer_outputs (`BaseModelOutputWithPoolingAndCrossAttentions`):
  55. Outputs of the Q-Former (Querying Transformer).
  56. language_model_outputs (`CausalLMOutputWithPast` or `Seq2SeqLMOutput`):
  57. Outputs of the language model.
  58. """
  59. loss: Optional[tuple[torch.FloatTensor]] = None
  60. logits: Optional[tuple[torch.FloatTensor]] = None
  61. vision_outputs: Optional[torch.FloatTensor] = None
  62. qformer_outputs: Optional[tuple[torch.FloatTensor]] = None
  63. language_model_outputs: Optional[tuple[torch.FloatTensor]] = None
  64. def to_tuple(self) -> tuple[Any]:
  65. return tuple(
  66. self[k]
  67. if k not in ["vision_outputs", "qformer_outputs", "language_model_outputs"]
  68. else getattr(self, k).to_tuple()
  69. for k in self.keys()
  70. )
  71. # Copied from transformers.models.blip.modeling_blip.BlipVisionEmbeddings with Blip->InstructBlip
  72. class InstructBlipVisionEmbeddings(nn.Module):
  73. def __init__(self, config: InstructBlipVisionConfig):
  74. super().__init__()
  75. self.config = config
  76. self.embed_dim = config.hidden_size
  77. self.image_size = config.image_size
  78. self.patch_size = config.patch_size
  79. self.class_embedding = nn.Parameter(torch.randn(1, 1, self.embed_dim))
  80. self.patch_embedding = nn.Conv2d(
  81. in_channels=3, out_channels=self.embed_dim, kernel_size=self.patch_size, stride=self.patch_size
  82. )
  83. self.num_patches = (self.image_size // self.patch_size) ** 2
  84. self.num_positions = self.num_patches + 1
  85. self.position_embedding = nn.Parameter(torch.randn(1, self.num_positions, self.embed_dim))
  86. def interpolate_pos_encoding(self, embeddings: torch.Tensor, height: int, width: int) -> torch.Tensor:
  87. """
  88. This method allows to interpolate the pre-trained position encodings, to be able to use the model on higher resolution
  89. images. This method is also adapted to support torch.jit tracing.
  90. Adapted from:
  91. - https://github.com/facebookresearch/dino/blob/de9ee3df6cf39fac952ab558447af1fa1365362a/vision_transformer.py#L174-L194, and
  92. - https://github.com/facebookresearch/dinov2/blob/e1277af2ba9496fbadf7aec6eba56e8d882d1e35/dinov2/models/vision_transformer.py#L179-L211
  93. """
  94. num_patches = embeddings.shape[1] - 1
  95. num_positions = self.position_embedding.shape[1] - 1
  96. # always interpolate when tracing to ensure the exported model works for dynamic input shapes
  97. if not torch.jit.is_tracing() and num_patches == num_positions and height == width:
  98. return self.position_embedding
  99. class_pos_embed = self.position_embedding[:, :1]
  100. patch_pos_embed = self.position_embedding[:, 1:]
  101. dim = embeddings.shape[-1]
  102. new_height = height // self.patch_size
  103. new_width = width // self.patch_size
  104. sqrt_num_positions = torch_int(num_positions**0.5)
  105. patch_pos_embed = patch_pos_embed.reshape(1, sqrt_num_positions, sqrt_num_positions, dim)
  106. patch_pos_embed = patch_pos_embed.permute(0, 3, 1, 2)
  107. patch_pos_embed = nn.functional.interpolate(
  108. patch_pos_embed,
  109. size=(new_height, new_width),
  110. mode="bicubic",
  111. align_corners=False,
  112. )
  113. patch_pos_embed = patch_pos_embed.permute(0, 2, 3, 1).view(1, -1, dim)
  114. return torch.cat((class_pos_embed, patch_pos_embed), dim=1)
  115. def forward(self, pixel_values: torch.FloatTensor, interpolate_pos_encoding: bool = False) -> torch.Tensor:
  116. batch_size, _, height, width = pixel_values.shape
  117. target_dtype = self.patch_embedding.weight.dtype
  118. patch_embeds = self.patch_embedding(pixel_values.to(dtype=target_dtype)) # shape = [*, width, grid, grid]
  119. patch_embeds = patch_embeds.flatten(2).transpose(1, 2)
  120. class_embeds = self.class_embedding.expand(batch_size, 1, -1).to(target_dtype)
  121. embeddings = torch.cat([class_embeds, patch_embeds], dim=1)
  122. if interpolate_pos_encoding:
  123. position_embedding = self.interpolate_pos_encoding(embeddings, height, width)
  124. else:
  125. position_embedding = self.position_embedding
  126. embeddings = embeddings + position_embedding[:, : embeddings.size(1), :].to(target_dtype)
  127. return embeddings
  128. # Adapted from transformers.models.siglip.modeling_siglip.eager_attention_forward -> InstructBLIP doesn't cast attn weights to fp32
  129. def eager_attention_forward(
  130. module: nn.Module,
  131. query: torch.Tensor,
  132. key: torch.Tensor,
  133. value: torch.Tensor,
  134. attention_mask: Optional[torch.Tensor],
  135. scaling: float,
  136. dropout: float = 0.0,
  137. **kwargs,
  138. ):
  139. attn_weights = torch.matmul(query, key.transpose(-1, -2)) * scaling
  140. if attention_mask is not None:
  141. attn_weights = attn_weights + attention_mask
  142. attn_weights = nn.functional.softmax(attn_weights, dim=-1)
  143. attn_weights = nn.functional.dropout(attn_weights, p=dropout, training=module.training)
  144. attn_output = torch.matmul(attn_weights, value)
  145. attn_output = attn_output.transpose(1, 2).contiguous()
  146. return attn_output, attn_weights
  147. # Copied from transformers.models.blip_2.modeling_blip_2.Blip2Attention with Blip2->InstructBlip
  148. class InstructBlipAttention(nn.Module):
  149. """Multi-headed attention from 'Attention Is All You Need' paper"""
  150. def __init__(self, config):
  151. super().__init__()
  152. self.config = config
  153. self.embed_dim = config.hidden_size
  154. self.num_heads = config.num_attention_heads
  155. self.head_dim = self.embed_dim // self.num_heads
  156. if self.head_dim * self.num_heads != self.embed_dim:
  157. raise ValueError(
  158. f"embed_dim must be divisible by num_heads (got `embed_dim`: {self.embed_dim} and `num_heads`:"
  159. f" {self.num_heads})."
  160. )
  161. self.scale = self.head_dim**-0.5
  162. self.is_causal = False
  163. self.attention_dropout = config.attention_dropout
  164. # small tweak here compared to CLIP, no bias here
  165. self.qkv = nn.Linear(self.embed_dim, 3 * self.embed_dim, bias=False)
  166. if config.qkv_bias:
  167. q_bias = nn.Parameter(torch.zeros(self.embed_dim))
  168. v_bias = nn.Parameter(torch.zeros(self.embed_dim))
  169. else:
  170. q_bias = None
  171. v_bias = None
  172. if q_bias is not None:
  173. qkv_bias = torch.cat((q_bias, torch.zeros_like(v_bias, requires_grad=False), v_bias))
  174. self.qkv.bias = nn.Parameter(qkv_bias)
  175. self.projection = nn.Linear(self.embed_dim, self.embed_dim)
  176. def _shape(self, tensor: torch.Tensor, seq_len: int, bsz: int):
  177. return tensor.view(bsz, seq_len, self.num_heads, self.head_dim).transpose(1, 2).contiguous()
  178. def forward(
  179. self,
  180. hidden_states: torch.Tensor,
  181. head_mask: Optional[torch.Tensor] = None,
  182. **kwargs,
  183. ) -> tuple[torch.Tensor, Optional[torch.Tensor], Optional[tuple[torch.Tensor]]]:
  184. """Input shape: Batch x Time x Channel"""
  185. bsz, tgt_len, embed_dim = hidden_states.size()
  186. mixed_qkv = self.qkv(hidden_states)
  187. mixed_qkv = mixed_qkv.reshape(bsz, tgt_len, 3, self.num_heads, embed_dim // self.num_heads).permute(
  188. 2, 0, 3, 1, 4
  189. )
  190. query_states, key_states, value_states = mixed_qkv[0], mixed_qkv[1], mixed_qkv[2]
  191. attention_interface: Callable = eager_attention_forward
  192. if self.config._attn_implementation != "eager":
  193. attention_interface = ALL_ATTENTION_FUNCTIONS[self.config._attn_implementation]
  194. attn_output, attn_weights = attention_interface(
  195. self,
  196. query_states,
  197. key_states,
  198. value_states,
  199. attention_mask=None,
  200. dropout=0.0 if not self.training else self.attention_dropout,
  201. scaling=self.scale,
  202. **kwargs,
  203. )
  204. attn_output = attn_output.reshape(bsz, tgt_len, -1).contiguous()
  205. attn_output = self.projection(attn_output)
  206. return attn_output, attn_weights
  207. # Copied from transformers.models.blip.modeling_blip.BlipMLP
  208. class InstructBlipMLP(nn.Module):
  209. def __init__(self, config):
  210. super().__init__()
  211. self.config = config
  212. self.activation_fn = ACT2FN[config.hidden_act]
  213. self.fc1 = nn.Linear(config.hidden_size, config.intermediate_size)
  214. self.fc2 = nn.Linear(config.intermediate_size, config.hidden_size)
  215. def forward(self, hidden_states: torch.Tensor) -> torch.Tensor:
  216. hidden_states = self.fc1(hidden_states)
  217. hidden_states = self.activation_fn(hidden_states)
  218. hidden_states = self.fc2(hidden_states)
  219. return hidden_states
  220. # Copied from transformers.models.blip.modeling_blip.BlipEncoderLayer with Blip->InstructBlip
  221. class InstructBlipEncoderLayer(GradientCheckpointingLayer):
  222. def __init__(self, config: InstructBlipConfig):
  223. super().__init__()
  224. self.embed_dim = config.hidden_size
  225. self.self_attn = InstructBlipAttention(config)
  226. self.layer_norm1 = nn.LayerNorm(self.embed_dim, eps=config.layer_norm_eps)
  227. self.mlp = InstructBlipMLP(config)
  228. self.layer_norm2 = nn.LayerNorm(self.embed_dim, eps=config.layer_norm_eps)
  229. @auto_docstring
  230. def forward(
  231. self,
  232. hidden_states: torch.Tensor,
  233. attention_mask: torch.Tensor,
  234. **kwargs: Unpack[TransformersKwargs],
  235. ) -> torch.FloatTensor:
  236. residual = hidden_states
  237. hidden_states = self.layer_norm1(hidden_states)
  238. hidden_states, _ = self.self_attn(
  239. hidden_states=hidden_states,
  240. head_mask=attention_mask,
  241. **kwargs,
  242. )
  243. hidden_states = hidden_states + residual
  244. residual = hidden_states
  245. hidden_states = self.layer_norm2(hidden_states)
  246. hidden_states = self.mlp(hidden_states)
  247. hidden_states = hidden_states + residual
  248. return hidden_states
  249. @auto_docstring
  250. class InstructBlipPreTrainedModel(PreTrainedModel):
  251. config: InstructBlipConfig
  252. base_model_prefix = "blip"
  253. supports_gradient_checkpointing = True
  254. _supports_attention_backend = True
  255. _supports_flash_attn = True
  256. _supports_sdpa = True
  257. _supports_flex_attn = True
  258. _can_compile_fullgraph = True
  259. _no_split_modules = [
  260. "InstructBlipQFormerEmbeddings",
  261. "InstructBlipAttention",
  262. "InstructBlipQFormerMultiHeadAttention",
  263. "InstructBlipQFormerSelfOutput",
  264. ]
  265. def _init_weights(self, module):
  266. """Initialize the weights"""
  267. factor = self.config.initializer_range
  268. if isinstance(module, (nn.Linear, nn.Conv2d)):
  269. module.weight.data.normal_(mean=0.0, std=factor)
  270. if module.bias is not None:
  271. module.bias.data.zero_()
  272. elif isinstance(module, nn.Embedding):
  273. module.weight.data.normal_(mean=0.0, std=factor)
  274. elif isinstance(module, nn.LayerNorm):
  275. module.bias.data.zero_()
  276. module.weight.data.fill_(1.0)
  277. elif isinstance(module, InstructBlipVisionEmbeddings):
  278. nn.init.trunc_normal_(module.position_embedding, mean=0.0, std=factor)
  279. nn.init.trunc_normal_(module.class_embedding, mean=0.0, std=factor)
  280. elif isinstance(module, (InstructBlipForConditionalGeneration, InstructBlipModel)):
  281. module.query_tokens.data.zero_()
  282. # Copied from transformers.models.blip.modeling_blip.BlipEncoder with Blip->InstructBlip
  283. class InstructBlipEncoder(nn.Module):
  284. """
  285. Transformer encoder consisting of `config.num_hidden_layers` self attention layers. Each layer is a
  286. [`InstructBlipEncoderLayer`].
  287. Args:
  288. config (`InstructBlipConfig`):
  289. The corresponding vision configuration for the `InstructBlipEncoder`.
  290. """
  291. def __init__(self, config: InstructBlipConfig):
  292. super().__init__()
  293. self.config = config
  294. self.layers = nn.ModuleList([InstructBlipEncoderLayer(config) for _ in range(config.num_hidden_layers)])
  295. self.gradient_checkpointing = False
  296. @auto_docstring
  297. def forward(
  298. self,
  299. inputs_embeds,
  300. attention_mask: Optional[torch.Tensor] = None,
  301. **kwargs: Unpack[TransformersKwargs],
  302. ) -> Union[tuple, BaseModelOutput]:
  303. hidden_states = inputs_embeds
  304. for encoder_layer in self.layers:
  305. hidden_states = encoder_layer(
  306. hidden_states,
  307. attention_mask=attention_mask,
  308. **kwargs,
  309. )
  310. return BaseModelOutput(last_hidden_state=hidden_states)
  311. # Copied from transformers.models.blip.modeling_blip.BlipVisionModel with Blip->InstructBlip, BLIP->INSTRUCTBLIP
  312. class InstructBlipVisionModel(InstructBlipPreTrainedModel):
  313. main_input_name = "pixel_values"
  314. config: InstructBlipVisionConfig
  315. _can_record_outputs = {
  316. "hidden_states": InstructBlipEncoderLayer,
  317. "attentions": InstructBlipAttention,
  318. }
  319. def __init__(self, config: InstructBlipVisionConfig):
  320. super().__init__(config)
  321. self.config = config
  322. embed_dim = config.hidden_size
  323. self.embeddings = InstructBlipVisionEmbeddings(config)
  324. self.encoder = InstructBlipEncoder(config)
  325. self.post_layernorm = nn.LayerNorm(embed_dim, eps=config.layer_norm_eps)
  326. self.post_init()
  327. @check_model_inputs(tie_last_hidden_states=False)
  328. @auto_docstring
  329. def forward(
  330. self,
  331. pixel_values: Optional[torch.FloatTensor] = None,
  332. interpolate_pos_encoding: bool = False,
  333. **kwargs: Unpack[TransformersKwargs],
  334. ) -> Union[tuple, BaseModelOutputWithPooling]:
  335. if pixel_values is None:
  336. raise ValueError("You have to specify pixel_values")
  337. hidden_states = self.embeddings(pixel_values, interpolate_pos_encoding=interpolate_pos_encoding)
  338. encoder_outputs: BaseModelOutput = self.encoder(
  339. inputs_embeds=hidden_states,
  340. **kwargs,
  341. )
  342. last_hidden_state = encoder_outputs.last_hidden_state
  343. last_hidden_state = self.post_layernorm(last_hidden_state)
  344. pooled_output = last_hidden_state[:, 0, :]
  345. pooled_output = self.post_layernorm(pooled_output)
  346. return BaseModelOutputWithPooling(
  347. last_hidden_state=last_hidden_state,
  348. pooler_output=pooled_output,
  349. )
  350. def get_input_embeddings(self):
  351. return self.embeddings
  352. class InstructBlipQFormerMultiHeadAttention(nn.Module):
  353. def __init__(self, config, is_cross_attention=False):
  354. super().__init__()
  355. self.config = config
  356. if config.hidden_size % config.num_attention_heads != 0 and not hasattr(config, "embedding_size"):
  357. raise ValueError(
  358. "The hidden size (%d) is not a multiple of the number of attention heads (%d)"
  359. % (config.hidden_size, config.num_attention_heads)
  360. )
  361. self.num_attention_heads = config.num_attention_heads
  362. self.attention_head_size = int(config.hidden_size / config.num_attention_heads)
  363. self.all_head_size = self.num_attention_heads * self.attention_head_size
  364. self.query = nn.Linear(config.hidden_size, self.all_head_size)
  365. if is_cross_attention:
  366. self.key = nn.Linear(config.encoder_hidden_size, self.all_head_size)
  367. self.value = nn.Linear(config.encoder_hidden_size, self.all_head_size)
  368. else:
  369. self.key = nn.Linear(config.hidden_size, self.all_head_size)
  370. self.value = nn.Linear(config.hidden_size, self.all_head_size)
  371. self.dropout = nn.Dropout(config.attention_probs_dropout_prob)
  372. self.position_embedding_type = getattr(config, "position_embedding_type", "absolute")
  373. if self.position_embedding_type == "relative_key" or self.position_embedding_type == "relative_key_query":
  374. self.max_position_embeddings = config.max_position_embeddings
  375. self.distance_embedding = nn.Embedding(2 * config.max_position_embeddings - 1, self.attention_head_size)
  376. self.save_attention = False
  377. def save_attn_gradients(self, attn_gradients):
  378. self.attn_gradients = attn_gradients
  379. def get_attn_gradients(self):
  380. return self.attn_gradients
  381. def save_attention_map(self, attention_map):
  382. self.attention_map = attention_map
  383. def get_attention_map(self):
  384. return self.attention_map
  385. def transpose_for_scores(self, x):
  386. new_x_shape = x.size()[:-1] + (self.num_attention_heads, self.attention_head_size)
  387. x = x.view(*new_x_shape)
  388. return x.permute(0, 2, 1, 3)
  389. def forward(
  390. self,
  391. hidden_states,
  392. attention_mask=None,
  393. head_mask=None,
  394. encoder_hidden_states=None,
  395. encoder_attention_mask=None,
  396. **kwargs: Unpack[TransformersKwargs],
  397. ):
  398. # If this is instantiated as a cross-attention module, the keys
  399. # and values come from an encoder; the attention mask needs to be
  400. # such that the encoder's padding tokens are not attended to.
  401. is_cross_attention = encoder_hidden_states is not None
  402. if is_cross_attention:
  403. key_layer = self.transpose_for_scores(self.key(encoder_hidden_states))
  404. value_layer = self.transpose_for_scores(self.value(encoder_hidden_states))
  405. attention_mask = encoder_attention_mask
  406. else:
  407. key_layer = self.transpose_for_scores(self.key(hidden_states))
  408. value_layer = self.transpose_for_scores(self.value(hidden_states))
  409. mixed_query_layer = self.query(hidden_states)
  410. query_layer = self.transpose_for_scores(mixed_query_layer)
  411. # Take the dot product between "query" and "key" to get the raw attention scores.
  412. attention_scores = torch.matmul(query_layer, key_layer.transpose(-1, -2))
  413. if self.position_embedding_type == "relative_key" or self.position_embedding_type == "relative_key_query":
  414. seq_length = hidden_states.size()[1]
  415. position_ids_l = torch.arange(seq_length, dtype=torch.long, device=hidden_states.device).view(-1, 1)
  416. position_ids_r = torch.arange(seq_length, dtype=torch.long, device=hidden_states.device).view(1, -1)
  417. distance = position_ids_l - position_ids_r
  418. positional_embedding = self.distance_embedding(distance + self.max_position_embeddings - 1)
  419. positional_embedding = positional_embedding.to(dtype=query_layer.dtype) # fp16 compatibility
  420. if self.position_embedding_type == "relative_key":
  421. relative_position_scores = torch.einsum("bhld,lrd->bhlr", query_layer, positional_embedding)
  422. attention_scores = attention_scores + relative_position_scores
  423. elif self.position_embedding_type == "relative_key_query":
  424. relative_position_scores_query = torch.einsum("bhld,lrd->bhlr", query_layer, positional_embedding)
  425. relative_position_scores_key = torch.einsum("bhrd,lrd->bhlr", key_layer, positional_embedding)
  426. attention_scores = attention_scores + relative_position_scores_query + relative_position_scores_key
  427. attention_scores = attention_scores / math.sqrt(self.attention_head_size)
  428. attention_scores_dtype = attention_scores.dtype
  429. if attention_mask is not None:
  430. # Apply the attention mask is (precomputed for all layers in BertModel forward() function)
  431. attention_scores = attention_scores + attention_mask
  432. # Normalize the attention scores to probabilities.
  433. attention_probs = nn.Softmax(dim=-1)(attention_scores).to(attention_scores_dtype)
  434. if is_cross_attention and self.save_attention:
  435. self.save_attention_map(attention_probs)
  436. attention_probs.register_hook(self.save_attn_gradients)
  437. # This is actually dropping out entire tokens to attend to, which might
  438. # seem a bit unusual, but is taken from the original Transformer paper.
  439. attention_probs_dropped = self.dropout(attention_probs)
  440. # Mask heads if we want to
  441. if head_mask is not None:
  442. attention_probs_dropped = attention_probs_dropped * head_mask
  443. context_layer = torch.matmul(attention_probs_dropped, value_layer)
  444. context_layer = context_layer.permute(0, 2, 1, 3).contiguous()
  445. new_context_layer_shape = context_layer.size()[:-2] + (self.all_head_size,)
  446. context_layer = context_layer.view(*new_context_layer_shape)
  447. return context_layer, attention_probs
  448. # Copied from transformers.models.bert.modeling_bert.BertSelfOutput with Bert->InstructBlipQFormer
  449. class InstructBlipQFormerSelfOutput(nn.Module):
  450. def __init__(self, config):
  451. super().__init__()
  452. self.dense = nn.Linear(config.hidden_size, config.hidden_size)
  453. self.LayerNorm = nn.LayerNorm(config.hidden_size, eps=config.layer_norm_eps)
  454. self.dropout = nn.Dropout(config.hidden_dropout_prob)
  455. def forward(self, hidden_states: torch.Tensor, input_tensor: torch.Tensor) -> torch.Tensor:
  456. hidden_states = self.dense(hidden_states)
  457. hidden_states = self.dropout(hidden_states)
  458. hidden_states = self.LayerNorm(hidden_states + input_tensor)
  459. return hidden_states
  460. # Copied from transformers.models.blip_2.modeling_blip_2.Blip2QFormerAttention with Blip2->InstructBlip
  461. class InstructBlipQFormerAttention(nn.Module):
  462. def __init__(self, config, is_cross_attention=False):
  463. super().__init__()
  464. self.attention = InstructBlipQFormerMultiHeadAttention(config, is_cross_attention)
  465. self.output = InstructBlipQFormerSelfOutput(config)
  466. self.pruned_heads = set()
  467. def prune_heads(self, heads):
  468. if len(heads) == 0:
  469. return
  470. heads, index = find_pruneable_heads_and_indices(
  471. heads, self.attention.num_attention_heads, self.attention.attention_head_size, self.pruned_heads
  472. )
  473. # Prune linear layers
  474. self.attention.query = prune_linear_layer(self.attention.query, index)
  475. self.attention.key = prune_linear_layer(self.attention.key, index)
  476. self.attention.value = prune_linear_layer(self.attention.value, index)
  477. self.output.dense = prune_linear_layer(self.output.dense, index, dim=1)
  478. # Update hyper params and store pruned heads
  479. self.attention.num_attention_heads = self.attention.num_attention_heads - len(heads)
  480. self.attention.all_head_size = self.attention.attention_head_size * self.attention.num_attention_heads
  481. self.pruned_heads = self.pruned_heads.union(heads)
  482. def forward(
  483. self,
  484. hidden_states: torch.Tensor,
  485. attention_mask: Optional[torch.FloatTensor] = None,
  486. head_mask: Optional[torch.FloatTensor] = None,
  487. encoder_hidden_states: Optional[torch.FloatTensor] = None,
  488. encoder_attention_mask: Optional[torch.FloatTensor] = None,
  489. **kwargs: Unpack[TransformersKwargs],
  490. ) -> torch.Tensor:
  491. attn_output, _ = self.attention(
  492. hidden_states=hidden_states,
  493. attention_mask=attention_mask,
  494. head_mask=head_mask,
  495. encoder_hidden_states=encoder_hidden_states,
  496. encoder_attention_mask=encoder_attention_mask,
  497. **kwargs,
  498. )
  499. attention_output = self.output(attn_output, hidden_states)
  500. return attention_output
  501. # Copied from transformers.models.bert.modeling_bert.BertIntermediate with Bert->InstructBlipQFormer
  502. class InstructBlipQFormerIntermediate(nn.Module):
  503. def __init__(self, config):
  504. super().__init__()
  505. self.dense = nn.Linear(config.hidden_size, config.intermediate_size)
  506. if isinstance(config.hidden_act, str):
  507. self.intermediate_act_fn = ACT2FN[config.hidden_act]
  508. else:
  509. self.intermediate_act_fn = config.hidden_act
  510. def forward(self, hidden_states: torch.Tensor) -> torch.Tensor:
  511. hidden_states = self.dense(hidden_states)
  512. hidden_states = self.intermediate_act_fn(hidden_states)
  513. return hidden_states
  514. # Copied from transformers.models.bert.modeling_bert.BertOutput with Bert->InstructBlipQFormer
  515. class InstructBlipQFormerOutput(nn.Module):
  516. def __init__(self, config):
  517. super().__init__()
  518. self.dense = nn.Linear(config.intermediate_size, config.hidden_size)
  519. self.LayerNorm = nn.LayerNorm(config.hidden_size, eps=config.layer_norm_eps)
  520. self.dropout = nn.Dropout(config.hidden_dropout_prob)
  521. def forward(self, hidden_states: torch.Tensor, input_tensor: torch.Tensor) -> torch.Tensor:
  522. hidden_states = self.dense(hidden_states)
  523. hidden_states = self.dropout(hidden_states)
  524. hidden_states = self.LayerNorm(hidden_states + input_tensor)
  525. return hidden_states
  526. class InstructBlipQFormerLayer(GradientCheckpointingLayer):
  527. def __init__(self, config, layer_idx):
  528. super().__init__()
  529. self.chunk_size_feed_forward = config.chunk_size_feed_forward
  530. self.seq_len_dim = 1
  531. self.attention = InstructBlipQFormerAttention(config)
  532. self.layer_idx = layer_idx
  533. if layer_idx % config.cross_attention_frequency == 0:
  534. self.crossattention = InstructBlipQFormerAttention(config, is_cross_attention=True)
  535. self.has_cross_attention = True
  536. else:
  537. self.has_cross_attention = False
  538. self.intermediate = InstructBlipQFormerIntermediate(config)
  539. self.output = InstructBlipQFormerOutput(config)
  540. self.intermediate_query = InstructBlipQFormerIntermediate(config)
  541. self.output_query = InstructBlipQFormerOutput(config)
  542. def forward(
  543. self,
  544. hidden_states,
  545. attention_mask=None,
  546. head_mask=None,
  547. encoder_hidden_states=None,
  548. encoder_attention_mask=None,
  549. query_length=0,
  550. **kwargs: Unpack[TransformersKwargs],
  551. ):
  552. attention_output = self.attention(
  553. hidden_states,
  554. attention_mask=attention_mask,
  555. head_mask=head_mask,
  556. **kwargs,
  557. )
  558. if query_length > 0:
  559. query_attention_output = attention_output[:, :query_length, :]
  560. if self.has_cross_attention:
  561. if encoder_hidden_states is None:
  562. raise ValueError("encoder_hidden_states must be given for cross-attention layers")
  563. query_attention_output = self.crossattention(
  564. query_attention_output,
  565. attention_mask=attention_mask,
  566. head_mask=head_mask,
  567. encoder_hidden_states=encoder_hidden_states,
  568. encoder_attention_mask=encoder_attention_mask,
  569. **kwargs,
  570. )
  571. layer_output = apply_chunking_to_forward(
  572. self.feed_forward_chunk_query,
  573. self.chunk_size_feed_forward,
  574. self.seq_len_dim,
  575. query_attention_output,
  576. )
  577. if attention_output.shape[1] > query_length:
  578. layer_output_text = apply_chunking_to_forward(
  579. self.feed_forward_chunk,
  580. self.chunk_size_feed_forward,
  581. self.seq_len_dim,
  582. attention_output[:, query_length:, :],
  583. ).to(layer_output.device)
  584. layer_output = torch.cat([layer_output, layer_output_text], dim=1)
  585. else:
  586. layer_output = apply_chunking_to_forward(
  587. self.feed_forward_chunk,
  588. self.chunk_size_feed_forward,
  589. self.seq_len_dim,
  590. attention_output,
  591. )
  592. return layer_output
  593. def feed_forward_chunk(self, attention_output):
  594. intermediate_output = self.intermediate(attention_output)
  595. layer_output = self.output(intermediate_output, attention_output)
  596. return layer_output
  597. def feed_forward_chunk_query(self, attention_output):
  598. intermediate_output = self.intermediate_query(attention_output)
  599. layer_output = self.output_query(intermediate_output, attention_output)
  600. return layer_output
  601. # Copied from transformers.models.blip_2.modeling_blip_2.Blip2QFormerEncoder with Blip2->InstructBlip
  602. class InstructBlipQFormerEncoder(nn.Module):
  603. def __init__(self, config):
  604. super().__init__()
  605. self.config = config
  606. self.layer = nn.ModuleList(
  607. [InstructBlipQFormerLayer(config, layer_idx) for layer_idx in range(config.num_hidden_layers)]
  608. )
  609. self.gradient_checkpointing = False
  610. @can_return_tuple
  611. def forward(
  612. self,
  613. hidden_states,
  614. attention_mask=None,
  615. head_mask=None,
  616. encoder_hidden_states=None,
  617. encoder_attention_mask=None,
  618. query_length=0,
  619. **kwargs: Unpack[TransformersKwargs],
  620. ):
  621. for i in range(self.config.num_hidden_layers):
  622. layer_module = self.layer[i]
  623. layer_head_mask = head_mask[i] if head_mask is not None else None
  624. hidden_states = layer_module(
  625. hidden_states,
  626. attention_mask,
  627. layer_head_mask,
  628. encoder_hidden_states, # as a positional argument for gradient checkpointing
  629. encoder_attention_mask=encoder_attention_mask,
  630. query_length=query_length,
  631. **kwargs,
  632. )
  633. return BaseModelOutputWithPastAndCrossAttentions(
  634. last_hidden_state=hidden_states,
  635. )
  636. class InstructBlipQFormerEmbeddings(nn.Module):
  637. """Construct the embeddings from word and position embeddings."""
  638. def __init__(self, config):
  639. super().__init__()
  640. self.word_embeddings = nn.Embedding(config.vocab_size, config.hidden_size, padding_idx=config.pad_token_id)
  641. self.position_embeddings = nn.Embedding(config.max_position_embeddings, config.hidden_size)
  642. self.layernorm = nn.LayerNorm(config.hidden_size, eps=config.layer_norm_eps)
  643. self.dropout = nn.Dropout(config.hidden_dropout_prob)
  644. # position_ids (1, len position emb) is contiguous in memory and exported when serialized
  645. self.register_buffer(
  646. "position_ids", torch.arange(config.max_position_embeddings).expand((1, -1)), persistent=False
  647. )
  648. self.position_embedding_type = getattr(config, "position_embedding_type", "absolute")
  649. self.config = config
  650. def forward(
  651. self,
  652. input_ids=None,
  653. position_ids=None,
  654. query_embeds=None,
  655. past_key_values_length=0,
  656. ):
  657. if input_ids is not None:
  658. seq_length = input_ids.size()[1]
  659. else:
  660. seq_length = 0
  661. if position_ids is None:
  662. position_ids = self.position_ids[:, past_key_values_length : seq_length + past_key_values_length].clone()
  663. if input_ids is not None:
  664. embeddings = self.word_embeddings(input_ids)
  665. if self.position_embedding_type == "absolute":
  666. position_embeddings = self.position_embeddings(position_ids.to(embeddings.device))
  667. embeddings = embeddings + position_embeddings
  668. if query_embeds is not None:
  669. embeddings = torch.cat((query_embeds, embeddings), dim=1)
  670. else:
  671. embeddings = query_embeds
  672. embeddings = embeddings.to(self.layernorm.weight.dtype)
  673. embeddings = self.layernorm(embeddings)
  674. embeddings = self.dropout(embeddings)
  675. return embeddings
  676. class InstructBlipQFormerModel(InstructBlipPreTrainedModel):
  677. """
  678. Querying Transformer (Q-Former), used in InstructBLIP. Slightly modified from BLIP-2 as it also takes the
  679. instruction as input.
  680. """
  681. _supports_attention_backend = False # adds position on attn weights before last matmul
  682. _supports_flash_attn = False
  683. _supports_sdpa = False
  684. _supports_flex_attn = False
  685. _can_record_outputs = {
  686. "hidden_states": InstructBlipQFormerLayer,
  687. "attentions": [
  688. OutputRecorder(InstructBlipQFormerMultiHeadAttention, index=1, layer_name=".attention"),
  689. ],
  690. "cross_attentions": [
  691. OutputRecorder(InstructBlipQFormerMultiHeadAttention, index=1, layer_name=".crossattention"),
  692. ],
  693. }
  694. def __init__(self, config: InstructBlipQFormerConfig):
  695. super().__init__(config)
  696. self.config = config
  697. self.embeddings = InstructBlipQFormerEmbeddings(config)
  698. self.encoder = InstructBlipQFormerEncoder(config)
  699. self.post_init()
  700. def get_input_embeddings(self):
  701. return self.embeddings.word_embeddings
  702. def set_input_embeddings(self, value):
  703. self.embeddings.word_embeddings = value
  704. def _prune_heads(self, heads_to_prune):
  705. """
  706. Prunes heads of the model. heads_to_prune: dict of {layer_num: list of heads to prune in this layer} See base
  707. class PreTrainedModel
  708. """
  709. for layer, heads in heads_to_prune.items():
  710. self.encoder.layer[layer].attention.prune_heads(heads)
  711. def get_extended_attention_mask(
  712. self,
  713. attention_mask: torch.Tensor,
  714. input_shape: tuple[int],
  715. device: torch.device,
  716. has_query: bool = False,
  717. ) -> torch.Tensor:
  718. """
  719. Makes broadcastable attention and causal masks so that future and masked tokens are ignored.
  720. Arguments:
  721. attention_mask (`torch.Tensor`):
  722. Mask with ones indicating tokens to attend to, zeros for tokens to ignore.
  723. input_shape (`tuple[int]`):
  724. The shape of the input to the model.
  725. device: (`torch.device`):
  726. The device of the input to the model.
  727. Returns:
  728. `torch.Tensor` The extended attention mask, with a the same dtype as `attention_mask.dtype`.
  729. """
  730. # We can provide a self-attention mask of dimensions [batch_size, from_seq_length, to_seq_length]
  731. # ourselves in which case we just need to make it broadcastable to all heads.
  732. if attention_mask.dim() == 3:
  733. extended_attention_mask = attention_mask[:, None, :, :]
  734. elif attention_mask.dim() == 2:
  735. # Provided a padding mask of dimensions [batch_size, seq_length]
  736. # - the model is an encoder, so make the mask broadcastable to [batch_size, num_heads, seq_length, seq_length]
  737. extended_attention_mask = attention_mask[:, None, None, :]
  738. else:
  739. raise ValueError(
  740. f"Wrong shape for input_ids (shape {input_shape}) or attention_mask (shape {attention_mask.shape})",
  741. )
  742. # Since attention_mask is 1.0 for positions we want to attend and 0.0 for
  743. # masked positions, this operation will create a tensor which is 0.0 for
  744. # positions we want to attend and -10000.0 for masked positions.
  745. # Since we are adding it to the raw scores before the softmax, this is
  746. # effectively the same as removing these entirely.
  747. extended_attention_mask = extended_attention_mask.to(dtype=self.dtype) # fp16 compatibility
  748. extended_attention_mask = (1.0 - extended_attention_mask) * -10000.0
  749. return extended_attention_mask
  750. @check_model_inputs()
  751. @auto_docstring
  752. def forward(
  753. self,
  754. input_ids: torch.LongTensor,
  755. attention_mask: Optional[torch.FloatTensor] = None,
  756. position_ids: Optional[torch.LongTensor] = None,
  757. query_embeds: Optional[torch.Tensor] = None,
  758. head_mask: Optional[torch.FloatTensor] = None,
  759. encoder_hidden_states: Optional[torch.FloatTensor] = None,
  760. encoder_attention_mask: Optional[torch.FloatTensor] = None,
  761. **kwargs: Unpack[TransformersKwargs],
  762. ) -> Union[tuple[torch.FloatTensor], BaseModelOutputWithPoolingAndCrossAttentions]:
  763. r"""
  764. query_embeds (`torch.FloatTensor` of shape `(batch_size, sequence_length, hidden_size)`):
  765. Hidden states to be used in the attention computation. If cross-attention,
  766. will be used for the query (i.e., key and value will use the encoder_hidden_states).
  767. """
  768. if input_ids is None and query_embeds is None:
  769. raise ValueError("You have to specify query_embeds when input_ids is None")
  770. query_length = query_embeds.shape[1] if query_embeds is not None else 0
  771. embedding_output = self.embeddings(
  772. input_ids=input_ids,
  773. position_ids=position_ids,
  774. query_embeds=query_embeds,
  775. )
  776. input_shape = embedding_output.size()[:-1]
  777. batch_size, seq_length = input_shape
  778. device = embedding_output.device
  779. if attention_mask is None:
  780. attention_mask = torch.ones(((batch_size, seq_length)), device=device)
  781. # We can provide a self-attention mask of dimensions [batch_size, from_seq_length, to_seq_length]
  782. # ourselves in which case we just need to make it broadcastable to all heads.
  783. extended_attention_mask = self.get_extended_attention_mask(attention_mask, input_shape, device)
  784. # If a 2D or 3D attention mask is provided for the cross-attention
  785. # we need to make broadcastable to [batch_size, num_heads, seq_length, seq_length]
  786. if encoder_hidden_states is not None:
  787. if isinstance(encoder_hidden_states, list):
  788. encoder_batch_size, encoder_sequence_length, _ = encoder_hidden_states[0].size()
  789. else:
  790. encoder_batch_size, encoder_sequence_length, _ = encoder_hidden_states.size()
  791. encoder_hidden_shape = (encoder_batch_size, encoder_sequence_length)
  792. if isinstance(encoder_attention_mask, list):
  793. encoder_extended_attention_mask = [self.invert_attention_mask(mask) for mask in encoder_attention_mask]
  794. elif encoder_attention_mask is None:
  795. encoder_attention_mask = torch.ones(encoder_hidden_shape, device=device)
  796. encoder_extended_attention_mask = self.invert_attention_mask(encoder_attention_mask)
  797. else:
  798. encoder_extended_attention_mask = self.invert_attention_mask(encoder_attention_mask)
  799. else:
  800. encoder_extended_attention_mask = None
  801. # Prepare head mask if needed
  802. # 1.0 in head_mask indicate we keep the head
  803. # attention_probs has shape bsz x n_heads x N x N
  804. # input head_mask has shape [num_heads] or [num_hidden_layers x num_heads]
  805. # and head_mask is converted to shape [num_hidden_layers x batch x num_heads x seq_length x seq_length]
  806. head_mask = self.get_head_mask(head_mask, self.config.num_hidden_layers)
  807. encoder_outputs: BaseModelOutput = self.encoder(
  808. embedding_output,
  809. attention_mask=extended_attention_mask,
  810. head_mask=head_mask,
  811. encoder_hidden_states=encoder_hidden_states,
  812. encoder_attention_mask=encoder_extended_attention_mask,
  813. query_length=query_length,
  814. **kwargs,
  815. )
  816. sequence_output = encoder_outputs.last_hidden_state
  817. pooled_output = sequence_output[:, 0, :]
  818. return BaseModelOutputWithPoolingAndCrossAttentions(
  819. last_hidden_state=sequence_output,
  820. pooler_output=pooled_output,
  821. )
  822. @auto_docstring(
  823. custom_intro="""
  824. InstructBLIP base Model consisting of language model, qformer and vision encoder.
  825. """
  826. )
  827. class InstructBlipModel(InstructBlipPreTrainedModel):
  828. main_input_name = "pixel_values"
  829. _keep_in_fp32_modules = ["query_tokens"] # TODO @ArthurZucker I don't know why this is required for FP8
  830. def __init__(self, config: InstructBlipConfig):
  831. super().__init__(config)
  832. self.vision_model = InstructBlipVisionModel(config.vision_config)
  833. self.query_tokens = nn.Parameter(torch.zeros(1, config.num_query_tokens, config.qformer_config.hidden_size))
  834. self.qformer = InstructBlipQFormerModel(config.qformer_config)
  835. self.language_projection = nn.Linear(config.qformer_config.hidden_size, config.text_config.hidden_size)
  836. self.language_model = AutoModel.from_config(config.text_config)
  837. if self.language_model._no_split_modules is not None:
  838. self._no_split_modules.extend(self.language_model._no_split_modules)
  839. if self.language_model._keep_in_fp32_modules is not None:
  840. self._keep_in_fp32_modules.extend(self.language_model._keep_in_fp32_modules)
  841. # Initialize weights and apply final processing
  842. self.post_init()
  843. def get_input_embeddings(self):
  844. return self.language_model.get_input_embeddings()
  845. def set_input_embeddings(self, value):
  846. self.language_model.set_input_embeddings(value)
  847. def _tie_weights(self):
  848. if not self.config.use_decoder_only_language_model:
  849. self.language_model.encoder.embed_tokens = self.language_model.shared
  850. self.language_model.decoder.embed_tokens = self.language_model.shared
  851. def _preprocess_accelerate(self):
  852. r"""
  853. Some pre-processing hacks to make the model `accelerate` compatible. Check
  854. https://github.com/huggingface/transformers/pull/21707 for more details.
  855. """
  856. hf_device_map = self.hf_device_map
  857. if len(hf_device_map) > 1 and "language_model" not in hf_device_map and torch.cuda.device_count() > 1:
  858. # warn users about unexpected behavior when using multi-GPU + InstructBLIP + `accelerate`.
  859. logger.warning(
  860. "The `language_model` is not in the `hf_device_map` dictionary and you are running your script"
  861. " in a multi-GPU environment. this may lead to unexpected behavior when using `accelerate`."
  862. " Please pass a `device_map` that contains `language_model` to remove this warning."
  863. " Please refer to https://github.com/huggingface/blog/blob/main/accelerate-large-models.md for"
  864. " more details on creating a `device_map` for large models.",
  865. )
  866. if hasattr(self.language_model, "_hf_hook"):
  867. self.language_model._hf_hook.io_same_device = True # For `generate` compatibility
  868. def get_placeholder_mask(self, input_ids: torch.LongTensor, inputs_embeds: torch.FloatTensor):
  869. """
  870. Obtains multimodal placeholder mask from `input_ids` or `inputs_embeds`.
  871. """
  872. if input_ids is None:
  873. special_image_mask = inputs_embeds == self.get_input_embeddings()(
  874. torch.tensor(self.config.image_token_id, dtype=torch.long, device=inputs_embeds.device)
  875. )
  876. special_image_mask = special_image_mask.all(-1)
  877. else:
  878. special_image_mask = input_ids == self.config.image_token_id
  879. special_image_mask = special_image_mask.unsqueeze(-1).expand_as(inputs_embeds).to(inputs_embeds.device)
  880. return special_image_mask
  881. @can_return_tuple
  882. @auto_docstring
  883. def forward(
  884. self,
  885. pixel_values: torch.FloatTensor,
  886. qformer_input_ids: torch.FloatTensor,
  887. qformer_attention_mask: Optional[torch.LongTensor] = None,
  888. input_ids: Optional[torch.FloatTensor] = None,
  889. attention_mask: Optional[torch.LongTensor] = None,
  890. decoder_input_ids: Optional[torch.LongTensor] = None,
  891. decoder_attention_mask: Optional[torch.LongTensor] = None,
  892. inputs_embeds: Optional[torch.Tensor] = None,
  893. interpolate_pos_encoding: bool = False,
  894. **kwargs: Unpack[FlashAttentionKwargs],
  895. ) -> Union[tuple, InstructBlipForConditionalGenerationModelOutput]:
  896. r"""
  897. qformer_input_ids (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*):
  898. Indices of input sequence tokens in the vocabulary of the Q-Former. Input tokens can optionally be provided
  899. to serve as text prompt, which the Q-Former model will encode.
  900. Indices can be obtained using [`InstructBlipProcessor`]. See [`InstructBlipProcessor.__call__`] for
  901. details.
  902. [What are input IDs?](../glossary#input-ids)
  903. qformer_attention_mask (`torch.Tensor` of shape `(batch_size, sequence_length)`, *optional*):
  904. Mask to avoid performing attention on padding token indices. Mask values selected in `[0, 1]`:
  905. - 1 for tokens that are **not masked**,
  906. - 0 for tokens that are **masked**.
  907. [What are attention masks?](../glossary#attention-mask)
  908. decoder_attention_mask (`torch.BoolTensor` of shape `(batch_size, target_sequence_length)`, *optional*):
  909. Default behavior: generate a tensor that ignores pad tokens in `decoder_input_ids`. Causal mask will also
  910. be used by default.
  911. Only relevant in case an encoder-decoder language model (like T5) is used.
  912. """
  913. # step 1: forward the images through the vision encoder,
  914. # to get image embeddings of shape (batch_size, seq_len, hidden_size)
  915. vision_outputs = self.vision_model(
  916. pixel_values=pixel_values,
  917. interpolate_pos_encoding=interpolate_pos_encoding,
  918. **kwargs,
  919. )
  920. image_embeds = vision_outputs[0]
  921. # step 2: forward the query tokens through the QFormer, using the image embeddings for cross-attention
  922. image_attention_mask = torch.ones(image_embeds.size()[:-1], dtype=torch.long, device=image_embeds.device)
  923. # difference with BLIP-2 here: we also feed the instruction prompt to the Q-Former
  924. query_tokens = self.query_tokens.expand(image_embeds.shape[0], -1, -1)
  925. query_attention_mask = torch.ones(query_tokens.size()[:-1], dtype=torch.long, device=image_embeds.device)
  926. if qformer_attention_mask is None:
  927. qformer_attention_mask = torch.ones_like(qformer_input_ids)
  928. qformer_attention_mask = torch.cat([query_attention_mask, qformer_attention_mask], dim=1)
  929. query_outputs = self.qformer(
  930. input_ids=qformer_input_ids,
  931. attention_mask=qformer_attention_mask,
  932. query_embeds=query_tokens,
  933. encoder_hidden_states=image_embeds,
  934. encoder_attention_mask=image_attention_mask,
  935. **kwargs,
  936. )
  937. query_output = query_outputs[0][:, : query_tokens.size(1), :]
  938. if inputs_embeds is None:
  939. inputs_embeds = self.language_model.get_input_embeddings()(input_ids)
  940. if attention_mask is None:
  941. attention_mask = torch.ones_like(input_ids)
  942. # step 3: use the language model, conditioned on the query outputs and the prompt
  943. language_model_inputs = self.language_projection(query_output)
  944. language_model_inputs = language_model_inputs.to(inputs_embeds.device, inputs_embeds.dtype)
  945. special_image_mask = self.get_placeholder_mask(input_ids, inputs_embeds=inputs_embeds)
  946. inputs_embeds = inputs_embeds.masked_scatter(special_image_mask, language_model_inputs)
  947. if self.config.use_decoder_only_language_model:
  948. outputs = self.language_model(
  949. inputs_embeds=inputs_embeds,
  950. attention_mask=attention_mask,
  951. **kwargs,
  952. )
  953. else:
  954. outputs = self.language_model(
  955. inputs_embeds=inputs_embeds,
  956. attention_mask=attention_mask,
  957. decoder_input_ids=decoder_input_ids,
  958. decoder_attention_mask=decoder_attention_mask,
  959. **kwargs,
  960. )
  961. return InstructBlipForConditionalGenerationModelOutput(
  962. vision_outputs=vision_outputs,
  963. qformer_outputs=query_outputs,
  964. language_model_outputs=outputs,
  965. )
  966. @auto_docstring(
  967. custom_intro="""
  968. InstructBLIP Model for generating text given an image and an optional text prompt. The model consists of a vision
  969. encoder, Querying Transformer (Q-Former) and a language model.
  970. One can optionally pass `input_ids` to the model, which serve as a text prompt, to make the language model continue
  971. the prompt. Otherwise, the language model starts generating text from the [BOS] (beginning-of-sequence) token.
  972. """
  973. )
  974. class InstructBlipForConditionalGeneration(InstructBlipPreTrainedModel, GenerationMixin):
  975. config: InstructBlipConfig
  976. main_input_name = "pixel_values"
  977. _can_compile_fullgraph = True
  978. _keep_in_fp32_modules = ["query_tokens"] # TODO @ArthurZucker I don't know why this is required for FP8
  979. def __init__(self, config: InstructBlipConfig):
  980. super().__init__(config)
  981. self.vision_model = InstructBlipVisionModel._from_config(config.vision_config)
  982. self.query_tokens = nn.Parameter(torch.zeros(1, config.num_query_tokens, config.qformer_config.hidden_size))
  983. self.qformer = InstructBlipQFormerModel._from_config(config.qformer_config)
  984. self.language_projection = nn.Linear(config.qformer_config.hidden_size, config.text_config.hidden_size)
  985. if config.use_decoder_only_language_model:
  986. language_model = AutoModelForCausalLM.from_config(config.text_config)
  987. else:
  988. language_model = AutoModelForSeq2SeqLM.from_config(config.text_config)
  989. if language_model._no_split_modules is not None:
  990. self._no_split_modules.extend(language_model._no_split_modules)
  991. if language_model._keep_in_fp32_modules is not None:
  992. self._keep_in_fp32_modules.extend(language_model._keep_in_fp32_modules)
  993. self.language_model = language_model
  994. # Initialize weights and apply final processing
  995. self.post_init()
  996. def get_input_embeddings(self):
  997. return self.language_model.get_input_embeddings()
  998. def set_input_embeddings(self, value):
  999. self.language_model.set_input_embeddings(value)
  1000. def set_output_embeddings(self, new_embeddings):
  1001. self.language_model.set_output_embeddings(new_embeddings)
  1002. def get_output_embeddings(self) -> nn.Module:
  1003. return self.language_model.get_output_embeddings()
  1004. def get_encoder(self):
  1005. return self.language_model.get_encoder()
  1006. def get_decoder(self):
  1007. return self.language_model.get_decoder()
  1008. # Copied from transformers.models.instructblip.modeling_instructblip.InstructBlipModel._tie_weights
  1009. def _tie_weights(self):
  1010. if not self.config.use_decoder_only_language_model:
  1011. self.language_model.encoder.embed_tokens = self.language_model.shared
  1012. self.language_model.decoder.embed_tokens = self.language_model.shared
  1013. # Copied from transformers.models.instructblip.modeling_instructblip.InstructBlipModel._preprocess_accelerate
  1014. def _preprocess_accelerate(self):
  1015. r"""
  1016. Some pre-processing hacks to make the model `accelerate` compatible. Check
  1017. https://github.com/huggingface/transformers/pull/21707 for more details.
  1018. """
  1019. hf_device_map = self.hf_device_map
  1020. if len(hf_device_map) > 1 and "language_model" not in hf_device_map and torch.cuda.device_count() > 1:
  1021. # warn users about unexpected behavior when using multi-GPU + InstructBLIP + `accelerate`.
  1022. logger.warning(
  1023. "The `language_model` is not in the `hf_device_map` dictionary and you are running your script"
  1024. " in a multi-GPU environment. this may lead to unexpected behavior when using `accelerate`."
  1025. " Please pass a `device_map` that contains `language_model` to remove this warning."
  1026. " Please refer to https://github.com/huggingface/blog/blob/main/accelerate-large-models.md for"
  1027. " more details on creating a `device_map` for large models.",
  1028. )
  1029. if hasattr(self.language_model, "_hf_hook"):
  1030. self.language_model._hf_hook.io_same_device = True # For `generate` compatibility
  1031. def get_image_features(
  1032. self,
  1033. pixel_values: torch.FloatTensor,
  1034. qformer_input_ids: torch.LongTensor,
  1035. qformer_attention_mask: Optional[torch.LongTensor] = None,
  1036. interpolate_pos_encoding: Optional[bool] = False,
  1037. return_dict: Optional[bool] = False,
  1038. ):
  1039. """
  1040. Encodes images into continuous embeddings that can be forwarded to the language model.
  1041. Args:
  1042. pixel_values (`torch.FloatTensor` of shape `(batch_size, num_channels, image_size, image_size)`):
  1043. The tensors corresponding to the input images.
  1044. """
  1045. # step 1: forward the images through the vision encoder,
  1046. # to get image embeddings of shape (batch_size, seq_len, hidden_size)
  1047. vision_outputs = self.vision_model(
  1048. pixel_values=pixel_values,
  1049. interpolate_pos_encoding=interpolate_pos_encoding,
  1050. return_dict=True,
  1051. )
  1052. image_embeds = vision_outputs[0]
  1053. # step 2: forward the query tokens through the QFormer, using the image embeddings for cross-attention
  1054. image_attention_mask = torch.ones(image_embeds.size()[:-1], dtype=torch.long, device=image_embeds.device)
  1055. # difference with BLIP-2 here: we also feed the instruction prompt to the Q-Former
  1056. query_tokens = self.query_tokens.expand(image_embeds.shape[0], -1, -1)
  1057. query_attention_mask = torch.ones(query_tokens.size()[:-1], dtype=torch.long, device=image_embeds.device)
  1058. if qformer_attention_mask is None:
  1059. qformer_attention_mask = torch.ones_like(qformer_input_ids)
  1060. qformer_attention_mask = torch.cat([query_attention_mask, qformer_attention_mask], dim=1)
  1061. query_outputs = self.qformer(
  1062. input_ids=qformer_input_ids,
  1063. attention_mask=qformer_attention_mask,
  1064. query_embeds=query_tokens,
  1065. encoder_hidden_states=image_embeds,
  1066. encoder_attention_mask=image_attention_mask,
  1067. return_dict=True,
  1068. )
  1069. query_output = query_outputs[0][:, : query_tokens.size(1), :]
  1070. # step 3: use the language model, conditioned on the query outputs and the prompt
  1071. language_model_inputs = self.language_projection(query_output)
  1072. if return_dict:
  1073. return language_model_inputs, vision_outputs, query_outputs
  1074. return language_model_inputs
  1075. def get_placeholder_mask(self, input_ids: torch.LongTensor, inputs_embeds: torch.FloatTensor):
  1076. """
  1077. Obtains multimodal placeholder mask from `input_ids` or `inputs_embeds`.
  1078. """
  1079. if input_ids is None:
  1080. special_image_mask = inputs_embeds == self.get_input_embeddings()(
  1081. torch.tensor(self.config.image_token_id, dtype=torch.long, device=inputs_embeds.device)
  1082. )
  1083. special_image_mask = special_image_mask.all(-1)
  1084. else:
  1085. special_image_mask = input_ids == self.config.image_token_id
  1086. special_image_mask = special_image_mask.unsqueeze(-1).expand_as(inputs_embeds).to(inputs_embeds.device)
  1087. return special_image_mask
  1088. @can_return_tuple
  1089. @auto_docstring
  1090. def forward(
  1091. self,
  1092. pixel_values: torch.FloatTensor,
  1093. qformer_input_ids: torch.FloatTensor,
  1094. qformer_attention_mask: Optional[torch.LongTensor] = None,
  1095. input_ids: Optional[torch.FloatTensor] = None,
  1096. attention_mask: Optional[torch.LongTensor] = None,
  1097. decoder_input_ids: Optional[torch.LongTensor] = None,
  1098. decoder_attention_mask: Optional[torch.LongTensor] = None,
  1099. inputs_embeds: Optional[torch.FloatTensor] = None,
  1100. labels: Optional[torch.LongTensor] = None,
  1101. interpolate_pos_encoding: bool = False,
  1102. **kwargs: Unpack[TransformersKwargs],
  1103. ) -> Union[tuple, InstructBlipForConditionalGenerationModelOutput]:
  1104. r"""
  1105. qformer_input_ids (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*):
  1106. Indices of input sequence tokens in the vocabulary of the Q-Former. Input tokens can optionally be provided
  1107. to serve as text prompt, which the Q-Former model will encode.
  1108. Indices can be obtained using [`InstructBlipProcessor`]. See [`InstructBlipProcessor.__call__`] for
  1109. details.
  1110. [What are input IDs?](../glossary#input-ids)
  1111. qformer_attention_mask (`torch.Tensor` of shape `(batch_size, sequence_length)`, *optional*):
  1112. Mask to avoid performing attention on padding token indices. Mask values selected in `[0, 1]`:
  1113. - 1 for tokens that are **not masked**,
  1114. - 0 for tokens that are **masked**.
  1115. [What are attention masks?](../glossary#attention-mask)
  1116. decoder_attention_mask (`torch.BoolTensor` of shape `(batch_size, target_sequence_length)`, *optional*):
  1117. Default behavior: generate a tensor that ignores pad tokens in `decoder_input_ids`. Causal mask will also
  1118. be used by default.
  1119. Only relevant in case an encoder-decoder language model (like T5) is used.
  1120. labels (`torch.LongTensor` of shape `(batch_size,)`, *optional*):
  1121. Labels for computing the language modeling loss. Indices should be in `[-100, 0, ..., config.vocab_size -
  1122. 1]`. All labels set to `-100` are ignored (masked), the loss is only computed for labels in `[0, ...,
  1123. config.vocab_size]`
  1124. Examples:
  1125. ```python
  1126. >>> from transformers import InstructBlipProcessor, InstructBlipForConditionalGeneration
  1127. >>> import torch
  1128. >>> from PIL import Image
  1129. >>> import requests
  1130. >>> model = InstructBlipForConditionalGeneration.from_pretrained("Salesforce/instructblip-vicuna-7b")
  1131. >>> processor = InstructBlipProcessor.from_pretrained("Salesforce/instructblip-vicuna-7b")
  1132. >>> device = "cuda" if torch.cuda.is_available() else "cpu"
  1133. >>> model.to(device) # doctest: +IGNORE_RESULT
  1134. >>> url = "https://raw.githubusercontent.com/salesforce/LAVIS/main/docs/_static/Confusing-Pictures.jpg"
  1135. >>> image = Image.open(requests.get(url, stream=True).raw).convert("RGB")
  1136. >>> prompt = "What is unusual about this image?"
  1137. >>> inputs = processor(images=image, text=prompt, return_tensors="pt").to(device)
  1138. >>> outputs = model.generate(
  1139. ... **inputs,
  1140. ... do_sample=False,
  1141. ... num_beams=5,
  1142. ... max_length=256,
  1143. ... min_length=1,
  1144. ... top_p=0.9,
  1145. ... repetition_penalty=1.5,
  1146. ... length_penalty=1.0,
  1147. ... temperature=1,
  1148. ... )
  1149. >>> generated_text = processor.batch_decode(outputs, skip_special_tokens=True)[0].strip()
  1150. >>> print(generated_text)
  1151. The unusual aspect of this image is that a man is ironing clothes on the back of a yellow SUV, which is parked in the middle of a busy city street. This is an unconventional approach to ironing clothes, as it requires the man to balance himself and his ironing equipment on top of the vehicle while navigating through traffic. Additionally, the presence of taxis and other vehicles in the scene further emphasizes the unusual nature of this situation.
  1152. ```"""
  1153. language_model_inputs, vision_outputs, query_outputs = self.get_image_features(
  1154. pixel_values,
  1155. qformer_input_ids=qformer_input_ids,
  1156. qformer_attention_mask=qformer_attention_mask,
  1157. interpolate_pos_encoding=interpolate_pos_encoding,
  1158. return_dict=True,
  1159. )
  1160. if inputs_embeds is None:
  1161. inputs_embeds = self.get_input_embeddings()(input_ids)
  1162. if attention_mask is None:
  1163. attention_mask = torch.ones_like(input_ids)
  1164. language_model_inputs = language_model_inputs.to(inputs_embeds.device, inputs_embeds.dtype)
  1165. special_image_mask = self.get_placeholder_mask(input_ids, inputs_embeds=inputs_embeds)
  1166. inputs_embeds = inputs_embeds.masked_scatter(special_image_mask, language_model_inputs)
  1167. if self.config.use_decoder_only_language_model:
  1168. outputs = self.language_model(
  1169. inputs_embeds=inputs_embeds,
  1170. attention_mask=attention_mask,
  1171. **kwargs,
  1172. )
  1173. logits = outputs[0]
  1174. loss = None
  1175. if labels is not None:
  1176. loss = self.loss_function(
  1177. logits=logits, labels=labels, vocab_size=self.config.text_config.vocab_size, **kwargs
  1178. )
  1179. else:
  1180. kwargs["return_dict"] = True
  1181. outputs = self.language_model(
  1182. inputs_embeds=inputs_embeds,
  1183. attention_mask=attention_mask,
  1184. decoder_input_ids=decoder_input_ids,
  1185. decoder_attention_mask=decoder_attention_mask,
  1186. labels=labels,
  1187. **kwargs,
  1188. )
  1189. loss = outputs.loss
  1190. logits = outputs.logits
  1191. return InstructBlipForConditionalGenerationModelOutput(
  1192. loss=loss,
  1193. logits=logits,
  1194. vision_outputs=vision_outputs,
  1195. qformer_outputs=query_outputs,
  1196. language_model_outputs=outputs,
  1197. )
  1198. @torch.no_grad()
  1199. def generate(
  1200. self,
  1201. pixel_values: torch.FloatTensor,
  1202. qformer_input_ids: Optional[torch.LongTensor] = None,
  1203. qformer_attention_mask: Optional[torch.LongTensor] = None,
  1204. input_ids: Optional[torch.LongTensor] = None,
  1205. attention_mask: Optional[torch.LongTensor] = None,
  1206. inputs_embeds: Optional[torch.FloatTensor] = None,
  1207. interpolate_pos_encoding: bool = False,
  1208. **generate_kwargs,
  1209. ) -> torch.LongTensor:
  1210. """
  1211. Overrides `generate` function to be able to use the model as a conditional generator.
  1212. Args:
  1213. pixel_values (`torch.FloatTensor` of shape (batch_size, num_channels, height, width)):
  1214. Input images to be processed.
  1215. qformer_input_ids (`torch.LongTensor` of shape (batch_size, sequence_length), *optional*):
  1216. The sequence used as a prompt to be fed to the Q-Former module.
  1217. qformer_attention_mask (`torch.LongTensor` of shape (batch_size, sequence_length), *optional*):
  1218. Mask to avoid performing attention on padding token indices.
  1219. input_ids (`torch.LongTensor` of shape (batch_size, sequence_length), *optional*):
  1220. The sequence used as a prompt for the generation.
  1221. attention_mask (`torch.LongTensor` of shape (batch_size, sequence_length), *optional*):
  1222. Mask to avoid performing attention on padding token indices.
  1223. inputs_embeds (`torch.FloatTensor` of shape `(batch_size, sequence_length, hidden_size)`):
  1224. Embedded representation of the inputs. Should be float, not int tokens.
  1225. interpolate_pos_encoding (`bool`, *optional*, defaults to `False`):
  1226. Whether to interpolate the positional encoding of the image embeddings.
  1227. Returns:
  1228. captions (list): A list of strings of length batch_size * num_captions.
  1229. """
  1230. if hasattr(self, "hf_device_map"):
  1231. # preprocess for `accelerate`
  1232. self._preprocess_accelerate()
  1233. batch_size = pixel_values.shape[0]
  1234. language_model_inputs, vision_outputs, query_outputs = self.get_image_features(
  1235. pixel_values,
  1236. qformer_input_ids=qformer_input_ids,
  1237. qformer_attention_mask=qformer_attention_mask,
  1238. interpolate_pos_encoding=interpolate_pos_encoding,
  1239. return_dict=True,
  1240. )
  1241. if inputs_embeds is None:
  1242. if input_ids is None:
  1243. image_tokens = [self.config.image_token_index] * self.config.num_query_tokens
  1244. start_tokens = image_tokens + [self.config.text_config.bos_token_id]
  1245. input_ids = torch.tensor([start_tokens], dtype=torch.long, device=pixel_values.device)
  1246. input_ids = input_ids.repeat(batch_size, 1)
  1247. inputs_embeds = self.get_input_embeddings()(input_ids)
  1248. if attention_mask is None:
  1249. attention_mask = torch.ones_like(input_ids)
  1250. language_model_inputs = language_model_inputs.to(inputs_embeds.device, inputs_embeds.dtype)
  1251. special_image_mask = self.get_placeholder_mask(input_ids, inputs_embeds=inputs_embeds)
  1252. inputs_embeds = inputs_embeds.masked_scatter(special_image_mask, language_model_inputs)
  1253. inputs = {"inputs_embeds": inputs_embeds, "attention_mask": attention_mask}
  1254. if not self.language_model.config.is_encoder_decoder:
  1255. inputs["input_ids"] = input_ids
  1256. outputs = self.language_model.generate(**inputs, **generate_kwargs)
  1257. return outputs
  1258. __all__ = [
  1259. "InstructBlipQFormerModel",
  1260. "InstructBlipPreTrainedModel",
  1261. "InstructBlipModel",
  1262. "InstructBlipForConditionalGeneration",
  1263. "InstructBlipVisionModel",
  1264. ]