modeling_blip.py 50 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525526527528529530531532533534535536537538539540541542543544545546547548549550551552553554555556557558559560561562563564565566567568569570571572573574575576577578579580581582583584585586587588589590591592593594595596597598599600601602603604605606607608609610611612613614615616617618619620621622623624625626627628629630631632633634635636637638639640641642643644645646647648649650651652653654655656657658659660661662663664665666667668669670671672673674675676677678679680681682683684685686687688689690691692693694695696697698699700701702703704705706707708709710711712713714715716717718719720721722723724725726727728729730731732733734735736737738739740741742743744745746747748749750751752753754755756757758759760761762763764765766767768769770771772773774775776777778779780781782783784785786787788789790791792793794795796797798799800801802803804805806807808809810811812813814815816817818819820821822823824825826827828829830831832833834835836837838839840841842843844845846847848849850851852853854855856857858859860861862863864865866867868869870871872873874875876877878879880881882883884885886887888889890891892893894895896897898899900901902903904905906907908909910911912913914915916917918919920921922923924925926927928929930931932933934935936937938939940941942943944945946947948949950951952953954955956957958959960961962963964965966967968969970971972973974975976977978979980981982983984985986987988989990991992993994995996997998999100010011002100310041005100610071008100910101011101210131014101510161017101810191020102110221023102410251026102710281029103010311032103310341035103610371038103910401041104210431044104510461047104810491050105110521053105410551056105710581059106010611062106310641065106610671068106910701071107210731074107510761077107810791080108110821083108410851086108710881089109010911092109310941095109610971098109911001101110211031104110511061107110811091110111111121113111411151116111711181119112011211122112311241125112611271128112911301131113211331134113511361137113811391140114111421143114411451146114711481149115011511152115311541155115611571158115911601161116211631164116511661167116811691170117111721173117411751176117711781179118011811182118311841185118611871188118911901191119211931194119511961197119811991200120112021203120412051206120712081209121012111212121312141215121612171218121912201221122212231224122512261227122812291230123112321233123412351236123712381239124012411242124312441245124612471248124912501251125212531254125512561257125812591260126112621263126412651266126712681269127012711272127312741275127612771278127912801281128212831284128512861287128812891290129112921293129412951296129712981299130013011302130313041305130613071308
  1. # coding=utf-8
  2. # Copyright 2022 The Salesforce Team Authors and The HuggingFace Team. All rights reserved.
  3. #
  4. # Licensed under the Apache License, Version 2.0 (the "License");
  5. # you may not use this file except in compliance with the License.
  6. # You may obtain a copy of the License at
  7. #
  8. # http://www.apache.org/licenses/LICENSE-2.0
  9. #
  10. # Unless required by applicable law or agreed to in writing, software
  11. # distributed under the License is distributed on an "AS IS" BASIS,
  12. # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  13. # See the License for the specific language governing permissions and
  14. # limitations under the License.
  15. """PyTorch BLIP model."""
  16. import warnings
  17. from dataclasses import dataclass
  18. from typing import Any, Optional, Union
  19. import torch
  20. from torch import nn
  21. from torch.nn.functional import normalize
  22. from ...activations import ACT2FN
  23. from ...generation import GenerationMixin
  24. from ...modeling_layers import GradientCheckpointingLayer
  25. from ...modeling_outputs import BaseModelOutput, BaseModelOutputWithPooling
  26. from ...modeling_utils import PreTrainedModel
  27. from ...processing_utils import Unpack
  28. from ...utils import ModelOutput, TransformersKwargs, auto_docstring, can_return_tuple, logging, torch_int
  29. from ...utils.generic import check_model_inputs
  30. from .configuration_blip import BlipConfig, BlipTextConfig, BlipVisionConfig
  31. from .modeling_blip_text import BlipTextLMHeadModel, BlipTextModel
  32. logger = logging.get_logger(__name__)
  33. # Copied from transformers.models.clip.modeling_clip.contrastive_loss
  34. def contrastive_loss(logits: torch.Tensor) -> torch.Tensor:
  35. return nn.functional.cross_entropy(logits, torch.arange(len(logits), device=logits.device))
  36. # Copied from transformers.models.clip.modeling_clip.clip_loss with clip->blip
  37. def blip_loss(similarity: torch.Tensor) -> torch.Tensor:
  38. caption_loss = contrastive_loss(similarity)
  39. image_loss = contrastive_loss(similarity.t())
  40. return (caption_loss + image_loss) / 2.0
  41. @dataclass
  42. @auto_docstring(
  43. custom_intro="""
  44. Adapted from the base class for vision model's outputs that also contains image embeddings of the pooling of the
  45. last hidden states. This class also adds the loss term from the text decoder.
  46. """
  47. )
  48. class BlipForConditionalGenerationModelOutput(ModelOutput):
  49. r"""
  50. loss (`torch.FloatTensor`, *optional*, returned when `labels` is provided, `torch.FloatTensor` of shape `(1,)`):
  51. Language modeling loss from the text decoder.
  52. logits (`torch.FloatTensor` of shape `(batch_size, sequence_length, config.vocab_size)`, *optional*):
  53. Prediction scores of the language modeling head of the text decoder model.
  54. image_embeds (`torch.FloatTensor` of shape `(batch_size, output_dim)`, *optional*):
  55. The image embeddings obtained after applying the Vision Transformer model to the input image.
  56. hidden_states (`tuple(torch.FloatTensor)`, *optional*, returned when `output_hidden_states=True`):
  57. Tuple of `torch.FloatTensor` (one for the output of the embeddings, if the model has an embedding layer, +
  58. one for the output of each layer) of shape `(batch_size, sequence_length, hidden_size)`.
  59. Hidden-states of the model at the output of each layer plus the optional initial embedding outputs.
  60. attentions (`tuple(torch.FloatTensor)`, *optional*, returned when `output_attentions=True` is passed):
  61. Tuple of `torch.FloatTensor` (one for each layer) of shape `(batch_size, num_heads, sequence_length,
  62. sequence_length)`.
  63. Attentions weights after the attention softmax, used to compute the weighted average in the self-attention
  64. heads.
  65. """
  66. loss: Optional[tuple[torch.FloatTensor]] = None
  67. logits: Optional[tuple[torch.FloatTensor]] = None
  68. image_embeds: Optional[torch.FloatTensor] = None
  69. last_hidden_state: Optional[torch.FloatTensor] = None
  70. hidden_states: Optional[tuple[torch.FloatTensor, ...]] = None
  71. attentions: Optional[tuple[torch.FloatTensor, ...]] = None
  72. @property
  73. def decoder_logits(self):
  74. warnings.warn(
  75. "`decoder_logits` attribute is deprecated and will be removed in version 5 of Transformers."
  76. " Please use the `logits` attribute to retrieve the final output instead.",
  77. FutureWarning,
  78. )
  79. return self.logits
  80. @dataclass
  81. @auto_docstring(
  82. custom_intro="""
  83. Adapted from the base class for vision model's outputs that also contains image embeddings of the pooling of the
  84. last hidden states. This class also adds the loss term from the text decoder.
  85. """
  86. )
  87. class BlipTextVisionModelOutput(ModelOutput):
  88. r"""
  89. loss (`torch.FloatTensor` of shape `(1,)`, *optional*, returned when `labels` is provided):
  90. Language modeling loss from the text decoder.
  91. image_embeds (`torch.FloatTensor` of shape `(batch_size, output_dim)` *optional* returned when model is initialized with `with_projection=True`):
  92. The image embeddings obtained by applying the projection layer to the pooler_output.
  93. """
  94. loss: Optional[torch.FloatTensor] = None
  95. image_embeds: Optional[torch.FloatTensor] = None
  96. last_hidden_state: Optional[torch.FloatTensor] = None
  97. hidden_states: Optional[tuple[torch.FloatTensor, ...]] = None
  98. attentions: Optional[tuple[torch.FloatTensor, ...]] = None
  99. @dataclass
  100. @auto_docstring(
  101. custom_intro="""
  102. Adapted from the base class for vision model's outputs that also contains image embeddings of the pooling of the
  103. last hidden states. This class also adds the loss term from the text decoder as well as the image-text similarity
  104. scores.
  105. """
  106. )
  107. class BlipImageTextMatchingModelOutput(ModelOutput):
  108. r"""
  109. itm_score (`torch.FloatTensor`):
  110. The image-text similarity scores.
  111. loss (`torch.FloatTensor` of shape `(1,)`, *optional*, returned when `labels` is provided):
  112. Language modeling loss from the text decoder.
  113. image_embeds (`torch.FloatTensor` of shape `(batch_size, output_dim)` *optional* returned when model is initialized with `with_projection=True`):
  114. The image embeddings obtained by applying the projection layer to the pooler_output.
  115. vision_pooler_output (`torch.FloatTensor` of shape `(batch_size, hidden_size)`, *optional*):
  116. Last layer hidden-state of the vision of the vision-only branch of the model.
  117. question_embeds (`torch.FloatTensor`):
  118. The question embeddings obtained by the text projection layer.
  119. """
  120. itm_score: Optional[torch.FloatTensor] = None
  121. loss: Optional[torch.FloatTensor] = None
  122. image_embeds: Optional[torch.FloatTensor] = None
  123. last_hidden_state: Optional[torch.FloatTensor] = None
  124. hidden_states: Optional[tuple[torch.FloatTensor, ...]] = None
  125. vision_pooler_output: Optional[torch.FloatTensor] = None
  126. attentions: Optional[tuple[torch.FloatTensor, ...]] = None
  127. question_embeds: Optional[tuple[torch.FloatTensor]] = None
  128. @dataclass
  129. @auto_docstring
  130. class BlipOutput(ModelOutput):
  131. r"""
  132. loss (`torch.FloatTensor` of shape `(1,)`, *optional*, returned when `return_loss` is `True`):
  133. Contrastive loss for image-text similarity.
  134. logits_per_image (`torch.FloatTensor` of shape `(image_batch_size, text_batch_size)`):
  135. The scaled dot product scores between `image_embeds` and `text_embeds`. This represents the image-text
  136. similarity scores.
  137. logits_per_text (`torch.FloatTensor` of shape `(text_batch_size, image_batch_size)`):
  138. The scaled dot product scores between `text_embeds` and `image_embeds`. This represents the text-image
  139. similarity scores.
  140. text_embeds (`torch.FloatTensor` of shape `(batch_size, output_dim`):
  141. The text embeddings obtained by applying the projection layer to the pooled output of [`BlipTextModel`].
  142. image_embeds (`torch.FloatTensor` of shape `(batch_size, output_dim`):
  143. The image embeddings obtained by applying the projection layer to the pooled output of [`BlipVisionModel`].
  144. text_model_output (`BaseModelOutputWithPooling`):
  145. The output of the [`BlipTextModel`].
  146. vision_model_output (`BaseModelOutputWithPooling`):
  147. The output of the [`BlipVisionModel`].
  148. """
  149. loss: Optional[torch.FloatTensor] = None
  150. logits_per_image: Optional[torch.FloatTensor] = None
  151. logits_per_text: Optional[torch.FloatTensor] = None
  152. text_embeds: Optional[torch.FloatTensor] = None
  153. image_embeds: Optional[torch.FloatTensor] = None
  154. text_model_output: BaseModelOutputWithPooling = None
  155. vision_model_output: BaseModelOutputWithPooling = None
  156. def to_tuple(self) -> tuple[Any]:
  157. return tuple(
  158. self[k] if k not in ["text_model_output", "vision_model_output"] else getattr(self, k).to_tuple()
  159. for k in self.keys()
  160. )
  161. class BlipVisionEmbeddings(nn.Module):
  162. def __init__(self, config: BlipVisionConfig):
  163. super().__init__()
  164. self.config = config
  165. self.embed_dim = config.hidden_size
  166. self.image_size = config.image_size
  167. self.patch_size = config.patch_size
  168. self.class_embedding = nn.Parameter(torch.randn(1, 1, self.embed_dim))
  169. self.patch_embedding = nn.Conv2d(
  170. in_channels=3, out_channels=self.embed_dim, kernel_size=self.patch_size, stride=self.patch_size
  171. )
  172. self.num_patches = (self.image_size // self.patch_size) ** 2
  173. self.num_positions = self.num_patches + 1
  174. self.position_embedding = nn.Parameter(torch.randn(1, self.num_positions, self.embed_dim))
  175. def interpolate_pos_encoding(self, embeddings: torch.Tensor, height: int, width: int) -> torch.Tensor:
  176. """
  177. This method allows to interpolate the pre-trained position encodings, to be able to use the model on higher resolution
  178. images. This method is also adapted to support torch.jit tracing.
  179. Adapted from:
  180. - https://github.com/facebookresearch/dino/blob/de9ee3df6cf39fac952ab558447af1fa1365362a/vision_transformer.py#L174-L194, and
  181. - https://github.com/facebookresearch/dinov2/blob/e1277af2ba9496fbadf7aec6eba56e8d882d1e35/dinov2/models/vision_transformer.py#L179-L211
  182. """
  183. num_patches = embeddings.shape[1] - 1
  184. num_positions = self.position_embedding.shape[1] - 1
  185. # always interpolate when tracing to ensure the exported model works for dynamic input shapes
  186. if not torch.jit.is_tracing() and num_patches == num_positions and height == width:
  187. return self.position_embedding
  188. class_pos_embed = self.position_embedding[:, :1]
  189. patch_pos_embed = self.position_embedding[:, 1:]
  190. dim = embeddings.shape[-1]
  191. new_height = height // self.patch_size
  192. new_width = width // self.patch_size
  193. sqrt_num_positions = torch_int(num_positions**0.5)
  194. patch_pos_embed = patch_pos_embed.reshape(1, sqrt_num_positions, sqrt_num_positions, dim)
  195. patch_pos_embed = patch_pos_embed.permute(0, 3, 1, 2)
  196. patch_pos_embed = nn.functional.interpolate(
  197. patch_pos_embed,
  198. size=(new_height, new_width),
  199. mode="bicubic",
  200. align_corners=False,
  201. )
  202. patch_pos_embed = patch_pos_embed.permute(0, 2, 3, 1).view(1, -1, dim)
  203. return torch.cat((class_pos_embed, patch_pos_embed), dim=1)
  204. def forward(self, pixel_values: torch.FloatTensor, interpolate_pos_encoding: bool = False) -> torch.Tensor:
  205. batch_size, _, height, width = pixel_values.shape
  206. target_dtype = self.patch_embedding.weight.dtype
  207. patch_embeds = self.patch_embedding(pixel_values.to(dtype=target_dtype)) # shape = [*, width, grid, grid]
  208. patch_embeds = patch_embeds.flatten(2).transpose(1, 2)
  209. class_embeds = self.class_embedding.expand(batch_size, 1, -1).to(target_dtype)
  210. embeddings = torch.cat([class_embeds, patch_embeds], dim=1)
  211. if interpolate_pos_encoding:
  212. position_embedding = self.interpolate_pos_encoding(embeddings, height, width)
  213. else:
  214. position_embedding = self.position_embedding
  215. embeddings = embeddings + position_embedding[:, : embeddings.size(1), :].to(target_dtype)
  216. return embeddings
  217. # Copied from transformers.models.clip.modeling_clip.CLIPTextEmbeddings with CLIP->Blip
  218. class BlipTextEmbeddings(nn.Module):
  219. def __init__(self, config: BlipTextConfig):
  220. super().__init__()
  221. embed_dim = config.hidden_size
  222. self.token_embedding = nn.Embedding(config.vocab_size, embed_dim)
  223. self.position_embedding = nn.Embedding(config.max_position_embeddings, embed_dim)
  224. # position_ids (1, len position emb) is contiguous in memory and exported when serialized
  225. self.register_buffer(
  226. "position_ids", torch.arange(config.max_position_embeddings).expand((1, -1)), persistent=False
  227. )
  228. def forward(
  229. self,
  230. input_ids: Optional[torch.LongTensor] = None,
  231. position_ids: Optional[torch.LongTensor] = None,
  232. inputs_embeds: Optional[torch.FloatTensor] = None,
  233. ) -> torch.Tensor:
  234. seq_length = input_ids.shape[-1] if input_ids is not None else inputs_embeds.shape[-2]
  235. max_position_embedding = self.position_embedding.weight.shape[0]
  236. if seq_length > max_position_embedding:
  237. raise ValueError(
  238. f"Sequence length must be less than max_position_embeddings (got `sequence length`: "
  239. f"{seq_length} and max_position_embeddings: {max_position_embedding}"
  240. )
  241. if position_ids is None:
  242. position_ids = self.position_ids[:, :seq_length]
  243. if inputs_embeds is None:
  244. inputs_embeds = self.token_embedding(input_ids)
  245. position_embeddings = self.position_embedding(position_ids)
  246. embeddings = inputs_embeds + position_embeddings
  247. return embeddings
  248. class BlipAttention(nn.Module):
  249. """Multi-headed attention from 'Attention Is All You Need' paper"""
  250. def __init__(self, config):
  251. super().__init__()
  252. self.config = config
  253. self.embed_dim = config.hidden_size
  254. self.num_heads = config.num_attention_heads
  255. self.head_dim = self.embed_dim // self.num_heads
  256. if self.head_dim * self.num_heads != self.embed_dim:
  257. raise ValueError(
  258. f"embed_dim must be divisible by num_heads (got `embed_dim`: {self.embed_dim} and `num_heads`:"
  259. f" {self.num_heads})."
  260. )
  261. self.scale = self.head_dim**-0.5
  262. self.dropout = nn.Dropout(config.attention_dropout)
  263. self.qkv = nn.Linear(self.embed_dim, 3 * self.embed_dim)
  264. self.projection = nn.Linear(self.embed_dim, self.embed_dim)
  265. def _shape(self, tensor: torch.Tensor, seq_len: int, bsz: int):
  266. return tensor.view(bsz, seq_len, self.num_heads, self.head_dim).transpose(1, 2).contiguous()
  267. def forward(
  268. self,
  269. hidden_states: torch.Tensor,
  270. head_mask: Optional[torch.Tensor] = None,
  271. **kwargs: Unpack[TransformersKwargs],
  272. ) -> tuple[torch.Tensor, torch.Tensor]:
  273. """Input shape: Batch x Time x Channel"""
  274. bsz, tgt_len, embed_dim = hidden_states.size()
  275. mixed_qkv = (
  276. self.qkv(hidden_states)
  277. .reshape(bsz, tgt_len, 3, self.num_heads, embed_dim // self.num_heads)
  278. .permute(2, 0, 3, 1, 4)
  279. )
  280. query_states, key_states, value_states = mixed_qkv[0], mixed_qkv[1], mixed_qkv[2]
  281. # Take the dot product between "query" and "key" to get the raw attention scores.
  282. attention_scores = torch.matmul(query_states, key_states.transpose(-1, -2))
  283. attention_scores = attention_scores * self.scale
  284. # Normalize the attention scores to probabilities.
  285. attention_probs = nn.functional.softmax(attention_scores, dim=-1)
  286. # This is actually dropping out entire tokens to attend to, which might
  287. # seem a bit unusual, but is taken from the original Transformer paper.
  288. attention_probs = self.dropout(attention_probs)
  289. # Mask heads if we want to
  290. if head_mask is not None:
  291. attention_probs = attention_probs * head_mask
  292. context_layer = torch.matmul(attention_probs, value_states).permute(0, 2, 1, 3)
  293. new_context_layer_shape = context_layer.size()[:-2] + (self.embed_dim,)
  294. context_layer = context_layer.reshape(new_context_layer_shape)
  295. output = self.projection(context_layer)
  296. return output, attention_probs
  297. # Copied from transformers.models.clip.modeling_clip.CLIPMLP with CLIP->Blip
  298. class BlipMLP(nn.Module):
  299. def __init__(self, config):
  300. super().__init__()
  301. self.config = config
  302. self.activation_fn = ACT2FN[config.hidden_act]
  303. self.fc1 = nn.Linear(config.hidden_size, config.intermediate_size)
  304. self.fc2 = nn.Linear(config.intermediate_size, config.hidden_size)
  305. def forward(self, hidden_states: torch.Tensor) -> torch.Tensor:
  306. hidden_states = self.fc1(hidden_states)
  307. hidden_states = self.activation_fn(hidden_states)
  308. hidden_states = self.fc2(hidden_states)
  309. return hidden_states
  310. class BlipEncoderLayer(GradientCheckpointingLayer):
  311. def __init__(self, config: BlipConfig):
  312. super().__init__()
  313. self.embed_dim = config.hidden_size
  314. self.self_attn = BlipAttention(config)
  315. self.layer_norm1 = nn.LayerNorm(self.embed_dim, eps=config.layer_norm_eps)
  316. self.mlp = BlipMLP(config)
  317. self.layer_norm2 = nn.LayerNorm(self.embed_dim, eps=config.layer_norm_eps)
  318. @auto_docstring
  319. def forward(
  320. self,
  321. hidden_states: torch.Tensor,
  322. attention_mask: torch.Tensor,
  323. **kwargs: Unpack[TransformersKwargs],
  324. ) -> torch.FloatTensor:
  325. residual = hidden_states
  326. hidden_states = self.layer_norm1(hidden_states)
  327. hidden_states, _ = self.self_attn(
  328. hidden_states=hidden_states,
  329. head_mask=attention_mask,
  330. **kwargs,
  331. )
  332. hidden_states = hidden_states + residual
  333. residual = hidden_states
  334. hidden_states = self.layer_norm2(hidden_states)
  335. hidden_states = self.mlp(hidden_states)
  336. hidden_states = hidden_states + residual
  337. return hidden_states
  338. @auto_docstring
  339. class BlipPreTrainedModel(PreTrainedModel):
  340. config: BlipConfig
  341. base_model_prefix = "blip"
  342. supports_gradient_checkpointing = True
  343. _no_split_modules = ["BlipEncoderLayer", "BlipTextEmbeddings"]
  344. _skip_keys_device_placement = ["past_key_values"]
  345. def _init_weights(self, module):
  346. """Initialize the weights"""
  347. factor = self.config.initializer_range
  348. if isinstance(module, (nn.Conv2d, nn.Embedding, nn.Linear)):
  349. module.weight.data.normal_(mean=0.0, std=factor)
  350. if hasattr(module, "bias") and module.bias is not None:
  351. module.bias.data.zero_()
  352. if isinstance(module, BlipVisionEmbeddings):
  353. if hasattr(self.config, "vision_config"):
  354. factor = self.config.vision_config.initializer_range
  355. nn.init.trunc_normal_(
  356. module.position_embedding,
  357. mean=0.0,
  358. std=factor,
  359. )
  360. nn.init.trunc_normal_(
  361. module.class_embedding,
  362. mean=0.0,
  363. std=factor,
  364. )
  365. elif isinstance(module, nn.LayerNorm):
  366. module.bias.data.zero_()
  367. module.weight.data.fill_(1.0)
  368. elif isinstance(module, nn.Linear) and module.bias is not None:
  369. module.bias.data.zero_()
  370. class BlipEncoder(nn.Module):
  371. """
  372. Transformer encoder consisting of `config.num_hidden_layers` self attention layers. Each layer is a
  373. [`BlipEncoderLayer`].
  374. Args:
  375. config (`BlipConfig`):
  376. The corresponding vision configuration for the `BlipEncoder`.
  377. """
  378. def __init__(self, config: BlipConfig):
  379. super().__init__()
  380. self.config = config
  381. self.layers = nn.ModuleList([BlipEncoderLayer(config) for _ in range(config.num_hidden_layers)])
  382. self.gradient_checkpointing = False
  383. @auto_docstring
  384. def forward(
  385. self,
  386. inputs_embeds,
  387. attention_mask: Optional[torch.Tensor] = None,
  388. **kwargs: Unpack[TransformersKwargs],
  389. ) -> Union[tuple, BaseModelOutput]:
  390. hidden_states = inputs_embeds
  391. for encoder_layer in self.layers:
  392. hidden_states = encoder_layer(
  393. hidden_states,
  394. attention_mask=attention_mask,
  395. **kwargs,
  396. )
  397. return BaseModelOutput(last_hidden_state=hidden_states)
  398. class BlipVisionModel(BlipPreTrainedModel):
  399. main_input_name = "pixel_values"
  400. config: BlipVisionConfig
  401. _can_record_outputs = {
  402. "hidden_states": BlipEncoderLayer,
  403. "attentions": BlipAttention,
  404. }
  405. def __init__(self, config: BlipVisionConfig):
  406. super().__init__(config)
  407. self.config = config
  408. embed_dim = config.hidden_size
  409. self.embeddings = BlipVisionEmbeddings(config)
  410. self.encoder = BlipEncoder(config)
  411. self.post_layernorm = nn.LayerNorm(embed_dim, eps=config.layer_norm_eps)
  412. self.post_init()
  413. @check_model_inputs(tie_last_hidden_states=False)
  414. @auto_docstring
  415. def forward(
  416. self,
  417. pixel_values: Optional[torch.FloatTensor] = None,
  418. interpolate_pos_encoding: bool = False,
  419. **kwargs: Unpack[TransformersKwargs],
  420. ) -> Union[tuple, BaseModelOutputWithPooling]:
  421. if pixel_values is None:
  422. raise ValueError("You have to specify pixel_values")
  423. hidden_states = self.embeddings(pixel_values, interpolate_pos_encoding=interpolate_pos_encoding)
  424. encoder_outputs: BaseModelOutput = self.encoder(
  425. inputs_embeds=hidden_states,
  426. **kwargs,
  427. )
  428. last_hidden_state = encoder_outputs.last_hidden_state
  429. last_hidden_state = self.post_layernorm(last_hidden_state)
  430. pooled_output = last_hidden_state[:, 0, :]
  431. pooled_output = self.post_layernorm(pooled_output)
  432. return BaseModelOutputWithPooling(
  433. last_hidden_state=last_hidden_state,
  434. pooler_output=pooled_output,
  435. )
  436. def get_input_embeddings(self):
  437. return self.embeddings
  438. @auto_docstring(
  439. custom_intro="""
  440. This model is going to be deprecated in future versions. Please use `BlipForConditionalGeneration`, `BlipForQuestionAnswering` or `BlipForImageTextRetrieval` depending on your usecase.
  441. """
  442. )
  443. class BlipModel(BlipPreTrainedModel):
  444. config: BlipConfig
  445. def __init__(self, config: BlipConfig):
  446. super().__init__(config)
  447. if not isinstance(config.text_config, BlipTextConfig):
  448. raise TypeError(
  449. "config.text_config is expected to be of type BlipTextConfig but is of type"
  450. f" {type(config.text_config)}."
  451. )
  452. if not isinstance(config.vision_config, BlipVisionConfig):
  453. raise TypeError(
  454. "config.vision_config is expected to be of type BlipVisionConfig but is of type"
  455. f" {type(config.vision_config)}."
  456. )
  457. text_config = config.text_config
  458. vision_config = config.vision_config
  459. self.projection_dim = config.projection_dim
  460. self.text_embed_dim = text_config.hidden_size
  461. self.vision_embed_dim = vision_config.hidden_size
  462. self.text_model = BlipTextModel(text_config)
  463. self.vision_model = BlipVisionModel(vision_config)
  464. self.visual_projection = nn.Linear(self.vision_embed_dim, self.projection_dim, bias=False)
  465. self.text_projection = nn.Linear(self.text_embed_dim, self.projection_dim, bias=False)
  466. self.logit_scale = nn.Parameter(torch.tensor(self.config.logit_scale_init_value))
  467. logger.warning(
  468. "`BlipModel` is going to be deprecated in future release, please use `BlipForConditionalGeneration`, `BlipForQuestionAnswering` or `BlipForImageTextRetrieval` depending on your usecase."
  469. )
  470. # Initialize weights and apply final processing
  471. self.post_init()
  472. def get_input_embeddings(self):
  473. return self.text_model.get_input_embeddings()
  474. def set_input_embeddings(self, value):
  475. self.text_model.set_input_embeddings(value)
  476. @auto_docstring
  477. def get_text_features(
  478. self,
  479. input_ids: Optional[torch.Tensor] = None,
  480. attention_mask: Optional[torch.Tensor] = None,
  481. position_ids: Optional[torch.Tensor] = None,
  482. ) -> torch.FloatTensor:
  483. r"""
  484. Returns:
  485. text_features (`torch.FloatTensor` of shape `(batch_size, output_dim`): The text embeddings obtained by
  486. applying the projection layer to the pooled output of [`BlipTextModel`].
  487. Examples:
  488. ```python
  489. >>> from transformers import AutoProcessor, BlipModel
  490. >>> model = BlipModel.from_pretrained("Salesforce/blip-image-captioning-base")
  491. >>> processor = AutoProcessor.from_pretrained("Salesforce/blip-image-captioning-base")
  492. >>> inputs = processor(text=["a photo of a cat", "a photo of a dog"], padding=True, return_tensors="pt")
  493. >>> text_features = model.get_text_features(**inputs)
  494. ```"""
  495. text_outputs = self.text_model(
  496. input_ids=input_ids,
  497. attention_mask=attention_mask,
  498. position_ids=position_ids,
  499. )
  500. pooled_output = text_outputs[1]
  501. text_features = self.text_projection(pooled_output)
  502. return text_features
  503. @auto_docstring
  504. def get_image_features(
  505. self,
  506. pixel_values: Optional[torch.FloatTensor] = None,
  507. interpolate_pos_encoding: bool = False,
  508. ) -> torch.FloatTensor:
  509. r"""
  510. Returns:
  511. image_features (`torch.FloatTensor` of shape `(batch_size, output_dim`): The image embeddings obtained by
  512. applying the projection layer to the pooled output of [`BlipVisionModel`].
  513. Examples:
  514. ```python
  515. >>> from PIL import Image
  516. >>> import requests
  517. >>> from transformers import AutoProcessor, BlipModel
  518. >>> model = BlipModel.from_pretrained("Salesforce/blip-image-captioning-base")
  519. >>> processor = AutoProcessor.from_pretrained("Salesforce/blip-image-captioning-base")
  520. >>> url = "http://images.cocodataset.org/val2017/000000039769.jpg"
  521. >>> image = Image.open(requests.get(url, stream=True).raw)
  522. >>> inputs = processor(images=image, return_tensors="pt")
  523. >>> image_features = model.get_image_features(**inputs)
  524. ```"""
  525. vision_outputs = self.vision_model(
  526. pixel_values=pixel_values,
  527. interpolate_pos_encoding=interpolate_pos_encoding,
  528. )
  529. pooled_output = vision_outputs[1] # pooled_output
  530. image_features = self.visual_projection(pooled_output)
  531. return image_features
  532. @auto_docstring
  533. def get_multimodal_features(
  534. self,
  535. input_ids: Optional[torch.LongTensor] = None,
  536. pixel_values: Optional[torch.FloatTensor] = None,
  537. attention_mask: Optional[torch.Tensor] = None,
  538. interpolate_pos_encoding: bool = False,
  539. ) -> torch.FloatTensor:
  540. r"""
  541. Returns:
  542. multimodal_features (`torch.FloatTensor` of shape `(batch_size, output_dim`): The multimodal embeddings
  543. obtained by applying the image embeddings to the text encoder using the cross-attention mechanism.
  544. Examples:
  545. ```python
  546. >>> from PIL import Image
  547. >>> import requests
  548. >>> from transformers import AutoProcessor, BlipModel
  549. >>> model = BlipModel.from_pretrained("Salesforce/blip-image-captioning-base")
  550. >>> processor = AutoProcessor.from_pretrained("Salesforce/blip-image-captioning-base")
  551. >>> url = "http://images.cocodataset.org/val2017/000000039769.jpg"
  552. >>> image = Image.open(requests.get(url, stream=True).raw)
  553. >>> texts = ["a photo of a cat", "a photo of a dog"]
  554. >>> inputs = processor(images=image, text=texts, padding=True, return_tensors="pt")
  555. >>> multimodal_features = model.get_multimodal_features(**inputs)
  556. ```"""
  557. vision_outputs = self.vision_model(
  558. pixel_values=pixel_values,
  559. interpolate_pos_encoding=interpolate_pos_encoding,
  560. )
  561. image_embeds = vision_outputs[0]
  562. image_atts = torch.ones(image_embeds.size()[:-1], dtype=torch.long)
  563. text_outputs = self.text_model(
  564. input_ids=input_ids,
  565. attention_mask=attention_mask,
  566. encoder_hidden_states=image_embeds,
  567. encoder_attention_mask=image_atts,
  568. )
  569. pooled_output = text_outputs[1] # pooled_output
  570. multimodal_features = self.text_projection(pooled_output)
  571. return multimodal_features
  572. @can_return_tuple
  573. @auto_docstring
  574. def forward(
  575. self,
  576. input_ids: Optional[torch.LongTensor] = None,
  577. pixel_values: Optional[torch.FloatTensor] = None,
  578. attention_mask: Optional[torch.Tensor] = None,
  579. position_ids: Optional[torch.LongTensor] = None,
  580. return_loss: Optional[bool] = None,
  581. interpolate_pos_encoding: bool = False,
  582. **kwargs: Unpack[TransformersKwargs],
  583. ) -> Union[tuple, BlipOutput]:
  584. r"""
  585. return_loss (`bool`, *optional*):
  586. Whether or not to return the contrastive loss.
  587. Examples:
  588. ```python
  589. >>> from PIL import Image
  590. >>> import requests
  591. >>> from transformers import AutoProcessor, BlipModel
  592. >>> model = BlipModel.from_pretrained("Salesforce/blip-image-captioning-base")
  593. >>> processor = AutoProcessor.from_pretrained("Salesforce/blip-image-captioning-base")
  594. >>> url = "http://images.cocodataset.org/val2017/000000039769.jpg"
  595. >>> image = Image.open(requests.get(url, stream=True).raw)
  596. >>> inputs = processor(
  597. ... text=["a photo of a cat", "a photo of a dog"], images=image, return_tensors="pt", padding=True
  598. ... )
  599. >>> outputs = model(**inputs)
  600. >>> logits_per_image = outputs.logits_per_image # this is the image-text similarity score
  601. >>> probs = logits_per_image.softmax(dim=1) # we can take the softmax to get the label probabilities
  602. ```"""
  603. vision_outputs = self.vision_model(
  604. pixel_values=pixel_values,
  605. interpolate_pos_encoding=interpolate_pos_encoding,
  606. **kwargs,
  607. )
  608. text_outputs = self.text_model(
  609. input_ids=input_ids,
  610. attention_mask=attention_mask,
  611. position_ids=position_ids,
  612. **kwargs,
  613. )
  614. image_embeds = vision_outputs.pooler_output
  615. image_embeds = self.visual_projection(image_embeds)
  616. text_embeds = text_outputs.pooler_output
  617. text_embeds = self.text_projection(text_embeds)
  618. # normalized features
  619. image_embeds = image_embeds / image_embeds.norm(p=2, dim=-1, keepdim=True)
  620. text_embeds = text_embeds / text_embeds.norm(p=2, dim=-1, keepdim=True)
  621. # cosine similarity as logits
  622. logit_scale = self.logit_scale.exp().to(device=text_embeds.device)
  623. image_embeds = image_embeds.to(device=text_embeds.device, dtype=text_embeds.dtype)
  624. logits_per_text = torch.matmul(text_embeds, image_embeds.t()) * logit_scale
  625. logits_per_image = logits_per_text.t()
  626. loss = None
  627. if return_loss:
  628. loss = blip_loss(logits_per_text)
  629. return BlipOutput(
  630. loss=loss,
  631. logits_per_image=logits_per_image,
  632. logits_per_text=logits_per_text,
  633. text_embeds=text_embeds,
  634. image_embeds=image_embeds,
  635. text_model_output=text_outputs,
  636. vision_model_output=vision_outputs,
  637. )
  638. @auto_docstring(
  639. custom_intro="""
  640. BLIP Model for image captioning. The model consists of a vision encoder and a text decoder. One can optionally pass
  641. `input_ids` to the model, which serve as a text prompt, to make the text decoder continue the prompt. Otherwise,
  642. the decoder starts generating text from the [BOS] (beginning-of-sequence) token. will start generating the caption
  643. from the text input. If no text input is provided, the decoder will start with the [BOS] token only.
  644. """
  645. )
  646. class BlipForConditionalGeneration(BlipPreTrainedModel, GenerationMixin):
  647. config: BlipConfig
  648. _tied_weights_keys = ["text_decoder.cls.predictions.decoder.bias"]
  649. main_input_name = "pixel_values"
  650. def __init__(self, config: BlipConfig):
  651. super().__init__(config)
  652. self.vision_model = BlipVisionModel(config.vision_config)
  653. self.text_decoder = BlipTextLMHeadModel(config.text_config)
  654. self.decoder_input_ids = config.text_config.bos_token_id
  655. self.decoder_pad_token_id = config.text_config.pad_token_id
  656. # Initialize weights and apply final processing
  657. self.post_init()
  658. def get_input_embeddings(self):
  659. return self.text_decoder.get_input_embeddings()
  660. def set_input_embeddings(self, value):
  661. self.text_decoder.set_input_embeddings(value)
  662. @can_return_tuple
  663. @auto_docstring
  664. def forward(
  665. self,
  666. pixel_values: torch.FloatTensor,
  667. input_ids: Optional[torch.LongTensor] = None,
  668. attention_mask: Optional[torch.LongTensor] = None,
  669. labels: Optional[torch.LongTensor] = None,
  670. interpolate_pos_encoding: bool = False,
  671. **kwargs: Unpack[TransformersKwargs],
  672. ) -> Union[tuple, BlipForConditionalGenerationModelOutput]:
  673. r"""
  674. Examples:
  675. ```python
  676. >>> from PIL import Image
  677. >>> import requests
  678. >>> from transformers import AutoProcessor, BlipForConditionalGeneration
  679. >>> processor = AutoProcessor.from_pretrained("Salesforce/blip-image-captioning-base")
  680. >>> model = BlipForConditionalGeneration.from_pretrained("Salesforce/blip-image-captioning-base")
  681. >>> url = "http://images.cocodataset.org/val2017/000000039769.jpg"
  682. >>> image = Image.open(requests.get(url, stream=True).raw)
  683. >>> text = "A picture of"
  684. >>> inputs = processor(images=image, text=text, return_tensors="pt")
  685. >>> outputs = model(**inputs)
  686. ```"""
  687. vision_outputs = self.vision_model(
  688. pixel_values=pixel_values,
  689. interpolate_pos_encoding=interpolate_pos_encoding,
  690. **kwargs,
  691. )
  692. image_embeds = vision_outputs.last_hidden_state
  693. outputs = self.text_decoder(
  694. input_ids=input_ids,
  695. attention_mask=attention_mask,
  696. encoder_hidden_states=image_embeds,
  697. labels=labels,
  698. reduction="mean",
  699. **kwargs,
  700. )
  701. return BlipForConditionalGenerationModelOutput(
  702. loss=outputs.loss,
  703. logits=outputs.logits,
  704. image_embeds=image_embeds,
  705. last_hidden_state=vision_outputs.last_hidden_state,
  706. hidden_states=vision_outputs.hidden_states,
  707. attentions=vision_outputs.attentions,
  708. )
  709. @torch.no_grad()
  710. def generate(
  711. self,
  712. pixel_values: torch.FloatTensor,
  713. input_ids: Optional[torch.LongTensor] = None,
  714. attention_mask: Optional[torch.LongTensor] = None,
  715. interpolate_pos_encoding: bool = False,
  716. **generate_kwargs,
  717. ) -> torch.LongTensor:
  718. r"""
  719. Overrides *generate* function to be able to use the model as a conditional generator
  720. Parameters:
  721. pixel_values (*torch.FloatTensor* of shape *(batch_size, num_channels, image_height, image_width)*:
  722. Input image to be processed
  723. input_ids (*torch.LongTensor* of shape *(batch_size, sequence_length)*, *optional*):
  724. The sequence used as a prompt for the generation.
  725. attention_mask (*torch.LongTensor* of shape *(batch_size, sequence_length)*, *optional*):
  726. Mask to avoid performing attention on padding token indices. Mask values selected in `[0, 1]`:
  727. Examples:
  728. ```python
  729. >>> from PIL import Image
  730. >>> import requests
  731. >>> from transformers import AutoProcessor, BlipForConditionalGeneration
  732. >>> model = BlipForConditionalGeneration.from_pretrained("Salesforce/blip-image-captioning-base")
  733. >>> processor = AutoProcessor.from_pretrained("Salesforce/blip-image-captioning-base")
  734. >>> url = "http://images.cocodataset.org/val2017/000000039769.jpg"
  735. >>> image = Image.open(requests.get(url, stream=True).raw)
  736. >>> inputs = processor(images=image, return_tensors="pt")
  737. >>> outputs = model.generate(**inputs)
  738. >>> print(processor.decode(outputs[0], skip_special_tokens=True))
  739. two cats sleeping on a couch
  740. ```
  741. """
  742. batch_size = pixel_values.shape[0]
  743. vision_outputs = self.vision_model(
  744. pixel_values=pixel_values,
  745. interpolate_pos_encoding=interpolate_pos_encoding,
  746. )
  747. image_embeds = vision_outputs[0]
  748. image_attention_mask = torch.ones(image_embeds.size()[:-1], dtype=torch.long, device=image_embeds.device)
  749. if isinstance(input_ids, list):
  750. input_ids = torch.LongTensor(input_ids)
  751. elif input_ids is None:
  752. input_ids = (
  753. torch.LongTensor([[self.decoder_input_ids, self.config.text_config.eos_token_id]])
  754. .repeat(batch_size, 1)
  755. .to(image_embeds.device)
  756. )
  757. input_ids[:, 0] = self.config.text_config.bos_token_id
  758. attention_mask = attention_mask[:, :-1] if attention_mask is not None else None
  759. outputs = self.text_decoder.generate(
  760. input_ids=input_ids[:, :-1],
  761. eos_token_id=self.config.text_config.sep_token_id,
  762. pad_token_id=self.config.text_config.pad_token_id,
  763. attention_mask=attention_mask,
  764. encoder_hidden_states=image_embeds,
  765. encoder_attention_mask=image_attention_mask,
  766. **generate_kwargs,
  767. )
  768. return outputs
  769. @auto_docstring(
  770. custom_intro="""
  771. BLIP Model for visual question answering. The model consists of a vision encoder, a text encoder as well as a text
  772. decoder. The vision encoder will encode the input image, the text encoder will encode the input question together
  773. with the encoding of the image, and the text decoder will output the answer to the question.
  774. """
  775. )
  776. class BlipForQuestionAnswering(BlipPreTrainedModel, GenerationMixin):
  777. config: BlipConfig
  778. _tied_weights_keys = ["text_decoder.cls.predictions.decoder.bias"]
  779. def __init__(self, config: BlipConfig):
  780. super().__init__(config)
  781. self.vision_model = BlipVisionModel(config.vision_config)
  782. self.text_encoder = BlipTextModel(config.text_config, add_pooling_layer=False)
  783. self.text_decoder = BlipTextLMHeadModel(config.text_config)
  784. self.decoder_pad_token_id = config.text_config.pad_token_id
  785. self.decoder_start_token_id = config.text_config.bos_token_id
  786. # Initialize weights and apply final processing
  787. self.post_init()
  788. def set_input_embeddings(self, value):
  789. self.text_encoder.set_input_embeddings(value)
  790. def get_input_embeddings(self):
  791. # This will return shared embeddings if they are shared else specific to encoder.
  792. return self.text_encoder.get_input_embeddings()
  793. @can_return_tuple
  794. @auto_docstring
  795. def forward(
  796. self,
  797. input_ids: torch.LongTensor,
  798. pixel_values: torch.FloatTensor,
  799. decoder_input_ids: Optional[torch.LongTensor] = None,
  800. decoder_attention_mask: Optional[torch.LongTensor] = None,
  801. attention_mask: Optional[torch.LongTensor] = None,
  802. labels: Optional[torch.LongTensor] = None,
  803. interpolate_pos_encoding: bool = False,
  804. **kwargs: Unpack[TransformersKwargs],
  805. ) -> Union[tuple, BlipTextVisionModelOutput]:
  806. r"""
  807. Examples:
  808. ```python
  809. >>> from PIL import Image
  810. >>> import requests
  811. >>> from transformers import AutoProcessor, BlipForQuestionAnswering
  812. >>> model = BlipForQuestionAnswering.from_pretrained("Salesforce/blip-vqa-base")
  813. >>> processor = AutoProcessor.from_pretrained("Salesforce/blip-vqa-base")
  814. >>> url = "http://images.cocodataset.org/val2017/000000039769.jpg"
  815. >>> image = Image.open(requests.get(url, stream=True).raw)
  816. >>> # training
  817. >>> text = "How many cats are in the picture?"
  818. >>> label = "2"
  819. >>> inputs = processor(images=image, text=text, return_tensors="pt")
  820. >>> labels = processor(text=label, return_tensors="pt").input_ids
  821. >>> inputs["labels"] = labels
  822. >>> outputs = model(**inputs)
  823. >>> loss = outputs.loss
  824. >>> loss.backward()
  825. >>> # inference
  826. >>> text = "How many cats are in the picture?"
  827. >>> inputs = processor(images=image, text=text, return_tensors="pt")
  828. >>> outputs = model.generate(**inputs)
  829. >>> print(processor.decode(outputs[0], skip_special_tokens=True))
  830. 2
  831. ```"""
  832. if labels is None and decoder_input_ids is None:
  833. raise ValueError(
  834. "Either `decoder_input_ids` or `labels` should be passed when calling `forward` with"
  835. " `BlipForQuestionAnswering`. if you are training the model make sure that `labels` is passed, if you"
  836. " are using the model for inference make sure that `decoder_input_ids` is passed or call `generate`"
  837. )
  838. vision_outputs = self.vision_model(
  839. pixel_values=pixel_values,
  840. interpolate_pos_encoding=interpolate_pos_encoding,
  841. **kwargs,
  842. )
  843. image_embeds = vision_outputs.last_hidden_state
  844. image_attention_mask = torch.ones(image_embeds.size()[:-1], dtype=torch.long)
  845. question_embeds = self.text_encoder(
  846. input_ids=input_ids,
  847. attention_mask=attention_mask,
  848. encoder_hidden_states=image_embeds,
  849. encoder_attention_mask=image_attention_mask,
  850. **kwargs,
  851. )
  852. if labels is not None and decoder_input_ids is None:
  853. # labels are already shifted right, see: https://github.com/huggingface/transformers/pull/23153
  854. decoder_input_ids = labels
  855. question_embeds = question_embeds[0]
  856. answer_output = self.text_decoder(
  857. input_ids=decoder_input_ids,
  858. attention_mask=decoder_attention_mask,
  859. encoder_hidden_states=question_embeds,
  860. encoder_attention_mask=attention_mask,
  861. labels=labels,
  862. reduction="mean",
  863. **kwargs,
  864. )
  865. if labels is not None:
  866. decoder_loss = answer_output.loss.mean()
  867. else:
  868. decoder_loss = None
  869. return BlipTextVisionModelOutput(
  870. loss=decoder_loss,
  871. image_embeds=image_embeds,
  872. last_hidden_state=vision_outputs.last_hidden_state,
  873. hidden_states=vision_outputs.hidden_states,
  874. attentions=vision_outputs.attentions,
  875. )
  876. @torch.no_grad()
  877. def generate(
  878. self,
  879. input_ids: torch.LongTensor,
  880. pixel_values: torch.FloatTensor,
  881. attention_mask: Optional[torch.LongTensor] = None,
  882. interpolate_pos_encoding: bool = False,
  883. **generate_kwargs,
  884. ) -> torch.LongTensor:
  885. r"""
  886. Overrides *generate* function to be able to use the model as a conditional generator
  887. Parameters:
  888. input_ids (*torch.LongTensor* of shape *(batch_size, sequence_length)*):
  889. The sequence used as a prompt for the generation.
  890. pixel_values (*torch.FloatTensor* of shape *(batch_size, num_channels, image_height, image_width)*:
  891. Input image to be processed
  892. attention_mask (*torch.LongTensor* of shape *(batch_size, sequence_length)*, *optional*):
  893. Mask to avoid performing attention on padding token indices. Mask values selected in `[0, 1]`. `1` for
  894. tokens that are NOT MASKED, `0` for MASKED tokens.
  895. **generate_kwargs:
  896. Additional arguments passed to the *generate* function of the decoder
  897. Examples:
  898. ```python
  899. >>> from PIL import Image
  900. >>> import requests
  901. >>> from transformers import AutoProcessor, BlipForQuestionAnswering
  902. >>> model = BlipForQuestionAnswering.from_pretrained("Salesforce/blip-vqa-base")
  903. >>> processor = AutoProcessor.from_pretrained("Salesforce/blip-vqa-base")
  904. >>> url = "http://images.cocodataset.org/val2017/000000039769.jpg"
  905. >>> image = Image.open(requests.get(url, stream=True).raw)
  906. >>> text = "How many cats are in the picture?"
  907. >>> inputs = processor(images=image, text=text, return_tensors="pt")
  908. >>> outputs = model.generate(**inputs)
  909. >>> print(processor.decode(outputs[0], skip_special_tokens=True))
  910. 2
  911. ```
  912. """
  913. vision_outputs = self.vision_model(
  914. pixel_values=pixel_values,
  915. interpolate_pos_encoding=interpolate_pos_encoding,
  916. )
  917. image_embeds = vision_outputs[0]
  918. image_attention_mask = torch.ones(image_embeds.size()[:-1], dtype=torch.long, device=image_embeds.device)
  919. if isinstance(input_ids, list):
  920. input_ids = torch.LongTensor(input_ids)
  921. question_outputs = self.text_encoder(
  922. input_ids=input_ids,
  923. attention_mask=attention_mask,
  924. encoder_hidden_states=image_embeds,
  925. encoder_attention_mask=image_attention_mask,
  926. return_dict=False,
  927. )
  928. question_embeds = question_outputs[0]
  929. question_attention_mask = torch.ones(
  930. question_embeds.size()[:-1], dtype=torch.long, device=question_embeds.device
  931. )
  932. bos_ids = torch.full(
  933. (question_embeds.size(0), 1), fill_value=self.decoder_start_token_id, device=question_embeds.device
  934. )
  935. outputs = self.text_decoder.generate(
  936. input_ids=bos_ids,
  937. eos_token_id=self.config.text_config.sep_token_id,
  938. pad_token_id=self.config.text_config.pad_token_id,
  939. encoder_hidden_states=question_embeds,
  940. encoder_attention_mask=question_attention_mask,
  941. **generate_kwargs,
  942. )
  943. return outputs
  944. @auto_docstring(
  945. custom_intro="""
  946. BLIP Model with a vision and text projector, and a classification head on top. The model is used in the context of
  947. image-text retrieval. Given an image and a text, the model returns the probability of the text being relevant to
  948. the image.
  949. """
  950. )
  951. class BlipForImageTextRetrieval(BlipPreTrainedModel):
  952. config: BlipConfig
  953. def __init__(self, config: BlipConfig):
  954. super().__init__(config)
  955. self.vision_model = BlipVisionModel(config.vision_config)
  956. self.text_encoder = BlipTextModel(config.text_config, add_pooling_layer=False)
  957. # vision projection layer
  958. self.vision_proj = nn.Linear(config.vision_config.hidden_size, config.image_text_hidden_size)
  959. # text projection layer
  960. self.text_proj = nn.Linear(config.text_config.hidden_size, config.image_text_hidden_size)
  961. # image text matching head
  962. self.itm_head = nn.Linear(config.text_config.hidden_size, 2)
  963. self.decoder_pad_token_id = (
  964. config.text_config.pad_token_id
  965. if not hasattr(config, "decoder_pad_token_id")
  966. else config.decoder_pad_token_id
  967. )
  968. self.decoder_start_token_id = (
  969. config.text_config.bos_token_id
  970. if not hasattr(config, "decoder_start_token_id")
  971. else config.decoder_start_token_id
  972. )
  973. # Initialize weights and apply final processing
  974. self.post_init()
  975. def get_input_embeddings(self):
  976. return self.text_encoder.get_input_embeddings()
  977. def set_input_embeddings(self, value):
  978. self.text_encoder.set_input_embeddings(value)
  979. @can_return_tuple
  980. @auto_docstring
  981. def forward(
  982. self,
  983. input_ids: torch.LongTensor,
  984. pixel_values: torch.FloatTensor,
  985. use_itm_head: Optional[bool] = True,
  986. attention_mask: Optional[torch.LongTensor] = None,
  987. interpolate_pos_encoding: bool = False,
  988. **kwargs: Unpack[TransformersKwargs],
  989. ) -> Union[tuple, BlipTextVisionModelOutput]:
  990. r"""
  991. use_itm_head (`bool`, *optional*, defaults to `True`):
  992. Whether or not to use the image-text matching head.
  993. Examples:
  994. ```python
  995. >>> from PIL import Image
  996. >>> import requests
  997. >>> from transformers import AutoProcessor, BlipForImageTextRetrieval
  998. >>> model = BlipForImageTextRetrieval.from_pretrained("Salesforce/blip-itm-base-coco")
  999. >>> processor = AutoProcessor.from_pretrained("Salesforce/blip-itm-base-coco")
  1000. >>> url = "http://images.cocodataset.org/val2017/000000039769.jpg"
  1001. >>> image = Image.open(requests.get(url, stream=True).raw)
  1002. >>> text = "an image of a cat"
  1003. >>> inputs = processor(images=image, text=text, return_tensors="pt")
  1004. >>> outputs = model(**inputs)
  1005. ```
  1006. """
  1007. vision_outputs = self.vision_model(
  1008. pixel_values=pixel_values,
  1009. interpolate_pos_encoding=interpolate_pos_encoding,
  1010. **kwargs,
  1011. )
  1012. image_embeds = vision_outputs.last_hidden_state
  1013. image_atts = torch.ones(image_embeds.size()[:-1], dtype=torch.long)
  1014. if use_itm_head:
  1015. question_embeds = self.text_encoder(
  1016. input_ids=input_ids,
  1017. attention_mask=attention_mask,
  1018. encoder_hidden_states=image_embeds,
  1019. encoder_attention_mask=image_atts,
  1020. **kwargs,
  1021. )
  1022. question_embeds = question_embeds.last_hidden_state
  1023. output = self.itm_head(question_embeds[:, 0, :])
  1024. else:
  1025. question_embeds = self.text_encoder(
  1026. input_ids=input_ids,
  1027. attention_mask=attention_mask,
  1028. **kwargs,
  1029. )
  1030. question_embeds = question_embeds.last_hidden_state
  1031. image_feat = normalize(self.vision_proj(image_embeds[:, 0, :]), dim=-1)
  1032. text_feat = normalize(self.text_proj(question_embeds[:, 0, :]), dim=-1)
  1033. output = image_feat @ text_feat.t()
  1034. return BlipImageTextMatchingModelOutput(
  1035. itm_score=output,
  1036. last_hidden_state=vision_outputs.last_hidden_state,
  1037. hidden_states=vision_outputs.hidden_states,
  1038. attentions=vision_outputs.attentions,
  1039. question_embeds=question_embeds,
  1040. )
  1041. __all__ = [
  1042. "BlipModel",
  1043. "BlipPreTrainedModel",
  1044. "BlipForConditionalGeneration",
  1045. "BlipForQuestionAnswering",
  1046. "BlipVisionModel",
  1047. "BlipTextModel",
  1048. "BlipForImageTextRetrieval",
  1049. ]