modeling_bridgetower.py 83 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525526527528529530531532533534535536537538539540541542543544545546547548549550551552553554555556557558559560561562563564565566567568569570571572573574575576577578579580581582583584585586587588589590591592593594595596597598599600601602603604605606607608609610611612613614615616617618619620621622623624625626627628629630631632633634635636637638639640641642643644645646647648649650651652653654655656657658659660661662663664665666667668669670671672673674675676677678679680681682683684685686687688689690691692693694695696697698699700701702703704705706707708709710711712713714715716717718719720721722723724725726727728729730731732733734735736737738739740741742743744745746747748749750751752753754755756757758759760761762763764765766767768769770771772773774775776777778779780781782783784785786787788789790791792793794795796797798799800801802803804805806807808809810811812813814815816817818819820821822823824825826827828829830831832833834835836837838839840841842843844845846847848849850851852853854855856857858859860861862863864865866867868869870871872873874875876877878879880881882883884885886887888889890891892893894895896897898899900901902903904905906907908909910911912913914915916917918919920921922923924925926927928929930931932933934935936937938939940941942943944945946947948949950951952953954955956957958959960961962963964965966967968969970971972973974975976977978979980981982983984985986987988989990991992993994995996997998999100010011002100310041005100610071008100910101011101210131014101510161017101810191020102110221023102410251026102710281029103010311032103310341035103610371038103910401041104210431044104510461047104810491050105110521053105410551056105710581059106010611062106310641065106610671068106910701071107210731074107510761077107810791080108110821083108410851086108710881089109010911092109310941095109610971098109911001101110211031104110511061107110811091110111111121113111411151116111711181119112011211122112311241125112611271128112911301131113211331134113511361137113811391140114111421143114411451146114711481149115011511152115311541155115611571158115911601161116211631164116511661167116811691170117111721173117411751176117711781179118011811182118311841185118611871188118911901191119211931194119511961197119811991200120112021203120412051206120712081209121012111212121312141215121612171218121912201221122212231224122512261227122812291230123112321233123412351236123712381239124012411242124312441245124612471248124912501251125212531254125512561257125812591260126112621263126412651266126712681269127012711272127312741275127612771278127912801281128212831284128512861287128812891290129112921293129412951296129712981299130013011302130313041305130613071308130913101311131213131314131513161317131813191320132113221323132413251326132713281329133013311332133313341335133613371338133913401341134213431344134513461347134813491350135113521353135413551356135713581359136013611362136313641365136613671368136913701371137213731374137513761377137813791380138113821383138413851386138713881389139013911392139313941395139613971398139914001401140214031404140514061407140814091410141114121413141414151416141714181419142014211422142314241425142614271428142914301431143214331434143514361437143814391440144114421443144414451446144714481449145014511452145314541455145614571458145914601461146214631464146514661467146814691470147114721473147414751476147714781479148014811482148314841485148614871488148914901491149214931494149514961497149814991500150115021503150415051506150715081509151015111512151315141515151615171518151915201521152215231524152515261527152815291530153115321533153415351536153715381539154015411542154315441545154615471548154915501551155215531554155515561557155815591560156115621563156415651566156715681569157015711572157315741575157615771578157915801581158215831584158515861587158815891590159115921593159415951596159715981599160016011602160316041605160616071608160916101611161216131614161516161617161816191620162116221623162416251626162716281629163016311632163316341635163616371638163916401641164216431644164516461647164816491650165116521653165416551656165716581659166016611662166316641665166616671668166916701671167216731674167516761677167816791680168116821683168416851686168716881689169016911692169316941695169616971698169917001701170217031704170517061707170817091710171117121713171417151716171717181719172017211722172317241725172617271728172917301731173217331734173517361737173817391740174117421743174417451746174717481749175017511752175317541755175617571758175917601761176217631764176517661767176817691770177117721773177417751776177717781779178017811782178317841785178617871788178917901791179217931794179517961797179817991800180118021803180418051806180718081809181018111812181318141815181618171818181918201821182218231824182518261827182818291830183118321833183418351836183718381839184018411842184318441845184618471848184918501851185218531854185518561857185818591860186118621863186418651866186718681869187018711872187318741875
  1. # coding=utf-8
  2. # Copyright 2023 The Intel Labs Team Authors, The Microsoft Research Team Authors and HuggingFace Inc. team. All rights reserved.
  3. #
  4. # Licensed under the Apache License, Version 2.0 (the "License");
  5. # you may not use this file except in compliance with the License.
  6. # You may obtain a copy of the License at
  7. #
  8. # http://www.apache.org/licenses/LICENSE-2.0
  9. #
  10. # Unless required by applicable law or agreed to in writing, software
  11. # distributed under the License is distributed on an "AS IS" BASIS,
  12. # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  13. # See the License for the specific language governing permissions and
  14. # limitations under the License.
  15. """PyTorch BridgeTower Model"""
  16. import math
  17. from collections import OrderedDict
  18. from dataclasses import dataclass
  19. from typing import Optional, Union
  20. import torch
  21. from torch import nn
  22. from torch.nn import CrossEntropyLoss
  23. from ...activations import ACT2FN, QuickGELUActivation
  24. from ...cache_utils import Cache, DynamicCache, EncoderDecoderCache
  25. from ...modeling_layers import GradientCheckpointingLayer
  26. from ...modeling_outputs import (
  27. BaseModelOutputWithPastAndCrossAttentions,
  28. BaseModelOutputWithPoolingAndCrossAttentions,
  29. MaskedLMOutput,
  30. ModelOutput,
  31. SequenceClassifierOutput,
  32. )
  33. from ...modeling_utils import PreTrainedModel
  34. from ...pytorch_utils import apply_chunking_to_forward, find_pruneable_heads_and_indices, prune_linear_layer
  35. from ...utils import auto_docstring, logging, torch_int
  36. from ...utils.deprecation import deprecate_kwarg
  37. from .configuration_bridgetower import BridgeTowerConfig, BridgeTowerTextConfig, BridgeTowerVisionConfig
  38. logger = logging.get_logger(__name__)
  39. _TOKENIZER_FOR_DOC = "RobertaTokenizer"
  40. @dataclass
  41. @auto_docstring(
  42. custom_intro="""
  43. Output type of [`BridgeTowerModel`].
  44. """
  45. )
  46. class BridgeTowerModelOutput(ModelOutput):
  47. r"""
  48. text_features (`torch.FloatTensor` of shape `(batch_size, text_sequence_length, hidden_size)`):
  49. Sequence of hidden-states at the text output of the last layer of the model.
  50. image_features (`torch.FloatTensor` of shape `(batch_size, image_sequence_length, hidden_size)`):
  51. Sequence of hidden-states at the image output of the last layer of the model.
  52. pooler_output (`torch.FloatTensor` of shape `(batch_size, hidden_size x 2)`):
  53. Concatenation of last layer hidden-state of the first token of the text and image sequence (classification
  54. token), respectively, after further processing through layers used for auxiliary pretraining tasks.
  55. """
  56. text_features: Optional[torch.FloatTensor] = None
  57. image_features: Optional[torch.FloatTensor] = None
  58. pooler_output: Optional[torch.FloatTensor] = None
  59. hidden_states: Optional[tuple[torch.FloatTensor]] = None
  60. attentions: Optional[tuple[torch.FloatTensor]] = None
  61. @dataclass
  62. @auto_docstring(
  63. custom_intro="""
  64. Output type of ['BridgeTowerForContrastiveLearning']
  65. """
  66. )
  67. class BridgeTowerContrastiveOutput(ModelOutput):
  68. r"""
  69. loss (`torch.FloatTensor` of shape `(1,)`, *optional*, returned when `return_loss` is `True`):
  70. Image-text contrastive loss.
  71. logits (`torch.FloatTensor` of shape `(batch_size, sequence_length, config.vocab_size)`):
  72. Prediction scores of the language modeling head (scores for each vocabulary token before SoftMax).
  73. text_embeds (`torch.FloatTensor)`, *optional*, returned when model is initialized with `with_projection=True`):
  74. The text embeddings obtained by applying the projection layer to the pooler_output.
  75. image_embeds (`torch.FloatTensor)`, *optional*, returned when model is initialized with `with_projection=True`):
  76. The image embeddings obtained by applying the projection layer to the pooler_output.
  77. cross_embeds (`torch.FloatTensor)`, *optional*, returned when model is initialized with `with_projection=True`):
  78. The text-image cross-modal embeddings obtained by applying the projection layer to the pooler_output.
  79. attentions (`tuple(torch.FloatTensor)`, *optional*, returned when `output_attentions=True` is passed or when `config.output_attentions=True`):
  80. Tuple of `torch.FloatTensor` (one for each layer) of shape `(batch_size, num_heads, sequence_length,
  81. sequence_length)`.
  82. """
  83. loss: Optional[torch.FloatTensor] = None
  84. logits: Optional[torch.FloatTensor] = None
  85. text_embeds: Optional[tuple[torch.FloatTensor]] = None
  86. image_embeds: Optional[tuple[torch.FloatTensor]] = None
  87. cross_embeds: Optional[tuple[torch.FloatTensor]] = None
  88. hidden_states: Optional[tuple[torch.FloatTensor]] = None
  89. attentions: Optional[tuple[torch.FloatTensor]] = None
  90. class BridgeTowerResidualAttention(nn.Module):
  91. def __init__(self, config):
  92. super().__init__()
  93. self.attn = nn.MultiheadAttention(config.hidden_size, config.hidden_size // 64)
  94. self.ln_1 = nn.LayerNorm(config.hidden_size, eps=config.layer_norm_eps)
  95. self.mlp = nn.ModuleDict(
  96. OrderedDict(
  97. [
  98. ("c_fc", nn.Linear(config.hidden_size, config.hidden_size * 4)),
  99. ("gelu", QuickGELUActivation()),
  100. ("c_proj", nn.Linear(config.hidden_size * 4, config.hidden_size)),
  101. ]
  102. )
  103. )
  104. self.ln_2 = nn.LayerNorm(config.hidden_size, eps=config.layer_norm_eps)
  105. self.attn_mask = None
  106. def attention(self, hidden_state: torch.Tensor, attention_mask: torch.Tensor):
  107. if attention_mask is not None:
  108. attention_mask = attention_mask.to(dtype=torch.bool, device=hidden_state.device)
  109. self.attn_mask = (
  110. self.attn_mask.to(dtype=hidden_state.dtype, device=hidden_state.device)
  111. if self.attn_mask is not None
  112. else None
  113. )
  114. return self.attn(
  115. hidden_state,
  116. hidden_state,
  117. hidden_state,
  118. need_weights=False,
  119. attn_mask=self.attn_mask,
  120. key_padding_mask=attention_mask,
  121. )[0]
  122. def forward(self, hidden_state: torch.Tensor, attention_mask: Optional[torch.Tensor] = None):
  123. residual_state = hidden_state + self.attention(self.ln_1(hidden_state), attention_mask)
  124. hidden_state = self.ln_2(residual_state)
  125. for layer in self.mlp.values():
  126. hidden_state = layer(hidden_state)
  127. hidden_state = residual_state + hidden_state
  128. return hidden_state
  129. class BridgeTowerTransformer(nn.Module):
  130. def __init__(self, config):
  131. super().__init__()
  132. self.hidden_size = config.hidden_size
  133. self.num_hidden_layers = config.num_hidden_layers
  134. if config.remove_last_layer:
  135. self.resblocks = nn.ModuleList(
  136. [BridgeTowerResidualAttention(config) for _ in range(self.num_hidden_layers - 1)]
  137. )
  138. else:
  139. self.resblocks = nn.ModuleList(
  140. [BridgeTowerResidualAttention(config) for _ in range(self.num_hidden_layers)]
  141. )
  142. self.stop_gradient = config.stop_gradient
  143. def forward(self, hidden_state: torch.Tensor, attention_mask: Optional[torch.Tensor] = None):
  144. hidden_states = []
  145. for block in self.resblocks:
  146. hidden_state = block(hidden_state, attention_mask)
  147. if self.stop_gradient:
  148. hidden_states.append(hidden_state.detach())
  149. else:
  150. hidden_states.append(hidden_state)
  151. return hidden_states
  152. # Copied from transformers.models.clip.modeling_clip.CLIPVisionEmbeddings with CLIP->BridgeTower
  153. class BridgeTowerVisionEmbeddings(nn.Module):
  154. def __init__(self, config: BridgeTowerVisionConfig):
  155. super().__init__()
  156. self.config = config
  157. self.embed_dim = config.hidden_size
  158. self.image_size = config.image_size
  159. self.patch_size = config.patch_size
  160. self.class_embedding = nn.Parameter(torch.randn(self.embed_dim))
  161. self.patch_embedding = nn.Conv2d(
  162. in_channels=config.num_channels,
  163. out_channels=self.embed_dim,
  164. kernel_size=self.patch_size,
  165. stride=self.patch_size,
  166. bias=False,
  167. )
  168. self.num_patches = (self.image_size // self.patch_size) ** 2
  169. self.num_positions = self.num_patches + 1
  170. self.position_embedding = nn.Embedding(self.num_positions, self.embed_dim)
  171. self.register_buffer("position_ids", torch.arange(self.num_positions).expand((1, -1)), persistent=False)
  172. def interpolate_pos_encoding(self, embeddings: torch.Tensor, height: int, width: int) -> torch.Tensor:
  173. """
  174. This method allows to interpolate the pre-trained position encodings, to be able to use the model on higher resolution
  175. images. This method is also adapted to support torch.jit tracing.
  176. Adapted from:
  177. - https://github.com/facebookresearch/dino/blob/de9ee3df6cf39fac952ab558447af1fa1365362a/vision_transformer.py#L174-L194, and
  178. - https://github.com/facebookresearch/dinov2/blob/e1277af2ba9496fbadf7aec6eba56e8d882d1e35/dinov2/models/vision_transformer.py#L179-L211
  179. """
  180. num_patches = embeddings.shape[1] - 1
  181. position_embedding = self.position_embedding.weight.unsqueeze(0)
  182. num_positions = position_embedding.shape[1] - 1
  183. # always interpolate when tracing to ensure the exported model works for dynamic input shapes
  184. if not torch.jit.is_tracing() and num_patches == num_positions and height == width:
  185. return self.position_embedding(self.position_ids)
  186. class_pos_embed = position_embedding[:, :1]
  187. patch_pos_embed = position_embedding[:, 1:]
  188. dim = embeddings.shape[-1]
  189. new_height = height // self.patch_size
  190. new_width = width // self.patch_size
  191. sqrt_num_positions = torch_int(num_positions**0.5)
  192. patch_pos_embed = patch_pos_embed.reshape(1, sqrt_num_positions, sqrt_num_positions, dim)
  193. patch_pos_embed = patch_pos_embed.permute(0, 3, 1, 2)
  194. patch_pos_embed = nn.functional.interpolate(
  195. patch_pos_embed,
  196. size=(new_height, new_width),
  197. mode="bicubic",
  198. align_corners=False,
  199. )
  200. patch_pos_embed = patch_pos_embed.permute(0, 2, 3, 1).view(1, -1, dim)
  201. return torch.cat((class_pos_embed, patch_pos_embed), dim=1)
  202. def forward(self, pixel_values: torch.FloatTensor, interpolate_pos_encoding=False) -> torch.Tensor:
  203. batch_size, _, height, width = pixel_values.shape
  204. if not interpolate_pos_encoding and (height != self.image_size or width != self.image_size):
  205. raise ValueError(
  206. f"Input image size ({height}*{width}) doesn't match model ({self.image_size}*{self.image_size})."
  207. )
  208. target_dtype = self.patch_embedding.weight.dtype
  209. patch_embeds = self.patch_embedding(pixel_values.to(dtype=target_dtype)) # shape = [*, width, grid, grid]
  210. patch_embeds = patch_embeds.flatten(2).transpose(1, 2)
  211. class_embeds = self.class_embedding.expand(batch_size, 1, -1)
  212. embeddings = torch.cat([class_embeds, patch_embeds], dim=1)
  213. if interpolate_pos_encoding:
  214. embeddings = embeddings + self.interpolate_pos_encoding(embeddings, height, width)
  215. else:
  216. embeddings = embeddings + self.position_embedding(self.position_ids)
  217. return embeddings
  218. class BridgeTowerVisionTransformer(nn.Module):
  219. def __init__(self, config):
  220. super().__init__()
  221. self.embeddings = BridgeTowerVisionEmbeddings(config)
  222. self.ln_pre = nn.LayerNorm(config.hidden_size, eps=config.layer_norm_eps)
  223. self.transformer = BridgeTowerTransformer(config)
  224. self.ln_post = nn.LayerNorm(config.hidden_size, eps=config.layer_norm_eps)
  225. self.share_layernorm = config.share_layernorm
  226. if not config.share_layernorm:
  227. self.ln_separate = nn.ModuleList(
  228. [nn.LayerNorm(config.hidden_size, eps=config.layer_norm_eps) for _ in range(config.num_hidden_layers)]
  229. )
  230. def forward(
  231. self,
  232. pixel_values: torch.Tensor,
  233. attention_mask,
  234. interpolate_pos_encoding: bool = False,
  235. ):
  236. hidden_states = self.embeddings(pixel_values, interpolate_pos_encoding)
  237. hidden_states = self.ln_pre(hidden_states)
  238. # NLD -> LND
  239. hidden_states = hidden_states.permute(1, 0, 2)
  240. hidden_states = self.transformer(hidden_states, attention_mask)
  241. # shape = [num_hidden_layers, hidden_size, *, grid ** 2]
  242. hidden_states = torch.stack(hidden_states, dim=0)
  243. # shape = [num_hidden_layers, *, hidden_size, grid ** 2]
  244. hidden_states = hidden_states.permute(0, 2, 1, 3)
  245. if self.share_layernorm:
  246. hidden_states = self.ln_post(hidden_states)
  247. else:
  248. hidden_states_stack = []
  249. for hidden_states, ln in zip(hidden_states, self.ln_separate):
  250. hidden_states = ln(hidden_states)
  251. hidden_states_stack.append(hidden_states)
  252. # shape = [num_hidden_layers, *, hidden_size, grid ** 2]
  253. hidden_states = torch.stack(hidden_states_stack, dim=0)
  254. return hidden_states
  255. def forward_pre(
  256. self,
  257. pixel_values: torch.Tensor,
  258. interpolate_pos_encoding: bool = False,
  259. ):
  260. hidden_states = self.embeddings(pixel_values, interpolate_pos_encoding=interpolate_pos_encoding)
  261. hidden_states = self.ln_pre(hidden_states)
  262. # NLD -> LND
  263. hidden_states = hidden_states.permute(1, 0, 2)
  264. return hidden_states
  265. def forward_post(self, hidden_state: torch.Tensor):
  266. visual_output_post = hidden_state.permute(1, 0, 2)
  267. visual_output_post = self.ln_post(visual_output_post)
  268. return visual_output_post
  269. class BridgeTowerLinkTower(nn.Module):
  270. def __init__(self, config):
  271. super().__init__()
  272. self.link_tower_type = config.link_tower_type
  273. self.hidden_size = config.hidden_size
  274. if config.link_tower_type in ["add", "scaled_add", "interpolate"]:
  275. if config.link_tower_type == "scaled_add":
  276. self.scaled_factor = nn.Parameter(torch.tensor(1.0))
  277. elif config.link_tower_type == "interpolate":
  278. self.beta = nn.Parameter(torch.tensor(0.5))
  279. self.LayerNorm = nn.LayerNorm(self.hidden_size, eps=config.layer_norm_eps)
  280. else:
  281. raise NotImplementedError(f"link_tower_type {config.link_tower_type} is not implemented")
  282. def forward(self, hidden_states, cross_modal_hidden_states, attention_mask):
  283. if self.link_tower_type == "add":
  284. return self.LayerNorm(hidden_states + cross_modal_hidden_states)
  285. elif self.link_tower_type == "scaled_add":
  286. return self.LayerNorm(hidden_states * self.scaled_factor + cross_modal_hidden_states)
  287. elif self.link_tower_type == "interpolate":
  288. return self.LayerNorm(hidden_states * (1 - self.beta) + cross_modal_hidden_states * self.beta)
  289. else:
  290. raise NotImplementedError(f"link_tower_type {self.link_tower_type} is not implemented")
  291. # Copied from transformers.models.bert.modeling_bert.BertSelfOutput with Bert->BridgeTower
  292. class BridgeTowerSelfOutput(nn.Module):
  293. def __init__(self, config):
  294. super().__init__()
  295. self.dense = nn.Linear(config.hidden_size, config.hidden_size)
  296. self.LayerNorm = nn.LayerNorm(config.hidden_size, eps=config.layer_norm_eps)
  297. self.dropout = nn.Dropout(config.hidden_dropout_prob)
  298. def forward(self, hidden_states: torch.Tensor, input_tensor: torch.Tensor) -> torch.Tensor:
  299. hidden_states = self.dense(hidden_states)
  300. hidden_states = self.dropout(hidden_states)
  301. hidden_states = self.LayerNorm(hidden_states + input_tensor)
  302. return hidden_states
  303. # Copied from transformers.models.bert.modeling_bert.BertIntermediate with Bert->BridgeTower
  304. class BridgeTowerIntermediate(nn.Module):
  305. def __init__(self, config):
  306. super().__init__()
  307. self.dense = nn.Linear(config.hidden_size, config.intermediate_size)
  308. if isinstance(config.hidden_act, str):
  309. self.intermediate_act_fn = ACT2FN[config.hidden_act]
  310. else:
  311. self.intermediate_act_fn = config.hidden_act
  312. def forward(self, hidden_states: torch.Tensor) -> torch.Tensor:
  313. hidden_states = self.dense(hidden_states)
  314. hidden_states = self.intermediate_act_fn(hidden_states)
  315. return hidden_states
  316. # Copied from transformers.models.bert.modeling_bert.BertOutput with Bert->BridgeTower
  317. class BridgeTowerOutput(nn.Module):
  318. def __init__(self, config):
  319. super().__init__()
  320. self.dense = nn.Linear(config.intermediate_size, config.hidden_size)
  321. self.LayerNorm = nn.LayerNorm(config.hidden_size, eps=config.layer_norm_eps)
  322. self.dropout = nn.Dropout(config.hidden_dropout_prob)
  323. def forward(self, hidden_states: torch.Tensor, input_tensor: torch.Tensor) -> torch.Tensor:
  324. hidden_states = self.dense(hidden_states)
  325. hidden_states = self.dropout(hidden_states)
  326. hidden_states = self.LayerNorm(hidden_states + input_tensor)
  327. return hidden_states
  328. # Copied from transformers.models.bert.modeling_bert.BertPooler with Bert->BridgeTower
  329. class BridgeTowerPooler(nn.Module):
  330. def __init__(self, config):
  331. super().__init__()
  332. self.dense = nn.Linear(config.hidden_size, config.hidden_size)
  333. self.activation = nn.Tanh()
  334. def forward(self, hidden_states: torch.Tensor) -> torch.Tensor:
  335. # We "pool" the model by simply taking the hidden state corresponding
  336. # to the first token.
  337. first_token_tensor = hidden_states[:, 0]
  338. pooled_output = self.dense(first_token_tensor)
  339. pooled_output = self.activation(pooled_output)
  340. return pooled_output
  341. # Copied from transformers.models.roberta.modeling_roberta.RobertaSelfAttention with Roberta->BridgeTower
  342. class BridgeTowerSelfAttention(nn.Module):
  343. def __init__(self, config, position_embedding_type=None, layer_idx=None):
  344. super().__init__()
  345. if config.hidden_size % config.num_attention_heads != 0 and not hasattr(config, "embedding_size"):
  346. raise ValueError(
  347. f"The hidden size ({config.hidden_size}) is not a multiple of the number of attention "
  348. f"heads ({config.num_attention_heads})"
  349. )
  350. self.num_attention_heads = config.num_attention_heads
  351. self.attention_head_size = int(config.hidden_size / config.num_attention_heads)
  352. self.all_head_size = self.num_attention_heads * self.attention_head_size
  353. self.query = nn.Linear(config.hidden_size, self.all_head_size)
  354. self.key = nn.Linear(config.hidden_size, self.all_head_size)
  355. self.value = nn.Linear(config.hidden_size, self.all_head_size)
  356. self.dropout = nn.Dropout(config.attention_probs_dropout_prob)
  357. self.position_embedding_type = position_embedding_type or getattr(
  358. config, "position_embedding_type", "absolute"
  359. )
  360. if self.position_embedding_type == "relative_key" or self.position_embedding_type == "relative_key_query":
  361. self.max_position_embeddings = config.max_position_embeddings
  362. self.distance_embedding = nn.Embedding(2 * config.max_position_embeddings - 1, self.attention_head_size)
  363. self.is_decoder = config.is_decoder
  364. self.layer_idx = layer_idx
  365. @deprecate_kwarg("past_key_value", new_name="past_key_values", version="4.58")
  366. def forward(
  367. self,
  368. hidden_states: torch.Tensor,
  369. attention_mask: Optional[torch.FloatTensor] = None,
  370. head_mask: Optional[torch.FloatTensor] = None,
  371. encoder_hidden_states: Optional[torch.FloatTensor] = None,
  372. past_key_values: Optional[Cache] = None,
  373. output_attentions: Optional[bool] = False,
  374. cache_position: Optional[torch.Tensor] = None,
  375. ) -> tuple[torch.Tensor]:
  376. batch_size, seq_length, _ = hidden_states.shape
  377. query_layer = self.query(hidden_states)
  378. query_layer = query_layer.view(batch_size, -1, self.num_attention_heads, self.attention_head_size).transpose(
  379. 1, 2
  380. )
  381. is_updated = False
  382. is_cross_attention = encoder_hidden_states is not None
  383. if past_key_values is not None:
  384. if isinstance(past_key_values, EncoderDecoderCache):
  385. is_updated = past_key_values.is_updated.get(self.layer_idx)
  386. if is_cross_attention:
  387. # after the first generated id, we can subsequently re-use all key/value_layer from cache
  388. curr_past_key_value = past_key_values.cross_attention_cache
  389. else:
  390. curr_past_key_value = past_key_values.self_attention_cache
  391. else:
  392. curr_past_key_value = past_key_values
  393. current_states = encoder_hidden_states if is_cross_attention else hidden_states
  394. if is_cross_attention and past_key_values is not None and is_updated:
  395. # reuse k,v, cross_attentions
  396. key_layer = curr_past_key_value.layers[self.layer_idx].keys
  397. value_layer = curr_past_key_value.layers[self.layer_idx].values
  398. else:
  399. key_layer = self.key(current_states)
  400. key_layer = key_layer.view(batch_size, -1, self.num_attention_heads, self.attention_head_size).transpose(
  401. 1, 2
  402. )
  403. value_layer = self.value(current_states)
  404. value_layer = value_layer.view(
  405. batch_size, -1, self.num_attention_heads, self.attention_head_size
  406. ).transpose(1, 2)
  407. if past_key_values is not None:
  408. # save all key/value_layer to cache to be re-used for fast auto-regressive generation
  409. cache_position = cache_position if not is_cross_attention else None
  410. key_layer, value_layer = curr_past_key_value.update(
  411. key_layer, value_layer, self.layer_idx, {"cache_position": cache_position}
  412. )
  413. # set flag that curr layer for cross-attn is already updated so we can re-use in subsequent calls
  414. if is_cross_attention and isinstance(past_key_values, EncoderDecoderCache):
  415. past_key_values.is_updated[self.layer_idx] = True
  416. # Take the dot product between "query" and "key" to get the raw attention scores.
  417. attention_scores = torch.matmul(query_layer, key_layer.transpose(-1, -2))
  418. if self.position_embedding_type == "relative_key" or self.position_embedding_type == "relative_key_query":
  419. query_length, key_length = query_layer.shape[2], key_layer.shape[2]
  420. if past_key_values is not None:
  421. position_ids_l = torch.tensor(key_length - 1, dtype=torch.long, device=hidden_states.device).view(
  422. -1, 1
  423. )
  424. else:
  425. position_ids_l = torch.arange(query_length, dtype=torch.long, device=hidden_states.device).view(-1, 1)
  426. position_ids_r = torch.arange(key_length, dtype=torch.long, device=hidden_states.device).view(1, -1)
  427. distance = position_ids_l - position_ids_r
  428. positional_embedding = self.distance_embedding(distance + self.max_position_embeddings - 1)
  429. positional_embedding = positional_embedding.to(dtype=query_layer.dtype) # fp16 compatibility
  430. if self.position_embedding_type == "relative_key":
  431. relative_position_scores = torch.einsum("bhld,lrd->bhlr", query_layer, positional_embedding)
  432. attention_scores = attention_scores + relative_position_scores
  433. elif self.position_embedding_type == "relative_key_query":
  434. relative_position_scores_query = torch.einsum("bhld,lrd->bhlr", query_layer, positional_embedding)
  435. relative_position_scores_key = torch.einsum("bhrd,lrd->bhlr", key_layer, positional_embedding)
  436. attention_scores = attention_scores + relative_position_scores_query + relative_position_scores_key
  437. attention_scores = attention_scores / math.sqrt(self.attention_head_size)
  438. if attention_mask is not None:
  439. # Apply the attention mask is (precomputed for all layers in BridgeTowerModel forward() function)
  440. attention_scores = attention_scores + attention_mask
  441. # Normalize the attention scores to probabilities.
  442. attention_probs = nn.functional.softmax(attention_scores, dim=-1)
  443. # This is actually dropping out entire tokens to attend to, which might
  444. # seem a bit unusual, but is taken from the original Transformer paper.
  445. attention_probs = self.dropout(attention_probs)
  446. # Mask heads if we want to
  447. if head_mask is not None:
  448. attention_probs = attention_probs * head_mask
  449. context_layer = torch.matmul(attention_probs, value_layer)
  450. context_layer = context_layer.permute(0, 2, 1, 3).contiguous()
  451. new_context_layer_shape = context_layer.size()[:-2] + (self.all_head_size,)
  452. context_layer = context_layer.view(new_context_layer_shape)
  453. return context_layer, attention_probs
  454. BRIDGE_TOWER_SELF_ATTENTION_CLASSES = {
  455. "eager": BridgeTowerSelfAttention,
  456. }
  457. # Copied from transformers.models.bert.modeling_bert.BertAttention with Bert->BridgeTower,BERT->BRIDGE_TOWER
  458. class BridgeTowerAttention(nn.Module):
  459. def __init__(self, config, position_embedding_type=None, layer_idx=None):
  460. super().__init__()
  461. self.self = BRIDGE_TOWER_SELF_ATTENTION_CLASSES[config._attn_implementation](
  462. config,
  463. position_embedding_type=position_embedding_type,
  464. layer_idx=layer_idx,
  465. )
  466. self.output = BridgeTowerSelfOutput(config)
  467. self.pruned_heads = set()
  468. def prune_heads(self, heads):
  469. if len(heads) == 0:
  470. return
  471. heads, index = find_pruneable_heads_and_indices(
  472. heads, self.self.num_attention_heads, self.self.attention_head_size, self.pruned_heads
  473. )
  474. # Prune linear layers
  475. self.self.query = prune_linear_layer(self.self.query, index)
  476. self.self.key = prune_linear_layer(self.self.key, index)
  477. self.self.value = prune_linear_layer(self.self.value, index)
  478. self.output.dense = prune_linear_layer(self.output.dense, index, dim=1)
  479. # Update hyper params and store pruned heads
  480. self.self.num_attention_heads = self.self.num_attention_heads - len(heads)
  481. self.self.all_head_size = self.self.attention_head_size * self.self.num_attention_heads
  482. self.pruned_heads = self.pruned_heads.union(heads)
  483. @deprecate_kwarg("past_key_value", new_name="past_key_values", version="4.58")
  484. def forward(
  485. self,
  486. hidden_states: torch.Tensor,
  487. attention_mask: Optional[torch.FloatTensor] = None,
  488. head_mask: Optional[torch.FloatTensor] = None,
  489. encoder_hidden_states: Optional[torch.FloatTensor] = None,
  490. past_key_values: Optional[Cache] = None,
  491. output_attentions: Optional[bool] = False,
  492. cache_position: Optional[torch.Tensor] = None,
  493. ) -> tuple[torch.Tensor]:
  494. self_outputs = self.self(
  495. hidden_states,
  496. attention_mask=attention_mask,
  497. head_mask=head_mask,
  498. encoder_hidden_states=encoder_hidden_states,
  499. past_key_values=past_key_values,
  500. output_attentions=output_attentions,
  501. cache_position=cache_position,
  502. )
  503. attention_output = self.output(self_outputs[0], hidden_states)
  504. outputs = (attention_output,) + self_outputs[1:] # add attentions if we output them
  505. return outputs
  506. class BridgeTowerBertCrossLayer(nn.Module):
  507. def __init__(self, config, layer_idx=None):
  508. super().__init__()
  509. self.chunk_size_feed_forward = config.chunk_size_feed_forward
  510. self.seq_len_dim = 1
  511. self.attention = BridgeTowerAttention(config, layer_idx=layer_idx)
  512. self.is_decoder = config.is_decoder
  513. self.add_cross_attention = config.add_cross_attention
  514. self.crossattention = BridgeTowerAttention(config, layer_idx=layer_idx)
  515. self.intermediate = BridgeTowerIntermediate(config)
  516. self.output = BridgeTowerOutput(config)
  517. @deprecate_kwarg("past_key_value", new_name="past_key_values", version="4.58")
  518. def forward(
  519. self,
  520. hidden_states,
  521. encoder_hidden_states,
  522. attention_mask=None,
  523. head_mask=None,
  524. encoder_attention_mask=None,
  525. past_key_values=None,
  526. output_attentions=False,
  527. cache_position=None,
  528. ):
  529. # decoder uni-directional self-attention cached key/values tuple is at positions 1,2
  530. self_attention_outputs = self.attention(
  531. hidden_states,
  532. attention_mask=attention_mask,
  533. head_mask=None,
  534. output_attentions=output_attentions,
  535. past_key_values=None,
  536. )
  537. attention_output = self_attention_outputs[0]
  538. # if decoder, the last output is tuple of self-attn cache
  539. # add self attentions if we output attention weights
  540. outputs = self_attention_outputs[1:]
  541. cross_attention_outputs = self.crossattention(
  542. attention_output,
  543. attention_mask=encoder_attention_mask,
  544. head_mask=head_mask,
  545. encoder_hidden_states=encoder_hidden_states,
  546. past_key_values=past_key_values,
  547. output_attentions=output_attentions,
  548. cache_position=cache_position,
  549. )
  550. attention_output = cross_attention_outputs[0]
  551. # add cross attentions if we output attention weights
  552. outputs = outputs + cross_attention_outputs[1:]
  553. layer_output = apply_chunking_to_forward(
  554. self.feed_forward_chunk, self.chunk_size_feed_forward, self.seq_len_dim, attention_output
  555. )
  556. outputs = (layer_output,) + outputs
  557. return outputs
  558. def feed_forward_chunk(self, attention_output):
  559. intermediate_output = self.intermediate(attention_output)
  560. layer_output = self.output(intermediate_output, attention_output)
  561. return layer_output
  562. class BridgeTowerTextLayer(GradientCheckpointingLayer):
  563. def __init__(self, config, layer_idx=None):
  564. super().__init__()
  565. self.chunk_size_feed_forward = config.chunk_size_feed_forward
  566. self.seq_len_dim = 1
  567. self.attention = BridgeTowerAttention(config, layer_idx=layer_idx)
  568. self.is_decoder = config.is_decoder
  569. self.add_cross_attention = config.add_cross_attention
  570. if self.add_cross_attention:
  571. if not self.is_decoder:
  572. raise ValueError(f"{self} should be used as a decoder model if cross attention is added")
  573. self.crossattention = BridgeTowerAttention(config, position_embedding_type="absolute", layer_idx=layer_idx)
  574. self.intermediate = BridgeTowerIntermediate(config)
  575. self.output = BridgeTowerOutput(config)
  576. @deprecate_kwarg("past_key_value", new_name="past_key_values", version="4.58")
  577. def forward(
  578. self,
  579. hidden_states: torch.Tensor,
  580. attention_mask: Optional[torch.FloatTensor] = None,
  581. head_mask: Optional[torch.FloatTensor] = None,
  582. encoder_hidden_states: Optional[torch.FloatTensor] = None,
  583. encoder_attention_mask: Optional[torch.FloatTensor] = None,
  584. past_key_values: Optional[Cache] = None,
  585. output_attentions: Optional[bool] = False,
  586. cache_position: Optional[torch.Tensor] = None,
  587. ) -> tuple[torch.Tensor]:
  588. # decoder uni-directional self-attention cached key/values tuple is at positions 1,2
  589. self_attention_outputs = self.attention(
  590. hidden_states,
  591. attention_mask=attention_mask,
  592. head_mask=head_mask,
  593. output_attentions=output_attentions,
  594. past_key_values=past_key_values,
  595. cache_position=cache_position,
  596. )
  597. attention_output = self_attention_outputs[0]
  598. # if decoder, the last output is tuple of self-attn cache
  599. if self.is_decoder:
  600. outputs = self_attention_outputs[1:-1]
  601. else:
  602. outputs = self_attention_outputs[1:] # add self attentions if we output attention weights
  603. if self.is_decoder and encoder_hidden_states is not None:
  604. if not hasattr(self, "crossattention"):
  605. raise ValueError(
  606. f"If `encoder_hidden_states` are passed, {self} has to be instantiated with cross-attention layers"
  607. " by setting `config.add_cross_attention=True`"
  608. )
  609. cross_attention_outputs = self.crossattention(
  610. attention_output,
  611. attention_mask=encoder_attention_mask,
  612. head_mask=head_mask,
  613. encoder_hidden_states=encoder_hidden_states,
  614. past_key_values=past_key_values,
  615. output_attentions=output_attentions,
  616. cache_position=cache_position,
  617. )
  618. attention_output = cross_attention_outputs[0]
  619. outputs = outputs + cross_attention_outputs[1:-1] # add cross attentions if we output attention weights
  620. layer_output = apply_chunking_to_forward(
  621. self.feed_forward_chunk, self.chunk_size_feed_forward, self.seq_len_dim, attention_output
  622. )
  623. return (layer_output,) + outputs
  624. def feed_forward_chunk(self, attention_output):
  625. intermediate_output = self.intermediate(attention_output)
  626. layer_output = self.output(intermediate_output, attention_output)
  627. return layer_output
  628. # Copied from transformers.models.roberta.modeling_roberta.RobertaEncoder with Roberta->BridgeTowerText
  629. class BridgeTowerTextEncoder(nn.Module):
  630. def __init__(self, config, layer_idx=None):
  631. super().__init__()
  632. self.config = config
  633. self.layer = nn.ModuleList(
  634. [BridgeTowerTextLayer(config, layer_idx=i) for i in range(config.num_hidden_layers)]
  635. )
  636. self.gradient_checkpointing = False
  637. def forward(
  638. self,
  639. hidden_states: torch.Tensor,
  640. attention_mask: Optional[torch.FloatTensor] = None,
  641. head_mask: Optional[torch.FloatTensor] = None,
  642. encoder_hidden_states: Optional[torch.FloatTensor] = None,
  643. encoder_attention_mask: Optional[torch.FloatTensor] = None,
  644. past_key_values: Optional[Cache] = None,
  645. use_cache: Optional[bool] = None,
  646. output_attentions: Optional[bool] = False,
  647. output_hidden_states: Optional[bool] = False,
  648. return_dict: Optional[bool] = True,
  649. cache_position: Optional[torch.Tensor] = None,
  650. ) -> Union[tuple[torch.Tensor], BaseModelOutputWithPastAndCrossAttentions]:
  651. all_hidden_states = () if output_hidden_states else None
  652. all_self_attentions = () if output_attentions else None
  653. all_cross_attentions = () if output_attentions and self.config.add_cross_attention else None
  654. if self.gradient_checkpointing and self.training:
  655. if use_cache:
  656. logger.warning_once(
  657. "`use_cache=True` is incompatible with gradient checkpointing. Setting `use_cache=False`..."
  658. )
  659. use_cache = False
  660. if use_cache and self.config.is_decoder and past_key_values is None:
  661. past_key_values = EncoderDecoderCache(DynamicCache(config=self.config), DynamicCache(config=self.config))
  662. if use_cache and self.config.is_decoder and isinstance(past_key_values, tuple):
  663. logger.warning_once(
  664. "Passing a tuple of `past_key_values` is deprecated and will be removed in Transformers v4.58.0. "
  665. "You should pass an instance of `EncoderDecoderCache` instead, e.g. "
  666. "`past_key_values=EncoderDecoderCache.from_legacy_cache(past_key_values)`."
  667. )
  668. past_key_values = EncoderDecoderCache.from_legacy_cache(past_key_values)
  669. for i, layer_module in enumerate(self.layer):
  670. if output_hidden_states:
  671. all_hidden_states = all_hidden_states + (hidden_states,)
  672. layer_head_mask = head_mask[i] if head_mask is not None else None
  673. layer_outputs = layer_module(
  674. hidden_states,
  675. attention_mask,
  676. layer_head_mask,
  677. encoder_hidden_states, # as a positional argument for gradient checkpointing
  678. encoder_attention_mask=encoder_attention_mask,
  679. past_key_values=past_key_values,
  680. output_attentions=output_attentions,
  681. cache_position=cache_position,
  682. )
  683. hidden_states = layer_outputs[0]
  684. if output_attentions:
  685. all_self_attentions = all_self_attentions + (layer_outputs[1],)
  686. if self.config.add_cross_attention:
  687. all_cross_attentions = all_cross_attentions + (layer_outputs[2],)
  688. if output_hidden_states:
  689. all_hidden_states = all_hidden_states + (hidden_states,)
  690. if not return_dict:
  691. return tuple(
  692. v
  693. for v in [
  694. hidden_states,
  695. past_key_values,
  696. all_hidden_states,
  697. all_self_attentions,
  698. all_cross_attentions,
  699. ]
  700. if v is not None
  701. )
  702. return BaseModelOutputWithPastAndCrossAttentions(
  703. last_hidden_state=hidden_states,
  704. past_key_values=past_key_values,
  705. hidden_states=all_hidden_states,
  706. attentions=all_self_attentions,
  707. cross_attentions=all_cross_attentions,
  708. )
  709. # Copied from transformers.models.roberta.modeling_roberta.RobertaEmbeddings with Roberta->BridgeTowerText
  710. class BridgeTowerTextEmbeddings(nn.Module):
  711. """
  712. Same as BertEmbeddings with a tiny tweak for positional embeddings indexing.
  713. """
  714. # Copied from transformers.models.bert.modeling_bert.BertEmbeddings.__init__
  715. def __init__(self, config):
  716. super().__init__()
  717. self.word_embeddings = nn.Embedding(config.vocab_size, config.hidden_size, padding_idx=config.pad_token_id)
  718. self.position_embeddings = nn.Embedding(config.max_position_embeddings, config.hidden_size)
  719. self.token_type_embeddings = nn.Embedding(config.type_vocab_size, config.hidden_size)
  720. # self.LayerNorm is not snake-cased to stick with TensorFlow model variable name and be able to load
  721. # any TensorFlow checkpoint file
  722. self.LayerNorm = nn.LayerNorm(config.hidden_size, eps=config.layer_norm_eps)
  723. self.dropout = nn.Dropout(config.hidden_dropout_prob)
  724. # position_ids (1, len position emb) is contiguous in memory and exported when serialized
  725. self.position_embedding_type = getattr(config, "position_embedding_type", "absolute")
  726. self.register_buffer(
  727. "position_ids", torch.arange(config.max_position_embeddings).expand((1, -1)), persistent=False
  728. )
  729. self.register_buffer(
  730. "token_type_ids", torch.zeros(self.position_ids.size(), dtype=torch.long), persistent=False
  731. )
  732. # End copy
  733. self.padding_idx = config.pad_token_id
  734. self.position_embeddings = nn.Embedding(
  735. config.max_position_embeddings, config.hidden_size, padding_idx=self.padding_idx
  736. )
  737. def forward(
  738. self, input_ids=None, token_type_ids=None, position_ids=None, inputs_embeds=None, past_key_values_length=0
  739. ):
  740. if position_ids is None:
  741. if input_ids is not None:
  742. # Create the position ids from the input token ids. Any padded tokens remain padded.
  743. position_ids = create_position_ids_from_input_ids(input_ids, self.padding_idx, past_key_values_length)
  744. else:
  745. position_ids = self.create_position_ids_from_inputs_embeds(inputs_embeds)
  746. if input_ids is not None:
  747. input_shape = input_ids.size()
  748. else:
  749. input_shape = inputs_embeds.size()[:-1]
  750. seq_length = input_shape[1]
  751. # Setting the token_type_ids to the registered buffer in constructor where it is all zeros, which usually occurs
  752. # when its auto-generated, registered buffer helps users when tracing the model without passing token_type_ids, solves
  753. # issue #5664
  754. if token_type_ids is None:
  755. if hasattr(self, "token_type_ids"):
  756. buffered_token_type_ids = self.token_type_ids[:, :seq_length]
  757. buffered_token_type_ids_expanded = buffered_token_type_ids.expand(input_shape[0], seq_length)
  758. token_type_ids = buffered_token_type_ids_expanded
  759. else:
  760. token_type_ids = torch.zeros(input_shape, dtype=torch.long, device=self.position_ids.device)
  761. if inputs_embeds is None:
  762. inputs_embeds = self.word_embeddings(input_ids)
  763. token_type_embeddings = self.token_type_embeddings(token_type_ids)
  764. embeddings = inputs_embeds + token_type_embeddings
  765. if self.position_embedding_type == "absolute":
  766. position_embeddings = self.position_embeddings(position_ids)
  767. embeddings += position_embeddings
  768. embeddings = self.LayerNorm(embeddings)
  769. embeddings = self.dropout(embeddings)
  770. return embeddings
  771. def create_position_ids_from_inputs_embeds(self, inputs_embeds):
  772. """
  773. We are provided embeddings directly. We cannot infer which are padded so just generate sequential position ids.
  774. Args:
  775. inputs_embeds: torch.Tensor
  776. Returns: torch.Tensor
  777. """
  778. input_shape = inputs_embeds.size()[:-1]
  779. sequence_length = input_shape[1]
  780. position_ids = torch.arange(
  781. self.padding_idx + 1, sequence_length + self.padding_idx + 1, dtype=torch.long, device=inputs_embeds.device
  782. )
  783. return position_ids.unsqueeze(0).expand(input_shape)
  784. # Copied from transformers.models.roberta.modeling_roberta.create_position_ids_from_input_ids
  785. def create_position_ids_from_input_ids(input_ids, padding_idx, past_key_values_length=0):
  786. """
  787. Replace non-padding symbols with their position numbers. Position numbers begin at padding_idx+1. Padding symbols
  788. are ignored. This is modified from fairseq's `utils.make_positions`.
  789. Args:
  790. x: torch.Tensor x:
  791. Returns: torch.Tensor
  792. """
  793. # The series of casts and type-conversions here are carefully balanced to both work with ONNX export and XLA.
  794. mask = input_ids.ne(padding_idx).int()
  795. incremental_indices = (torch.cumsum(mask, dim=1).type_as(mask) + past_key_values_length) * mask
  796. return incremental_indices.long() + padding_idx
  797. @auto_docstring
  798. class BridgeTowerPreTrainedModel(PreTrainedModel):
  799. config: BridgeTowerConfig
  800. base_model_prefix = "bridgetower"
  801. supports_gradient_checkpointing = False
  802. _no_split_modules = ["BridgeTowerSelfAttention", "BridgeTowerResidualAttention"]
  803. _skip_keys_device_placement = "past_key_values"
  804. def _init_weights(self, module: nn.Module):
  805. std = self.config.initializer_factor
  806. if isinstance(module, BridgeTowerVisionTransformer):
  807. proj_std = (self.config.hidden_size**-0.5) * ((2 * self.config.num_hidden_layers) ** -0.5)
  808. attn_std = self.config.hidden_size**-0.5
  809. fc_std = (2 * self.config.hidden_size) ** -0.5
  810. for block in module.transformer.resblocks:
  811. nn.init.normal_(block.attn.in_proj_weight, std=attn_std * std)
  812. block.attn.in_proj_bias.data.zero_()
  813. nn.init.normal_(block.attn.out_proj.weight, std=proj_std * std)
  814. nn.init.normal_(block.mlp.c_fc.weight, std=fc_std * std)
  815. nn.init.normal_(block.mlp.c_proj.weight, std=proj_std * std)
  816. nn.init.normal_(module.embeddings.class_embedding, std=attn_std * std)
  817. nn.init.normal_(module.embeddings.position_embedding.weight, std=attn_std * std)
  818. elif isinstance(module, (nn.Linear, nn.Conv2d, nn.Embedding)):
  819. module.weight.data.normal_(mean=0.0, std=0.05 * std)
  820. elif isinstance(module, nn.LayerNorm):
  821. module.bias.data.zero_()
  822. module.weight.data.fill_(1.0)
  823. elif isinstance(module, BridgeTowerForContrastiveLearning):
  824. module.logit_scale.data.fill_(self.config.logit_scale_init_value)
  825. if isinstance(module, (nn.Linear, BridgeTowerMLMHead)) and module.bias is not None:
  826. module.bias.data.zero_()
  827. class BridgeTowerVisionModel(BridgeTowerPreTrainedModel):
  828. config: BridgeTowerVisionConfig
  829. def __init__(self, config):
  830. super().__init__(config)
  831. self.visual = BridgeTowerVisionTransformer(config)
  832. @property
  833. def dtype(self):
  834. return self.visual.embeddings.patch_embedding.weight.dtype
  835. def forward(self, image, image_mask=None, interpolate_pos_encoding=False):
  836. return self.visual(image.type(self.dtype), image_mask, interpolate_pos_encoding)
  837. @auto_docstring(
  838. custom_intro="""
  839. The model can behave as an encoder (with only self-attention) as well as a decoder, in which case a layer of
  840. cross-attention is added between the self-attention layers, following the architecture described in *Attention is
  841. all you need*_ by Ashish Vaswani, Noam Shazeer, Niki Parmar, Jakob Uszkoreit, Llion Jones, Aidan N. Gomez, Lukasz
  842. Kaiser and Illia Polosukhin.
  843. To behave as an decoder the model needs to be initialized with the `is_decoder` argument of the configuration set
  844. to `True`. To be used in a Seq2Seq model, the model needs to initialized with both `is_decoder` argument and
  845. `add_cross_attention` set to `True`; an `encoder_hidden_states` is then expected as an input to the forward pass.
  846. .. _*Attention is all you need*: https://huggingface.co/papers/1706.03762
  847. """
  848. )
  849. class BridgeTowerTextModel(BridgeTowerPreTrainedModel):
  850. config: BridgeTowerTextConfig
  851. def __init__(self, config, add_pooling_layer=True):
  852. r"""
  853. add_pooling_layer (bool, *optional*, defaults to `True`):
  854. Whether to add a pooling layer
  855. """
  856. super().__init__(config)
  857. self.config = config
  858. self.embeddings = BridgeTowerTextEmbeddings(config)
  859. self.encoder = BridgeTowerTextEncoder(config)
  860. self.pooler = BridgeTowerPooler(config) if add_pooling_layer else None
  861. # Initialize weights and apply final processing
  862. self.post_init()
  863. def get_input_embeddings(self):
  864. return self.embeddings.word_embeddings
  865. def set_input_embeddings(self, value):
  866. self.embeddings.word_embeddings = value
  867. def _prune_heads(self, heads_to_prune):
  868. """
  869. Prunes heads of the model. heads_to_prune: dict of {layer_num: list of heads to prune in this layer} See base
  870. class PreTrainedModel
  871. """
  872. for layer, heads in heads_to_prune.items():
  873. self.encoder.layer[layer].attention.prune_heads(heads)
  874. @auto_docstring
  875. def forward(
  876. self,
  877. input_ids: Optional[torch.Tensor] = None,
  878. attention_mask: Optional[torch.Tensor] = None,
  879. token_type_ids: Optional[torch.Tensor] = None,
  880. position_ids: Optional[torch.Tensor] = None,
  881. head_mask: Optional[torch.Tensor] = None,
  882. inputs_embeds: Optional[torch.Tensor] = None,
  883. encoder_hidden_states: Optional[torch.Tensor] = None,
  884. encoder_attention_mask: Optional[torch.Tensor] = None,
  885. past_key_values: Optional[Cache] = None,
  886. use_cache: Optional[bool] = None,
  887. output_attentions: Optional[bool] = None,
  888. output_hidden_states: Optional[bool] = None,
  889. return_dict: Optional[bool] = None,
  890. cache_position: Optional[torch.Tensor] = None,
  891. ) -> Union[tuple[torch.Tensor], BaseModelOutputWithPoolingAndCrossAttentions]:
  892. output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
  893. output_hidden_states = (
  894. output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
  895. )
  896. return_dict = return_dict if return_dict is not None else self.config.use_return_dict
  897. if self.config.is_decoder:
  898. use_cache = use_cache if use_cache is not None else self.config.use_cache
  899. else:
  900. use_cache = False
  901. if input_ids is not None and inputs_embeds is not None:
  902. raise ValueError("You cannot specify both input_ids and inputs_embeds at the same time")
  903. elif input_ids is not None:
  904. self.warn_if_padding_and_no_attention_mask(input_ids, attention_mask)
  905. input_shape = input_ids.size()
  906. elif inputs_embeds is not None:
  907. input_shape = inputs_embeds.size()[:-1]
  908. else:
  909. raise ValueError("You have to specify either input_ids or inputs_embeds")
  910. batch_size, seq_length = input_shape
  911. device = input_ids.device if input_ids is not None else inputs_embeds.device
  912. past_key_values_length = 0
  913. if past_key_values is not None:
  914. past_key_values_length = (
  915. past_key_values[0][0].shape[-2]
  916. if not isinstance(past_key_values, Cache)
  917. else past_key_values.get_seq_length()
  918. )
  919. if attention_mask is None:
  920. attention_mask = torch.ones(((batch_size, seq_length + past_key_values_length)), device=device)
  921. if token_type_ids is None:
  922. if hasattr(self.embeddings, "token_type_ids"):
  923. buffered_token_type_ids = self.embeddings.token_type_ids[:, :seq_length]
  924. buffered_token_type_ids_expanded = buffered_token_type_ids.expand(batch_size, seq_length)
  925. token_type_ids = buffered_token_type_ids_expanded
  926. else:
  927. token_type_ids = torch.zeros(input_shape, dtype=torch.long, device=device)
  928. # We can provide a self-attention mask of dimensions [batch_size, from_seq_length, to_seq_length]
  929. # ourselves in which case we just need to make it broadcastable to all heads.
  930. extended_attention_mask: torch.Tensor = self.get_extended_attention_mask(attention_mask, input_shape)
  931. # If a 2D or 3D attention mask is provided for the cross-attention
  932. # we need to make broadcastable to [batch_size, num_heads, seq_length, seq_length]
  933. if self.config.is_decoder and encoder_hidden_states is not None:
  934. encoder_batch_size, encoder_sequence_length, _ = encoder_hidden_states.size()
  935. encoder_hidden_shape = (encoder_batch_size, encoder_sequence_length)
  936. if encoder_attention_mask is None:
  937. encoder_attention_mask = torch.ones(encoder_hidden_shape, device=device)
  938. encoder_extended_attention_mask = self.invert_attention_mask(encoder_attention_mask)
  939. else:
  940. encoder_extended_attention_mask = None
  941. # Prepare head mask if needed
  942. # 1.0 in head_mask indicate we keep the head
  943. # attention_probs has shape bsz x n_heads x N x N
  944. # input head_mask has shape [num_heads] or [num_hidden_layers x num_heads]
  945. # and head_mask is converted to shape [num_hidden_layers x batch x num_heads x seq_length x seq_length]
  946. head_mask = self.get_head_mask(head_mask, self.config.num_hidden_layers)
  947. embedding_output = self.embeddings(
  948. input_ids=input_ids,
  949. position_ids=position_ids,
  950. token_type_ids=token_type_ids,
  951. inputs_embeds=inputs_embeds,
  952. past_key_values_length=past_key_values_length,
  953. )
  954. encoder_outputs = self.encoder(
  955. embedding_output,
  956. attention_mask=extended_attention_mask,
  957. head_mask=head_mask,
  958. encoder_hidden_states=encoder_hidden_states,
  959. encoder_attention_mask=encoder_extended_attention_mask,
  960. past_key_values=past_key_values,
  961. use_cache=use_cache,
  962. output_attentions=output_attentions,
  963. output_hidden_states=output_hidden_states,
  964. return_dict=return_dict,
  965. cache_position=cache_position,
  966. )
  967. sequence_output = encoder_outputs[0]
  968. pooled_output = self.pooler(sequence_output) if self.pooler is not None else None
  969. if not return_dict:
  970. return (sequence_output, pooled_output) + encoder_outputs[1:]
  971. return BaseModelOutputWithPoolingAndCrossAttentions(
  972. last_hidden_state=sequence_output,
  973. pooler_output=pooled_output,
  974. past_key_values=encoder_outputs.past_key_values,
  975. hidden_states=encoder_outputs.hidden_states,
  976. attentions=encoder_outputs.attentions,
  977. cross_attentions=encoder_outputs.cross_attentions,
  978. )
  979. @auto_docstring(
  980. custom_intro="""
  981. The bare BridgeTower Model transformer outputting BridgeTowerModelOutput object without any specific head on
  982. """
  983. )
  984. class BridgeTowerModel(BridgeTowerPreTrainedModel):
  985. def __init__(self, config):
  986. super().__init__(config)
  987. self.config = config
  988. vision_config = config.vision_config
  989. text_config = config.text_config
  990. if config.share_cross_modal_transformer_layers:
  991. self.cross_modal_text_transform = nn.Linear(text_config.hidden_size, config.hidden_size)
  992. self.cross_modal_image_transform = nn.Linear(vision_config.hidden_size, config.hidden_size)
  993. else:
  994. self.cross_modal_text_transform = nn.ModuleList(
  995. [nn.Linear(text_config.hidden_size, config.hidden_size) for _ in range(config.num_hidden_layers)]
  996. )
  997. self.cross_modal_image_transform = nn.ModuleList(
  998. [nn.Linear(vision_config.hidden_size, config.hidden_size) for _ in range(config.num_hidden_layers)]
  999. )
  1000. self.token_type_embeddings = nn.Embedding(2, config.hidden_size)
  1001. self.vision_model = BridgeTowerVisionModel(vision_config)
  1002. self.text_model = BridgeTowerTextModel(text_config)
  1003. if not vision_config.share_layernorm and config.init_layernorm_from_vision_encoder:
  1004. for ln in self.vision_model.visual.cross_modal_ln_separate:
  1005. ln.weight.data = self.vision_model.visual.ln_post.weight.data
  1006. ln.bias.data = self.vision_model.visual.ln_post.bias.data
  1007. self.cross_modal_image_layers = nn.ModuleList(
  1008. [BridgeTowerBertCrossLayer(text_config, layer_idx=i) for i in range(config.num_hidden_layers)]
  1009. )
  1010. self.cross_modal_text_layers = nn.ModuleList(
  1011. [BridgeTowerBertCrossLayer(text_config, layer_idx=i) for i in range(config.num_hidden_layers)]
  1012. )
  1013. # Class token => Linear => Tanh
  1014. self.cross_modal_image_pooler = BridgeTowerPooler(config)
  1015. self.cross_modal_text_pooler = BridgeTowerPooler(config)
  1016. # Initialize BridgeTower Components
  1017. self.cross_modal_text_layernorm = nn.LayerNorm(config.hidden_size, eps=config.layer_norm_eps)
  1018. self.cross_modal_image_layernorm = nn.LayerNorm(config.hidden_size, eps=config.layer_norm_eps)
  1019. if config.share_link_tower_layers:
  1020. self.cross_modal_text_link_tower = BridgeTowerLinkTower(config)
  1021. self.cross_modal_image_link_tower = BridgeTowerLinkTower(config)
  1022. else:
  1023. self.cross_modal_text_link_tower = nn.ModuleList(
  1024. [BridgeTowerLinkTower(config) for _ in range(config.num_hidden_layers - 1)]
  1025. )
  1026. self.cross_modal_image_link_tower = nn.ModuleList(
  1027. [BridgeTowerLinkTower(config) for _ in range(config.num_hidden_layers - 1)]
  1028. )
  1029. self.post_init()
  1030. def get_input_embeddings(self):
  1031. return self.text_model.get_input_embeddings()
  1032. def set_input_embeddings(self, value):
  1033. self.text_model.set_input_embeddings(value)
  1034. @auto_docstring
  1035. def forward(
  1036. self,
  1037. input_ids: Optional[torch.LongTensor] = None,
  1038. attention_mask: Optional[torch.FloatTensor] = None,
  1039. token_type_ids: Optional[torch.LongTensor] = None,
  1040. pixel_values: Optional[torch.FloatTensor] = None,
  1041. pixel_mask: Optional[torch.LongTensor] = None,
  1042. head_mask: Optional[torch.FloatTensor] = None,
  1043. inputs_embeds: Optional[torch.FloatTensor] = None,
  1044. image_embeds: Optional[torch.FloatTensor] = None,
  1045. image_token_type_idx: Optional[int] = None,
  1046. output_attentions: Optional[bool] = None,
  1047. output_hidden_states: Optional[bool] = None,
  1048. return_dict: Optional[bool] = None,
  1049. labels: Optional[torch.LongTensor] = None,
  1050. interpolate_pos_encoding: bool = False,
  1051. ) -> Union[tuple[torch.Tensor], BridgeTowerModelOutput]:
  1052. r"""
  1053. image_embeds (`torch.FloatTensor` of shape `(batch_size, num_patches, hidden_size)`, *optional*):
  1054. Optionally, instead of passing `pixel_values`, you can choose to directly pass an embedded representation.
  1055. This is useful if you want more control over how to convert `pixel_values` into patch embeddings.
  1056. image_token_type_idx (`int`, *optional*):
  1057. - The token type ids for images.
  1058. output_hidden_states (`bool`, *optional*):
  1059. If set to `True`, hidden states are returned as a list containing the hidden states of text, image, and
  1060. cross-modal components respectively. i.e. `(hidden_states_text, hidden_states_image,
  1061. hidden_states_cross_modal)` where each element is a list of the hidden states of the corresponding
  1062. modality. `hidden_states_txt/img` are a list of tensors corresponding to unimodal hidden states and
  1063. `hidden_states_cross_modal` is a list of tuples containing `cross_modal_text_hidden_states` and
  1064. `cross_modal_image_hidden_states` of each brdige layer.
  1065. labels (`torch.LongTensor` of shape `(batch_size,)`, *optional*):
  1066. Labels are currently not supported.
  1067. Examples:
  1068. ```python
  1069. >>> from transformers import BridgeTowerProcessor, BridgeTowerModel
  1070. >>> from PIL import Image
  1071. >>> import requests
  1072. >>> # prepare image and text
  1073. >>> url = "http://images.cocodataset.org/val2017/000000039769.jpg"
  1074. >>> image = Image.open(requests.get(url, stream=True).raw)
  1075. >>> text = "hello world"
  1076. >>> processor = BridgeTowerProcessor.from_pretrained("BridgeTower/bridgetower-base")
  1077. >>> model = BridgeTowerModel.from_pretrained("BridgeTower/bridgetower-base")
  1078. >>> inputs = processor(image, text, return_tensors="pt")
  1079. >>> outputs = model(**inputs)
  1080. >>> outputs.keys()
  1081. odict_keys(['text_features', 'image_features', 'pooler_output'])
  1082. ```"""
  1083. output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
  1084. output_hidden_states = (
  1085. output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
  1086. )
  1087. all_hidden_states_text = () if output_hidden_states else None
  1088. all_hidden_states_image = () if output_hidden_states else None
  1089. all_hidden_states_cross = () if output_hidden_states else None
  1090. all_hidden_states = () if output_hidden_states else None
  1091. all_self_attentions = () if output_attentions else None
  1092. if inputs_embeds is not None and input_ids is None:
  1093. raise NotImplementedError(
  1094. "BridgeTowerModel does not use `inputs_embeds`. Make sure to pass in `input_ids` instead."
  1095. )
  1096. return_dict = return_dict if return_dict is not None else self.config.use_return_dict
  1097. image_token_type_idx = image_token_type_idx if image_token_type_idx else 1
  1098. input_shape = input_ids.size()
  1099. text_embeds = self.text_model.embeddings(input_ids=input_ids)
  1100. if output_hidden_states:
  1101. all_hidden_states_text += (text_embeds,)
  1102. if attention_mask is None:
  1103. attention_mask = torch.ones(input_shape, dtype=torch.long, device=input_ids.device)
  1104. extend_text_masks = self.text_model.get_extended_attention_mask(attention_mask, input_shape).to(
  1105. input_ids.device
  1106. )
  1107. # The split_index determines how many layers of the uni-modal encoder are applied before the cross-modal encoder
  1108. split_index = len(self.text_model.encoder.layer) - self.config.num_hidden_layers + 1
  1109. # Run the first 'split_index' layers of the textual encoder
  1110. for layer in self.text_model.encoder.layer[:split_index]:
  1111. text_embeds = layer(text_embeds, extend_text_masks)[0]
  1112. if output_hidden_states:
  1113. all_hidden_states_text += (text_embeds,)
  1114. if image_embeds is None:
  1115. image_embeds = self.vision_model.visual.forward_pre(
  1116. pixel_values.type(self.vision_model.dtype), interpolate_pos_encoding=interpolate_pos_encoding
  1117. )
  1118. else:
  1119. # Permute as BridgeTowerResidualAttention has batch_first=True
  1120. image_embeds = image_embeds.permute(1, 0, 2)
  1121. if output_hidden_states:
  1122. all_hidden_states_image += (image_embeds,)
  1123. # Run the first 'split_index' layers of the visual encoder
  1124. for block in self.vision_model.visual.transformer.resblocks[:split_index]:
  1125. image_embeds = block(image_embeds)
  1126. if output_hidden_states:
  1127. all_hidden_states_image += (image_embeds,)
  1128. image_embeds_with_ln = self.vision_model.visual.forward_post(image_embeds.type(self.vision_model.dtype))
  1129. # first layer is a special case because we don't have the output from the cross-encoder yet
  1130. cross_modal_text = self.cross_modal_text_transform(text_embeds)
  1131. text_token_type_embeddings = self.token_type_embeddings(
  1132. torch.zeros(1, dtype=torch.long, device=input_ids.device)
  1133. ).expand_as(cross_modal_text)
  1134. cross_modal_text = self.cross_modal_text_layernorm(cross_modal_text + text_token_type_embeddings)
  1135. image_embeds_with_ln = self.cross_modal_image_transform(image_embeds_with_ln)
  1136. image_token_type_embeddings = self.token_type_embeddings(
  1137. torch.full((1,), image_token_type_idx, dtype=torch.long, device=input_ids.device)
  1138. ).expand_as(image_embeds_with_ln)
  1139. image_embeds_with_ln = image_embeds_with_ln + image_token_type_embeddings
  1140. cross_modal_image = self.cross_modal_image_layernorm(image_embeds_with_ln)
  1141. pixel_mask = torch.ones(
  1142. (cross_modal_image.size(0), cross_modal_image.size(1)),
  1143. dtype=torch.long,
  1144. device=input_ids.device,
  1145. )
  1146. extend_image_masks = self.text_model.get_extended_attention_mask(pixel_mask, pixel_mask.size()).to(
  1147. input_ids.device
  1148. )
  1149. layer_outputs_text = self.cross_modal_text_layers[0](
  1150. cross_modal_text,
  1151. cross_modal_image,
  1152. attention_mask=extend_text_masks,
  1153. encoder_attention_mask=extend_image_masks,
  1154. output_attentions=output_attentions,
  1155. )
  1156. cross_text_features = layer_outputs_text[0]
  1157. layer_outputs_image = self.cross_modal_image_layers[0](
  1158. cross_modal_image,
  1159. cross_modal_text,
  1160. attention_mask=extend_image_masks,
  1161. encoder_attention_mask=extend_text_masks,
  1162. output_attentions=output_attentions,
  1163. )
  1164. cross_image_features = layer_outputs_image[0]
  1165. if output_hidden_states:
  1166. all_hidden_states_cross += ((cross_text_features, cross_image_features),)
  1167. if output_attentions:
  1168. all_self_attentions += ((layer_outputs_text[1], layer_outputs_image[1]),)
  1169. link_layer_index = 0
  1170. # Each of the top 6 layers of the visual and textual encoders ([split_index:]) is connected to each layer of
  1171. # the cross-modal encoder via bridge layers, which brings bottom-up alignment and fusion to the cross-modal encoder.
  1172. for i in range(split_index, len(self.text_model.encoder.layer)):
  1173. text_embeds = self.text_model.encoder.layer[i](text_embeds, extend_text_masks)[0]
  1174. image_embeds = self.vision_model.visual.transformer.resblocks[i](image_embeds).type(
  1175. self.vision_model.dtype
  1176. )
  1177. image_embeds_with_ln = (
  1178. self.cross_modal_image_transform(self.vision_model.visual.forward_post(image_embeds))
  1179. + image_token_type_embeddings
  1180. )
  1181. text_link_tower = self.cross_modal_text_link_tower[link_layer_index]
  1182. image_link_tower = self.cross_modal_image_link_tower[link_layer_index]
  1183. # Bridge layers for textual and visual encoders
  1184. cross_text_features_ = text_link_tower(
  1185. self.cross_modal_text_transform(text_embeds) + text_token_type_embeddings,
  1186. cross_text_features,
  1187. extend_text_masks,
  1188. )
  1189. cross_image_features_ = image_link_tower(image_embeds_with_ln, cross_image_features, extend_image_masks)
  1190. # Cross-modal encoder via bridge layers of textual and visual encoders
  1191. layer_outputs_text = self.cross_modal_text_layers[link_layer_index + 1](
  1192. cross_text_features_,
  1193. cross_image_features_,
  1194. attention_mask=extend_text_masks,
  1195. encoder_attention_mask=extend_image_masks,
  1196. output_attentions=output_attentions,
  1197. )
  1198. cross_text_features = layer_outputs_text[0]
  1199. layer_outputs_image = self.cross_modal_image_layers[link_layer_index + 1](
  1200. cross_image_features_,
  1201. cross_text_features_,
  1202. attention_mask=extend_image_masks,
  1203. encoder_attention_mask=extend_text_masks,
  1204. output_attentions=output_attentions,
  1205. )
  1206. cross_image_features = layer_outputs_image[0]
  1207. link_layer_index += 1
  1208. if output_hidden_states:
  1209. all_hidden_states_text += (text_embeds,)
  1210. all_hidden_states_image += (image_embeds,)
  1211. all_hidden_states_cross += ((cross_text_features, cross_image_features),)
  1212. if output_attentions:
  1213. all_self_attentions += ((layer_outputs_text[1], layer_outputs_image[1]),)
  1214. # Concatenate the cls token of the text and image features to get the final represtation
  1215. text_features, image_features = cross_text_features, cross_image_features
  1216. cls_features = self.get_cls_features(text_features, image_features)
  1217. if output_hidden_states:
  1218. all_hidden_states = (all_hidden_states_text, all_hidden_states_image, all_hidden_states_cross)
  1219. if not return_dict:
  1220. return tuple(
  1221. v
  1222. for v in [text_features, image_features, cls_features, all_hidden_states, all_self_attentions]
  1223. if v is not None
  1224. )
  1225. return BridgeTowerModelOutput(
  1226. text_features=text_features,
  1227. image_features=image_features,
  1228. pooler_output=cls_features,
  1229. hidden_states=all_hidden_states,
  1230. attentions=all_self_attentions,
  1231. )
  1232. def get_cls_features(self, text_features, image_features):
  1233. cls_features_text = self.cross_modal_text_pooler(text_features)
  1234. cls_features_image = self.cross_modal_image_pooler(image_features)
  1235. return torch.cat([cls_features_text, cls_features_image], dim=-1)
  1236. # Copied from transformers.models.vilt.modeling_vilt.ViltPredictionHeadTransform with Vilt->BridgeTower
  1237. class BridgeTowerPredictionHeadTransform(nn.Module):
  1238. def __init__(self, config):
  1239. super().__init__()
  1240. self.dense = nn.Linear(config.hidden_size, config.hidden_size)
  1241. if isinstance(config.hidden_act, str):
  1242. self.transform_act_fn = ACT2FN[config.hidden_act]
  1243. else:
  1244. self.transform_act_fn = config.hidden_act
  1245. self.LayerNorm = nn.LayerNorm(config.hidden_size, eps=config.layer_norm_eps)
  1246. def forward(self, hidden_states):
  1247. hidden_states = self.dense(hidden_states)
  1248. hidden_states = self.transform_act_fn(hidden_states)
  1249. hidden_states = self.LayerNorm(hidden_states)
  1250. return hidden_states
  1251. class BridgeTowerMLMHead(nn.Module):
  1252. def __init__(self, config, weight=None):
  1253. super().__init__()
  1254. self.config = config
  1255. self.transform = BridgeTowerPredictionHeadTransform(config)
  1256. self.decoder = nn.Linear(config.hidden_size, config.text_config.vocab_size, bias=False)
  1257. self.bias = nn.Parameter(torch.zeros(config.text_config.vocab_size))
  1258. if weight is not None:
  1259. self.decoder.weight = weight
  1260. def forward(self, x):
  1261. mlm_score = self.transform(x)
  1262. mlm_score = self.decoder(mlm_score) + self.bias
  1263. return mlm_score
  1264. class BridgeTowerITMHead(nn.Module):
  1265. def __init__(self, hidden_size):
  1266. super().__init__()
  1267. self.fc = nn.Linear(hidden_size, 2)
  1268. def forward(self, x):
  1269. itm_score = self.fc(x)
  1270. return itm_score
  1271. @auto_docstring(
  1272. custom_intro="""
  1273. BridgeTower Model with a language modeling head on top as done during pretraining.
  1274. """
  1275. )
  1276. class BridgeTowerForMaskedLM(BridgeTowerPreTrainedModel):
  1277. _tied_weights_keys = ["mlm_score.decoder.weight"]
  1278. def __init__(self, config):
  1279. super().__init__(config)
  1280. self.bridgetower = BridgeTowerModel(config)
  1281. self.mlm_score = BridgeTowerMLMHead(config)
  1282. # Initialize weights and apply final processing
  1283. self.post_init()
  1284. def get_output_embeddings(self):
  1285. return self.mlm_score.decoder
  1286. def set_output_embeddings(self, new_embeddings):
  1287. self.mlm_score.decoder = new_embeddings
  1288. @auto_docstring
  1289. def forward(
  1290. self,
  1291. input_ids: Optional[torch.LongTensor] = None,
  1292. attention_mask: Optional[torch.FloatTensor] = None,
  1293. token_type_ids: Optional[torch.LongTensor] = None,
  1294. pixel_values: Optional[torch.FloatTensor] = None,
  1295. pixel_mask: Optional[torch.LongTensor] = None,
  1296. head_mask: Optional[torch.FloatTensor] = None,
  1297. inputs_embeds: Optional[torch.FloatTensor] = None,
  1298. image_embeds: Optional[torch.FloatTensor] = None,
  1299. output_attentions: Optional[bool] = None,
  1300. output_hidden_states: Optional[bool] = None,
  1301. return_dict: Optional[bool] = None,
  1302. labels: Optional[torch.LongTensor] = None,
  1303. ) -> Union[MaskedLMOutput, tuple[torch.FloatTensor]]:
  1304. r"""
  1305. image_embeds (`torch.FloatTensor` of shape `(batch_size, num_patches, hidden_size)`, *optional*):
  1306. Optionally, instead of passing `pixel_values`, you can choose to directly pass an embedded representation.
  1307. This is useful if you want more control over how to convert `pixel_values` into patch embeddings.
  1308. labels (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*):
  1309. Labels for computing the masked language modeling loss. Indices should be in `[-100, 0, ...,
  1310. config.vocab_size]` (see `input_ids` docstring) Tokens with indices set to `-100` are ignored (masked), the
  1311. loss is only computed for the tokens with labels in `[0, ..., config.vocab_size]`
  1312. Examples:
  1313. ```python
  1314. >>> from transformers import BridgeTowerProcessor, BridgeTowerForMaskedLM
  1315. >>> from PIL import Image
  1316. >>> import requests
  1317. >>> url = "http://images.cocodataset.org/val2017/000000360943.jpg"
  1318. >>> image = Image.open(requests.get(url, stream=True).raw).convert("RGB")
  1319. >>> text = "a <mask> looking out of the window"
  1320. >>> processor = BridgeTowerProcessor.from_pretrained("BridgeTower/bridgetower-base-itm-mlm")
  1321. >>> model = BridgeTowerForMaskedLM.from_pretrained("BridgeTower/bridgetower-base-itm-mlm")
  1322. >>> # prepare inputs
  1323. >>> encoding = processor(image, text, return_tensors="pt")
  1324. >>> # forward pass
  1325. >>> outputs = model(**encoding)
  1326. >>> results = processor.decode(outputs.logits.argmax(dim=-1).squeeze(0).tolist())
  1327. >>> print(results)
  1328. .a cat looking out of the window.
  1329. ```"""
  1330. return_dict = return_dict if return_dict is not None else self.config.use_return_dict
  1331. outputs = self.bridgetower(
  1332. input_ids,
  1333. attention_mask=attention_mask,
  1334. token_type_ids=token_type_ids,
  1335. pixel_values=pixel_values,
  1336. pixel_mask=pixel_mask,
  1337. head_mask=head_mask,
  1338. inputs_embeds=inputs_embeds,
  1339. image_embeds=image_embeds,
  1340. output_attentions=output_attentions,
  1341. output_hidden_states=output_hidden_states,
  1342. return_dict=return_dict,
  1343. )
  1344. mlm_logits = self.mlm_score(outputs.text_features if return_dict else outputs[0])
  1345. masked_lm_loss = None
  1346. if labels is not None:
  1347. loss_fct = CrossEntropyLoss() # -100 index = padding token
  1348. labels = labels.to(mlm_logits.device)
  1349. masked_lm_loss = loss_fct(mlm_logits.view(-1, self.config.text_config.vocab_size), labels.view(-1))
  1350. if not return_dict:
  1351. output = tuple(mlm_logits)
  1352. return ((masked_lm_loss,) + output) if masked_lm_loss is not None else output
  1353. return MaskedLMOutput(
  1354. loss=masked_lm_loss,
  1355. logits=mlm_logits,
  1356. hidden_states=outputs.hidden_states,
  1357. attentions=outputs.attentions,
  1358. )
  1359. @auto_docstring(
  1360. custom_intro="""
  1361. BridgeTower Model transformer with a classifier head on top (a linear layer on top of the final hidden state of the
  1362. [CLS] token) for image-to-text matching.
  1363. """
  1364. )
  1365. class BridgeTowerForImageAndTextRetrieval(BridgeTowerPreTrainedModel):
  1366. def __init__(self, config):
  1367. super().__init__(config)
  1368. self.bridgetower = BridgeTowerModel(config)
  1369. self.itm_score = BridgeTowerITMHead(config.hidden_size * 2)
  1370. # Initialize weights and apply final processing
  1371. self.post_init()
  1372. @auto_docstring
  1373. def forward(
  1374. self,
  1375. input_ids: Optional[torch.LongTensor] = None,
  1376. attention_mask: Optional[torch.FloatTensor] = None,
  1377. token_type_ids: Optional[torch.LongTensor] = None,
  1378. pixel_values: Optional[torch.FloatTensor] = None,
  1379. pixel_mask: Optional[torch.LongTensor] = None,
  1380. head_mask: Optional[torch.FloatTensor] = None,
  1381. inputs_embeds: Optional[torch.FloatTensor] = None,
  1382. image_embeds: Optional[torch.FloatTensor] = None,
  1383. output_attentions: Optional[bool] = None,
  1384. output_hidden_states: Optional[bool] = None,
  1385. return_dict: Optional[bool] = None,
  1386. labels: Optional[torch.LongTensor] = None,
  1387. ) -> Union[SequenceClassifierOutput, tuple[torch.FloatTensor]]:
  1388. r"""
  1389. image_embeds (`torch.FloatTensor` of shape `(batch_size, num_patches, hidden_size)`, *optional*):
  1390. Optionally, instead of passing `pixel_values`, you can choose to directly pass an embedded representation.
  1391. This is useful if you want more control over how to convert `pixel_values` into patch embeddings.
  1392. labels (`torch.LongTensor` of shape `(batch_size, 1)`, *optional*):
  1393. Labels for computing the image-text matching loss. 0 means the pairs don't match and 1 means they match.
  1394. The pairs with 0 will be skipped for calculation.
  1395. Examples:
  1396. ```python
  1397. >>> from transformers import BridgeTowerProcessor, BridgeTowerForImageAndTextRetrieval
  1398. >>> import requests
  1399. >>> from PIL import Image
  1400. >>> url = "http://images.cocodataset.org/val2017/000000039769.jpg"
  1401. >>> image = Image.open(requests.get(url, stream=True).raw)
  1402. >>> texts = ["An image of two cats chilling on a couch", "A football player scoring a goal"]
  1403. >>> processor = BridgeTowerProcessor.from_pretrained("BridgeTower/bridgetower-base-itm-mlm")
  1404. >>> model = BridgeTowerForImageAndTextRetrieval.from_pretrained("BridgeTower/bridgetower-base-itm-mlm")
  1405. >>> # forward pass
  1406. >>> scores = dict()
  1407. >>> for text in texts:
  1408. ... # prepare inputs
  1409. ... encoding = processor(image, text, return_tensors="pt")
  1410. ... outputs = model(**encoding)
  1411. ... scores[text] = outputs.logits[0, 1].item()
  1412. ```"""
  1413. return_dict = return_dict if return_dict is not None else self.config.use_return_dict
  1414. outputs = self.bridgetower(
  1415. input_ids,
  1416. attention_mask=attention_mask,
  1417. token_type_ids=token_type_ids,
  1418. pixel_values=pixel_values,
  1419. pixel_mask=pixel_mask,
  1420. head_mask=head_mask,
  1421. inputs_embeds=inputs_embeds,
  1422. image_embeds=image_embeds,
  1423. output_attentions=output_attentions,
  1424. output_hidden_states=output_hidden_states,
  1425. return_dict=return_dict,
  1426. )
  1427. pooler_output = outputs.pooler_output if return_dict else outputs[2]
  1428. logits = self.itm_score(pooler_output)
  1429. itm_loss = None
  1430. if labels is not None:
  1431. loss_fct = CrossEntropyLoss()
  1432. labels = labels.to(logits.device)
  1433. itm_loss = loss_fct(logits, labels)
  1434. if not return_dict:
  1435. output = tuple(logits)
  1436. return ((itm_loss,) + output) if itm_loss is not None else output
  1437. return SequenceClassifierOutput(
  1438. loss=itm_loss,
  1439. logits=logits,
  1440. hidden_states=outputs.hidden_states,
  1441. attentions=outputs.attentions,
  1442. )
  1443. class BridgeTowerContrastiveHead(nn.Module):
  1444. def __init__(self, hidden_size, embed_size):
  1445. super().__init__()
  1446. self.fc = nn.Linear(hidden_size, embed_size)
  1447. def forward(self, x):
  1448. x = self.fc(x)
  1449. return x
  1450. @auto_docstring(
  1451. custom_intro="""
  1452. BridgeTower Model with a image-text contrastive head on top computing image-text contrastive loss.
  1453. """
  1454. )
  1455. class BridgeTowerForContrastiveLearning(BridgeTowerPreTrainedModel):
  1456. def __init__(self, config):
  1457. super().__init__(config)
  1458. self.bridgetower = BridgeTowerModel(config)
  1459. self.itc_text_head = BridgeTowerContrastiveHead(config.hidden_size, config.contrastive_hidden_size)
  1460. self.itc_image_head = BridgeTowerContrastiveHead(config.hidden_size, config.contrastive_hidden_size)
  1461. self.itc_cross_modal_head = BridgeTowerContrastiveHead(config.hidden_size * 2, config.contrastive_hidden_size)
  1462. self.logit_scale = nn.Parameter(torch.tensor(self.config.logit_scale_init_value))
  1463. # Initialize weights and apply final processing
  1464. self.post_init()
  1465. @auto_docstring
  1466. def forward(
  1467. self,
  1468. input_ids: Optional[torch.LongTensor] = None,
  1469. attention_mask: Optional[torch.FloatTensor] = None,
  1470. token_type_ids: Optional[torch.LongTensor] = None,
  1471. pixel_values: Optional[torch.FloatTensor] = None,
  1472. pixel_mask: Optional[torch.LongTensor] = None,
  1473. head_mask: Optional[torch.FloatTensor] = None,
  1474. inputs_embeds: Optional[torch.FloatTensor] = None,
  1475. image_embeds: Optional[torch.FloatTensor] = None,
  1476. output_attentions: Optional[bool] = None,
  1477. output_hidden_states: Optional[bool] = True,
  1478. return_dict: Optional[bool] = None,
  1479. return_loss: Optional[bool] = None,
  1480. ) -> Union[BridgeTowerContrastiveOutput, tuple[torch.FloatTensor]]:
  1481. r"""
  1482. image_embeds (`torch.FloatTensor` of shape `(batch_size, num_patches, hidden_size)`, *optional*):
  1483. Optionally, instead of passing `pixel_values`, you can choose to directly pass an embedded representation.
  1484. This is useful if you want more control over how to convert `pixel_values` into patch embeddings.
  1485. return_loss (`bool`, *optional*):
  1486. Whether or not to return the contrastive loss.
  1487. Examples:
  1488. ```python
  1489. >>> from transformers import BridgeTowerProcessor, BridgeTowerForContrastiveLearning
  1490. >>> import requests
  1491. >>> from PIL import Image
  1492. >>> import torch
  1493. >>> image_urls = [
  1494. ... "https://farm4.staticflickr.com/3395/3428278415_81c3e27f15_z.jpg",
  1495. ... "http://images.cocodataset.org/val2017/000000039769.jpg",
  1496. ... ]
  1497. >>> texts = ["two dogs in a car", "two cats sleeping on a couch"]
  1498. >>> images = [Image.open(requests.get(url, stream=True).raw) for url in image_urls]
  1499. >>> processor = BridgeTowerProcessor.from_pretrained("BridgeTower/bridgetower-large-itm-mlm-itc")
  1500. >>> model = BridgeTowerForContrastiveLearning.from_pretrained("BridgeTower/bridgetower-large-itm-mlm-itc")
  1501. >>> inputs = processor(images, texts, padding=True, return_tensors="pt")
  1502. >>> loss = model(**inputs, return_loss=True).loss
  1503. >>> inputs = processor(images, texts[::-1], padding=True, return_tensors="pt")
  1504. >>> loss_swapped = model(**inputs, return_loss=True).loss
  1505. >>> print("Loss", round(loss.item(), 4))
  1506. Loss 0.0019
  1507. >>> print("Loss with swapped images", round(loss_swapped.item(), 4))
  1508. Loss with swapped images 2.126
  1509. ```"""
  1510. return_dict = return_dict if return_dict is not None else self.config.use_return_dict
  1511. outputs = self.bridgetower(
  1512. input_ids,
  1513. attention_mask=attention_mask,
  1514. token_type_ids=token_type_ids,
  1515. pixel_values=pixel_values,
  1516. pixel_mask=pixel_mask,
  1517. head_mask=head_mask,
  1518. inputs_embeds=inputs_embeds,
  1519. image_embeds=image_embeds,
  1520. output_attentions=output_attentions,
  1521. output_hidden_states=True,
  1522. return_dict=return_dict,
  1523. )
  1524. pooler_output = outputs.pooler_output if return_dict else outputs[2]
  1525. hidden_states_txt, hidden_states_img, hidden_states_cross_modal = (
  1526. outputs.hidden_states if return_dict else outputs[3]
  1527. )
  1528. text_embeds = hidden_states_txt[-1]
  1529. image_embeds = hidden_states_img[-1]
  1530. image_embeds_with_ln = self.bridgetower.vision_model.visual.forward_post(image_embeds)
  1531. image_token_type_embeddings = self.bridgetower.token_type_embeddings(
  1532. torch.full((1,), 1, dtype=torch.long, device=self.bridgetower.token_type_embeddings.weight.device)
  1533. ).expand_as(image_embeds_with_ln)
  1534. image_embeds = self.bridgetower.cross_modal_image_transform(image_embeds_with_ln) + image_token_type_embeddings
  1535. # normalized features
  1536. text_embeds = nn.functional.normalize(self.itc_text_head(text_embeds[:, 0, :]), dim=-1, p=2)
  1537. image_embeds = nn.functional.normalize(self.itc_image_head(image_embeds[:, 0, :]), dim=-1, p=2).to(
  1538. device=text_embeds.device
  1539. )
  1540. cross_embeds = nn.functional.normalize(self.itc_cross_modal_head(pooler_output), dim=-1, p=2).to(
  1541. device=text_embeds.device
  1542. )
  1543. logits = torch.stack([text_embeds, image_embeds, cross_embeds], dim=-2)
  1544. logit_scale = self.logit_scale.exp().to(device=text_embeds.device)
  1545. logits_text_to_image = torch.matmul(text_embeds, image_embeds.t()) * logit_scale
  1546. logits_text_to_cross = torch.matmul(text_embeds, cross_embeds.t()) * logit_scale
  1547. logits_image_to_cross = torch.matmul(image_embeds, cross_embeds.t()) * logit_scale
  1548. itc_loss = None
  1549. if return_loss:
  1550. labels = torch.arange(len(logits), device=logits.device)
  1551. text_to_image_loss = nn.functional.cross_entropy(logits_text_to_image, labels)
  1552. text_to_cross_loss = nn.functional.cross_entropy(logits_text_to_cross, labels)
  1553. image_to_cross_loss = nn.functional.cross_entropy(logits_image_to_cross, labels)
  1554. itc_loss = (text_to_image_loss + text_to_cross_loss + image_to_cross_loss) / 3.0
  1555. if not return_dict:
  1556. output = (logits, text_embeds, image_embeds, cross_embeds) + outputs[3:]
  1557. return ((itc_loss,) + output) if itc_loss is not None else output
  1558. return BridgeTowerContrastiveOutput(
  1559. loss=itc_loss,
  1560. logits=logits,
  1561. text_embeds=text_embeds,
  1562. image_embeds=image_embeds,
  1563. cross_embeds=cross_embeds,
  1564. hidden_states=outputs.hidden_states,
  1565. attentions=outputs.attentions,
  1566. )
  1567. __all__ = [
  1568. "BridgeTowerForContrastiveLearning",
  1569. "BridgeTowerForImageAndTextRetrieval",
  1570. "BridgeTowerForMaskedLM",
  1571. "BridgeTowerModel",
  1572. "BridgeTowerPreTrainedModel",
  1573. ]