modeling_altclip.py 58 KB

12345678910111213141516171819202122232425262728293031323334353637383940414243444546474849505152535455565758596061626364656667686970717273747576777879808182838485868788899091929394959697989910010110210310410510610710810911011111211311411511611711811912012112212312412512612712812913013113213313413513613713813914014114214314414514614714814915015115215315415515615715815916016116216316416516616716816917017117217317417517617717817918018118218318418518618718818919019119219319419519619719819920020120220320420520620720820921021121221321421521621721821922022122222322422522622722822923023123223323423523623723823924024124224324424524624724824925025125225325425525625725825926026126226326426526626726826927027127227327427527627727827928028128228328428528628728828929029129229329429529629729829930030130230330430530630730830931031131231331431531631731831932032132232332432532632732832933033133233333433533633733833934034134234334434534634734834935035135235335435535635735835936036136236336436536636736836937037137237337437537637737837938038138238338438538638738838939039139239339439539639739839940040140240340440540640740840941041141241341441541641741841942042142242342442542642742842943043143243343443543643743843944044144244344444544644744844945045145245345445545645745845946046146246346446546646746846947047147247347447547647747847948048148248348448548648748848949049149249349449549649749849950050150250350450550650750850951051151251351451551651751851952052152252352452552652752852953053153253353453553653753853954054154254354454554654754854955055155255355455555655755855956056156256356456556656756856957057157257357457557657757857958058158258358458558658758858959059159259359459559659759859960060160260360460560660760860961061161261361461561661761861962062162262362462562662762862963063163263363463563663763863964064164264364464564664764864965065165265365465565665765865966066166266366466566666766866967067167267367467567667767867968068168268368468568668768868969069169269369469569669769869970070170270370470570670770870971071171271371471571671771871972072172272372472572672772872973073173273373473573673773873974074174274374474574674774874975075175275375475575675775875976076176276376476576676776876977077177277377477577677777877978078178278378478578678778878979079179279379479579679779879980080180280380480580680780880981081181281381481581681781881982082182282382482582682782882983083183283383483583683783883984084184284384484584684784884985085185285385485585685785885986086186286386486586686786886987087187287387487587687787887988088188288388488588688788888989089189289389489589689789889990090190290390490590690790890991091191291391491591691791891992092192292392492592692792892993093193293393493593693793893994094194294394494594694794894995095195295395495595695795895996096196296396496596696796896997097197297397497597697797897998098198298398498598698798898999099199299399499599699799899910001001100210031004100510061007100810091010101110121013101410151016101710181019102010211022102310241025102610271028102910301031103210331034103510361037103810391040104110421043104410451046104710481049105010511052105310541055105610571058105910601061106210631064106510661067106810691070107110721073107410751076107710781079108010811082108310841085108610871088108910901091109210931094109510961097109810991100110111021103110411051106110711081109111011111112111311141115111611171118111911201121112211231124112511261127112811291130113111321133113411351136113711381139114011411142114311441145114611471148114911501151115211531154115511561157115811591160116111621163116411651166116711681169117011711172117311741175117611771178117911801181118211831184118511861187118811891190119111921193119411951196119711981199120012011202120312041205120612071208120912101211121212131214121512161217121812191220122112221223122412251226122712281229123012311232123312341235123612371238123912401241124212431244124512461247124812491250125112521253125412551256125712581259126012611262126312641265126612671268126912701271127212731274127512761277127812791280128112821283128412851286128712881289129012911292129312941295129612971298129913001301130213031304130513061307130813091310131113121313131413151316131713181319132013211322132313241325132613271328132913301331133213331334133513361337133813391340134113421343134413451346134713481349135013511352135313541355135613571358135913601361136213631364136513661367136813691370137113721373137413751376137713781379138013811382138313841385138613871388
  1. # coding=utf-8
  2. # Copyright 2022 The BAAI Teams Authors and The HuggingFace Inc. team. All rights reserved.
  3. #
  4. # Licensed under the Apache License, Version 2.0 (the "License");
  5. # you may not use this file except in compliance with the License.
  6. # You may obtain a copy of the License at
  7. #
  8. # http://www.apache.org/licenses/LICENSE-2.0
  9. #
  10. # Unless required by applicable law or agreed to in writing, software
  11. # distributed under the License is distributed on an "AS IS" BASIS,
  12. # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  13. # See the License for the specific language governing permissions and
  14. # limitations under the License.
  15. """PyTorch AltCLIP model."""
  16. import math
  17. from dataclasses import dataclass
  18. from typing import Any, Callable, Optional, Union
  19. import torch
  20. import torch.nn as nn
  21. from ...activations import ACT2FN
  22. from ...modeling_layers import GradientCheckpointingLayer
  23. from ...modeling_outputs import (
  24. BaseModelOutput,
  25. BaseModelOutputWithPooling,
  26. BaseModelOutputWithPoolingAndCrossAttentions,
  27. BaseModelOutputWithPoolingAndProjection,
  28. )
  29. from ...modeling_utils import ALL_ATTENTION_FUNCTIONS, PreTrainedModel
  30. from ...pytorch_utils import apply_chunking_to_forward, find_pruneable_heads_and_indices, prune_linear_layer
  31. from ...utils import ModelOutput, auto_docstring, can_return_tuple, filter_out_non_signature_kwargs, logging, torch_int
  32. from .configuration_altclip import AltCLIPConfig, AltCLIPTextConfig, AltCLIPVisionConfig
  33. logger = logging.get_logger(__name__)
  34. # contrastive loss function, adapted from
  35. # https://sachinruk.github.io/blog/pytorch/pytorch%20lightning/loss%20function/gpu/2021/03/07/CLIP.html
  36. def contrastive_loss(logits: torch.Tensor) -> torch.Tensor:
  37. return nn.functional.cross_entropy(logits, torch.arange(len(logits), device=logits.device))
  38. def clip_loss(similarity: torch.Tensor) -> torch.Tensor:
  39. caption_loss = contrastive_loss(similarity)
  40. image_loss = contrastive_loss(similarity.t())
  41. return (caption_loss + image_loss) / 2.0
  42. @dataclass
  43. @auto_docstring
  44. # Copied from transformers.models.clip.modeling_clip.CLIPOutput with CLIP->AltCLIP
  45. class AltCLIPOutput(ModelOutput):
  46. r"""
  47. loss (`torch.FloatTensor` of shape `(1,)`, *optional*, returned when `return_loss` is `True`):
  48. Contrastive loss for image-text similarity.
  49. logits_per_image (`torch.FloatTensor` of shape `(image_batch_size, text_batch_size)`):
  50. The scaled dot product scores between `image_embeds` and `text_embeds`. This represents the image-text
  51. similarity scores.
  52. logits_per_text (`torch.FloatTensor` of shape `(text_batch_size, image_batch_size)`):
  53. The scaled dot product scores between `text_embeds` and `image_embeds`. This represents the text-image
  54. similarity scores.
  55. text_embeds (`torch.FloatTensor` of shape `(batch_size, output_dim`):
  56. The text embeddings obtained by applying the projection layer to the pooled output of [`AltCLIPTextModel`].
  57. image_embeds (`torch.FloatTensor` of shape `(batch_size, output_dim`):
  58. The image embeddings obtained by applying the projection layer to the pooled output of [`AltCLIPVisionModel`].
  59. text_model_output (`BaseModelOutputWithPooling`):
  60. The output of the [`AltCLIPTextModel`].
  61. vision_model_output (`BaseModelOutputWithPooling`):
  62. The output of the [`AltCLIPVisionModel`].
  63. """
  64. loss: Optional[torch.FloatTensor] = None
  65. logits_per_image: Optional[torch.FloatTensor] = None
  66. logits_per_text: Optional[torch.FloatTensor] = None
  67. text_embeds: Optional[torch.FloatTensor] = None
  68. image_embeds: Optional[torch.FloatTensor] = None
  69. text_model_output: BaseModelOutputWithPooling = None
  70. vision_model_output: BaseModelOutputWithPooling = None
  71. def to_tuple(self) -> tuple[Any]:
  72. return tuple(
  73. self[k] if k not in ["text_model_output", "vision_model_output"] else getattr(self, k).to_tuple()
  74. for k in self.keys()
  75. )
  76. # Copied from transformers.models.roberta.modeling_roberta.RobertaEmbeddings with Roberta->AltRoberta
  77. class AltRobertaEmbeddings(nn.Module):
  78. """
  79. Same as BertEmbeddings with a tiny tweak for positional embeddings indexing.
  80. """
  81. # Copied from transformers.models.bert.modeling_bert.BertEmbeddings.__init__
  82. def __init__(self, config):
  83. super().__init__()
  84. self.word_embeddings = nn.Embedding(config.vocab_size, config.hidden_size, padding_idx=config.pad_token_id)
  85. self.position_embeddings = nn.Embedding(config.max_position_embeddings, config.hidden_size)
  86. self.token_type_embeddings = nn.Embedding(config.type_vocab_size, config.hidden_size)
  87. # self.LayerNorm is not snake-cased to stick with TensorFlow model variable name and be able to load
  88. # any TensorFlow checkpoint file
  89. self.LayerNorm = nn.LayerNorm(config.hidden_size, eps=config.layer_norm_eps)
  90. self.dropout = nn.Dropout(config.hidden_dropout_prob)
  91. # position_ids (1, len position emb) is contiguous in memory and exported when serialized
  92. self.position_embedding_type = getattr(config, "position_embedding_type", "absolute")
  93. self.register_buffer(
  94. "position_ids", torch.arange(config.max_position_embeddings).expand((1, -1)), persistent=False
  95. )
  96. self.register_buffer(
  97. "token_type_ids", torch.zeros(self.position_ids.size(), dtype=torch.long), persistent=False
  98. )
  99. # End copy
  100. self.padding_idx = config.pad_token_id
  101. self.position_embeddings = nn.Embedding(
  102. config.max_position_embeddings, config.hidden_size, padding_idx=self.padding_idx
  103. )
  104. def forward(
  105. self, input_ids=None, token_type_ids=None, position_ids=None, inputs_embeds=None, past_key_values_length=0
  106. ):
  107. if position_ids is None:
  108. if input_ids is not None:
  109. # Create the position ids from the input token ids. Any padded tokens remain padded.
  110. position_ids = create_position_ids_from_input_ids(input_ids, self.padding_idx, past_key_values_length)
  111. else:
  112. position_ids = self.create_position_ids_from_inputs_embeds(inputs_embeds)
  113. if input_ids is not None:
  114. input_shape = input_ids.size()
  115. else:
  116. input_shape = inputs_embeds.size()[:-1]
  117. seq_length = input_shape[1]
  118. # Setting the token_type_ids to the registered buffer in constructor where it is all zeros, which usually occurs
  119. # when its auto-generated, registered buffer helps users when tracing the model without passing token_type_ids, solves
  120. # issue #5664
  121. if token_type_ids is None:
  122. if hasattr(self, "token_type_ids"):
  123. buffered_token_type_ids = self.token_type_ids[:, :seq_length]
  124. buffered_token_type_ids_expanded = buffered_token_type_ids.expand(input_shape[0], seq_length)
  125. token_type_ids = buffered_token_type_ids_expanded
  126. else:
  127. token_type_ids = torch.zeros(input_shape, dtype=torch.long, device=self.position_ids.device)
  128. if inputs_embeds is None:
  129. inputs_embeds = self.word_embeddings(input_ids)
  130. token_type_embeddings = self.token_type_embeddings(token_type_ids)
  131. embeddings = inputs_embeds + token_type_embeddings
  132. if self.position_embedding_type == "absolute":
  133. position_embeddings = self.position_embeddings(position_ids)
  134. embeddings += position_embeddings
  135. embeddings = self.LayerNorm(embeddings)
  136. embeddings = self.dropout(embeddings)
  137. return embeddings
  138. def create_position_ids_from_inputs_embeds(self, inputs_embeds):
  139. """
  140. We are provided embeddings directly. We cannot infer which are padded so just generate sequential position ids.
  141. Args:
  142. inputs_embeds: torch.Tensor
  143. Returns: torch.Tensor
  144. """
  145. input_shape = inputs_embeds.size()[:-1]
  146. sequence_length = input_shape[1]
  147. position_ids = torch.arange(
  148. self.padding_idx + 1, sequence_length + self.padding_idx + 1, dtype=torch.long, device=inputs_embeds.device
  149. )
  150. return position_ids.unsqueeze(0).expand(input_shape)
  151. class AltRobertaSelfAttention(nn.Module):
  152. def __init__(self, config, position_embedding_type=None):
  153. super().__init__()
  154. if config.hidden_size % config.num_attention_heads != 0 and not hasattr(config, "embedding_size"):
  155. raise ValueError(
  156. f"The hidden size ({config.hidden_size}) is not a multiple of the number of attention "
  157. f"heads ({config.num_attention_heads})"
  158. )
  159. self.num_attention_heads = config.num_attention_heads
  160. self.attention_head_size = int(config.hidden_size / config.num_attention_heads)
  161. self.all_head_size = self.num_attention_heads * self.attention_head_size
  162. self.query = nn.Linear(config.hidden_size, self.all_head_size)
  163. self.key = nn.Linear(config.hidden_size, self.all_head_size)
  164. self.value = nn.Linear(config.hidden_size, self.all_head_size)
  165. self.dropout = nn.Dropout(config.attention_probs_dropout_prob)
  166. self.position_embedding_type = position_embedding_type or getattr(
  167. config, "position_embedding_type", "absolute"
  168. )
  169. if self.position_embedding_type == "relative_key" or self.position_embedding_type == "relative_key_query":
  170. self.max_position_embeddings = config.max_position_embeddings
  171. self.distance_embedding = nn.Embedding(2 * config.max_position_embeddings - 1, self.attention_head_size)
  172. def forward(
  173. self,
  174. hidden_states: torch.Tensor,
  175. attention_mask: Optional[torch.FloatTensor] = None,
  176. head_mask: Optional[torch.FloatTensor] = None,
  177. output_attentions: Optional[bool] = False,
  178. ) -> tuple[torch.Tensor]:
  179. input_shape = hidden_states.shape[:-1]
  180. hidden_shape = (*input_shape, -1, self.attention_head_size)
  181. query_layer = self.query(hidden_states).view(hidden_shape).transpose(1, 2)
  182. key_layer = self.key(hidden_states).view(hidden_shape).transpose(1, 2)
  183. value_layer = self.value(hidden_states).view(hidden_shape).transpose(1, 2)
  184. # Take the dot product between "query" and "key" to get the raw attention scores.
  185. attention_scores = torch.matmul(query_layer, key_layer.transpose(-1, -2))
  186. if self.position_embedding_type == "relative_key" or self.position_embedding_type == "relative_key_query":
  187. query_length, key_length = query_layer.shape[2], key_layer.shape[2]
  188. position_ids_l = torch.arange(query_length, dtype=torch.long, device=hidden_states.device).view(-1, 1)
  189. position_ids_r = torch.arange(key_length, dtype=torch.long, device=hidden_states.device).view(1, -1)
  190. distance = position_ids_l - position_ids_r
  191. positional_embedding = self.distance_embedding(distance + self.max_position_embeddings - 1)
  192. positional_embedding = positional_embedding.to(dtype=query_layer.dtype) # fp16 compatibility
  193. if self.position_embedding_type == "relative_key":
  194. relative_position_scores = torch.einsum("bhld,lrd->bhlr", query_layer, positional_embedding)
  195. attention_scores = attention_scores + relative_position_scores
  196. elif self.position_embedding_type == "relative_key_query":
  197. relative_position_scores_query = torch.einsum("bhld,lrd->bhlr", query_layer, positional_embedding)
  198. relative_position_scores_key = torch.einsum("bhrd,lrd->bhlr", key_layer, positional_embedding)
  199. attention_scores = attention_scores + relative_position_scores_query + relative_position_scores_key
  200. attention_scores = attention_scores / math.sqrt(self.attention_head_size)
  201. if attention_mask is not None:
  202. # Apply the attention mask is (precomputed for all layers in AltRobertaModel forward() function)
  203. attention_scores = attention_scores + attention_mask
  204. # Normalize the attention scores to probabilities.
  205. attention_probs = nn.functional.softmax(attention_scores, dim=-1)
  206. # This is actually dropping out entire tokens to attend to, which might
  207. # seem a bit unusual, but is taken from the original Transformer paper.
  208. attention_probs = self.dropout(attention_probs)
  209. # Mask heads if we want to
  210. if head_mask is not None:
  211. attention_probs = attention_probs * head_mask
  212. context_layer = torch.matmul(attention_probs, value_layer)
  213. context_layer = context_layer.permute(0, 2, 1, 3).contiguous()
  214. new_context_layer_shape = context_layer.size()[:-2] + (self.all_head_size,)
  215. context_layer = context_layer.view(new_context_layer_shape)
  216. outputs = (context_layer, attention_probs) if output_attentions else (context_layer,)
  217. return outputs
  218. # Copied from transformers.models.roberta.modeling_roberta.RobertaSelfOutput
  219. class AltRobertaSelfOutput(nn.Module):
  220. def __init__(self, config):
  221. super().__init__()
  222. self.dense = nn.Linear(config.hidden_size, config.hidden_size)
  223. self.LayerNorm = nn.LayerNorm(config.hidden_size, eps=config.layer_norm_eps)
  224. self.dropout = nn.Dropout(config.hidden_dropout_prob)
  225. def forward(self, hidden_states: torch.Tensor, input_tensor: torch.Tensor) -> torch.Tensor:
  226. hidden_states = self.dense(hidden_states)
  227. hidden_states = self.dropout(hidden_states)
  228. hidden_states = self.LayerNorm(hidden_states + input_tensor)
  229. return hidden_states
  230. ALT_ROBERTA_SELF_ATTENTION_CLASSES = {
  231. "eager": AltRobertaSelfAttention,
  232. }
  233. class AltRobertaAttention(nn.Module):
  234. def __init__(self, config, position_embedding_type=None):
  235. super().__init__()
  236. self.self = ALT_ROBERTA_SELF_ATTENTION_CLASSES[config._attn_implementation](
  237. config, position_embedding_type=position_embedding_type
  238. )
  239. self.output = AltRobertaSelfOutput(config)
  240. self.pruned_heads = set()
  241. def prune_heads(self, heads):
  242. if len(heads) == 0:
  243. return
  244. heads, index = find_pruneable_heads_and_indices(
  245. heads, self.self.num_attention_heads, self.self.attention_head_size, self.pruned_heads
  246. )
  247. # Prune linear layers
  248. self.self.query = prune_linear_layer(self.self.query, index)
  249. self.self.key = prune_linear_layer(self.self.key, index)
  250. self.self.value = prune_linear_layer(self.self.value, index)
  251. self.output.dense = prune_linear_layer(self.output.dense, index, dim=1)
  252. # Update hyper params and store pruned heads
  253. self.self.num_attention_heads = self.self.num_attention_heads - len(heads)
  254. self.self.all_head_size = self.self.attention_head_size * self.self.num_attention_heads
  255. self.pruned_heads = self.pruned_heads.union(heads)
  256. def forward(
  257. self,
  258. hidden_states: torch.Tensor,
  259. attention_mask: Optional[torch.FloatTensor] = None,
  260. head_mask: Optional[torch.FloatTensor] = None,
  261. output_attentions: Optional[bool] = False,
  262. ) -> tuple[torch.Tensor]:
  263. self_outputs = self.self(
  264. hidden_states,
  265. attention_mask=attention_mask,
  266. head_mask=head_mask,
  267. output_attentions=output_attentions,
  268. )
  269. attention_output = self.output(self_outputs[0], hidden_states)
  270. outputs = (attention_output,) + self_outputs[1:] # add attentions if we output them
  271. return outputs
  272. # Copied from transformers.models.roberta.modeling_roberta.RobertaIntermediate with Roberta->AltRoberta
  273. class AltRobertaIntermediate(nn.Module):
  274. def __init__(self, config):
  275. super().__init__()
  276. self.dense = nn.Linear(config.hidden_size, config.intermediate_size)
  277. if isinstance(config.hidden_act, str):
  278. self.intermediate_act_fn = ACT2FN[config.hidden_act]
  279. else:
  280. self.intermediate_act_fn = config.hidden_act
  281. def forward(self, hidden_states: torch.Tensor) -> torch.Tensor:
  282. hidden_states = self.dense(hidden_states)
  283. hidden_states = self.intermediate_act_fn(hidden_states)
  284. return hidden_states
  285. # Copied from transformers.models.roberta.modeling_roberta.RobertaOutput
  286. class AltRobertaOutput(nn.Module):
  287. def __init__(self, config):
  288. super().__init__()
  289. self.dense = nn.Linear(config.intermediate_size, config.hidden_size)
  290. self.LayerNorm = nn.LayerNorm(config.hidden_size, eps=config.layer_norm_eps)
  291. self.dropout = nn.Dropout(config.hidden_dropout_prob)
  292. def forward(self, hidden_states: torch.Tensor, input_tensor: torch.Tensor) -> torch.Tensor:
  293. hidden_states = self.dense(hidden_states)
  294. hidden_states = self.dropout(hidden_states)
  295. hidden_states = self.LayerNorm(hidden_states + input_tensor)
  296. return hidden_states
  297. # Copied from transformers.models.align.modeling_align.AlignTextLayer with AlignText->AltRoberta
  298. class AltRobertaLayer(GradientCheckpointingLayer):
  299. def __init__(self, config):
  300. super().__init__()
  301. self.chunk_size_feed_forward = config.chunk_size_feed_forward
  302. self.seq_len_dim = 1
  303. self.attention = AltRobertaAttention(config)
  304. self.intermediate = AltRobertaIntermediate(config)
  305. self.output = AltRobertaOutput(config)
  306. def forward(
  307. self,
  308. hidden_states: torch.Tensor,
  309. attention_mask: Optional[torch.FloatTensor] = None,
  310. head_mask: Optional[torch.FloatTensor] = None,
  311. output_attentions: Optional[bool] = False,
  312. **kwargs,
  313. ) -> tuple[torch.Tensor]:
  314. self_attention_outputs = self.attention(
  315. hidden_states,
  316. attention_mask=attention_mask,
  317. head_mask=head_mask,
  318. output_attentions=output_attentions,
  319. **kwargs,
  320. )
  321. attention_output = self_attention_outputs[0]
  322. outputs = self_attention_outputs[1:] # add self attentions if we output attention weights
  323. layer_output = apply_chunking_to_forward(
  324. self.feed_forward_chunk, self.chunk_size_feed_forward, self.seq_len_dim, attention_output
  325. )
  326. outputs = (layer_output,) + outputs
  327. return outputs
  328. def feed_forward_chunk(self, attention_output):
  329. intermediate_output = self.intermediate(attention_output)
  330. layer_output = self.output(intermediate_output, attention_output)
  331. return layer_output
  332. # Copied from transformers.models.align.modeling_align.AlignTextEncoder with AlignText->AltRoberta
  333. class AltRobertaEncoder(nn.Module):
  334. def __init__(self, config):
  335. super().__init__()
  336. self.config = config
  337. self.layer = nn.ModuleList([AltRobertaLayer(config) for i in range(config.num_hidden_layers)])
  338. self.gradient_checkpointing = False
  339. @can_return_tuple
  340. def forward(
  341. self,
  342. hidden_states: torch.Tensor,
  343. attention_mask: Optional[torch.FloatTensor] = None,
  344. head_mask: Optional[torch.FloatTensor] = None,
  345. output_attentions: Optional[bool] = False,
  346. output_hidden_states: Optional[bool] = False,
  347. return_dict: Optional[bool] = True,
  348. **kwargs,
  349. ) -> Union[tuple[torch.Tensor], BaseModelOutput]:
  350. all_hidden_states = () if output_hidden_states else None
  351. all_self_attentions = () if output_attentions else None
  352. for i, layer_module in enumerate(self.layer):
  353. if output_hidden_states:
  354. all_hidden_states = all_hidden_states + (hidden_states,)
  355. layer_head_mask = head_mask[i] if head_mask is not None else None
  356. layer_outputs = layer_module(
  357. hidden_states=hidden_states,
  358. attention_mask=attention_mask,
  359. head_mask=layer_head_mask,
  360. output_attentions=output_attentions,
  361. **kwargs,
  362. )
  363. hidden_states = layer_outputs[0]
  364. if output_attentions:
  365. all_self_attentions = all_self_attentions + (layer_outputs[1],)
  366. if output_hidden_states:
  367. all_hidden_states = all_hidden_states + (hidden_states,)
  368. return BaseModelOutput(
  369. last_hidden_state=hidden_states,
  370. hidden_states=all_hidden_states,
  371. attentions=all_self_attentions,
  372. )
  373. # Copied from transformers.models.roberta.modeling_roberta.RobertaPooler
  374. class AltRobertaPooler(nn.Module):
  375. def __init__(self, config):
  376. super().__init__()
  377. self.dense = nn.Linear(config.hidden_size, config.hidden_size)
  378. self.activation = nn.Tanh()
  379. def forward(self, hidden_states: torch.Tensor) -> torch.Tensor:
  380. # We "pool" the model by simply taking the hidden state corresponding
  381. # to the first token.
  382. first_token_tensor = hidden_states[:, 0]
  383. pooled_output = self.dense(first_token_tensor)
  384. pooled_output = self.activation(pooled_output)
  385. return pooled_output
  386. # Copied from transformers.models.siglip.modeling_siglip.eager_attention_forward
  387. def eager_attention_forward(
  388. module: nn.Module,
  389. query: torch.Tensor,
  390. key: torch.Tensor,
  391. value: torch.Tensor,
  392. attention_mask: Optional[torch.Tensor],
  393. scaling: float,
  394. dropout: float = 0.0,
  395. **kwargs,
  396. ):
  397. attn_weights = torch.matmul(query, key.transpose(-1, -2)) * scaling
  398. if attention_mask is not None:
  399. attn_weights = attn_weights + attention_mask
  400. attn_weights = nn.functional.softmax(attn_weights, dim=-1, dtype=torch.float32).to(query.dtype)
  401. attn_weights = nn.functional.dropout(attn_weights, p=dropout, training=module.training)
  402. attn_output = torch.matmul(attn_weights, value)
  403. attn_output = attn_output.transpose(1, 2).contiguous()
  404. return attn_output, attn_weights
  405. class AltCLIPAttention(nn.Module):
  406. """Multi-headed attention from 'Attention Is All You Need' paper"""
  407. def __init__(self, config):
  408. super().__init__()
  409. self.config = config
  410. self.embed_dim = config.hidden_size
  411. self.num_heads = config.num_attention_heads
  412. self.head_dim = self.embed_dim // self.num_heads
  413. if self.head_dim * self.num_heads != self.embed_dim:
  414. raise ValueError(
  415. f"embed_dim must be divisible by num_heads (got `embed_dim`: {self.embed_dim} and `num_heads`:"
  416. f" {self.num_heads})."
  417. )
  418. self.scale = self.head_dim**-0.5
  419. self.dropout = config.attention_dropout
  420. self.is_causal = False
  421. self.k_proj = nn.Linear(self.embed_dim, self.embed_dim)
  422. self.v_proj = nn.Linear(self.embed_dim, self.embed_dim)
  423. self.q_proj = nn.Linear(self.embed_dim, self.embed_dim)
  424. self.out_proj = nn.Linear(self.embed_dim, self.embed_dim)
  425. def forward(
  426. self,
  427. hidden_states: torch.Tensor,
  428. attention_mask: Optional[torch.Tensor] = None,
  429. causal_attention_mask: Optional[torch.Tensor] = None,
  430. output_attentions: Optional[bool] = False,
  431. ) -> tuple[torch.Tensor, Optional[torch.Tensor]]:
  432. """Input shape: Batch x Time x Channel"""
  433. batch_size, seq_length, embed_dim = hidden_states.shape
  434. queries = self.q_proj(hidden_states)
  435. keys = self.k_proj(hidden_states)
  436. values = self.v_proj(hidden_states)
  437. queries = queries.view(batch_size, seq_length, self.num_heads, self.head_dim).transpose(1, 2)
  438. keys = keys.view(batch_size, seq_length, self.num_heads, self.head_dim).transpose(1, 2)
  439. values = values.view(batch_size, seq_length, self.num_heads, self.head_dim).transpose(1, 2)
  440. # CLIP text model uses both `causal_attention_mask` and `attention_mask`
  441. # in case FA2 kernel is called, `is_causal` should be inferred from `causal_attention_mask`
  442. if self.config._attn_implementation != "flash_attention_2":
  443. if attention_mask is not None and causal_attention_mask is not None:
  444. attention_mask = attention_mask + causal_attention_mask
  445. elif causal_attention_mask is not None:
  446. attention_mask = causal_attention_mask
  447. else:
  448. self.is_causal = causal_attention_mask is not None
  449. attention_interface: Callable = eager_attention_forward
  450. if self.config._attn_implementation != "eager":
  451. attention_interface = ALL_ATTENTION_FUNCTIONS[self.config._attn_implementation]
  452. attn_output, attn_weights = attention_interface(
  453. self,
  454. queries,
  455. keys,
  456. values,
  457. attention_mask,
  458. is_causal=self.is_causal,
  459. scaling=self.scale,
  460. dropout=0.0 if not self.training else self.dropout,
  461. )
  462. attn_output = attn_output.reshape(batch_size, seq_length, embed_dim).contiguous()
  463. attn_output = self.out_proj(attn_output)
  464. if not output_attentions:
  465. attn_weights = None
  466. return attn_output, attn_weights
  467. # Copied from transformers.models.clip.modeling_clip.CLIPMLP with CLIP->AltCLIP
  468. class AltCLIPMLP(nn.Module):
  469. def __init__(self, config):
  470. super().__init__()
  471. self.config = config
  472. self.activation_fn = ACT2FN[config.hidden_act]
  473. self.fc1 = nn.Linear(config.hidden_size, config.intermediate_size)
  474. self.fc2 = nn.Linear(config.intermediate_size, config.hidden_size)
  475. def forward(self, hidden_states: torch.Tensor) -> torch.Tensor:
  476. hidden_states = self.fc1(hidden_states)
  477. hidden_states = self.activation_fn(hidden_states)
  478. hidden_states = self.fc2(hidden_states)
  479. return hidden_states
  480. class AltCLIPEncoderLayer(GradientCheckpointingLayer):
  481. def __init__(self, config: AltCLIPConfig):
  482. super().__init__()
  483. self.embed_dim = config.hidden_size
  484. self.self_attn = AltCLIPAttention(config)
  485. self.layer_norm1 = nn.LayerNorm(self.embed_dim, eps=config.layer_norm_eps)
  486. self.mlp = AltCLIPMLP(config)
  487. self.layer_norm2 = nn.LayerNorm(self.embed_dim, eps=config.layer_norm_eps)
  488. def forward(
  489. self,
  490. hidden_states: torch.Tensor,
  491. attention_mask: torch.Tensor,
  492. causal_attention_mask: torch.Tensor,
  493. output_attentions: Optional[bool] = False,
  494. ) -> tuple[torch.FloatTensor]:
  495. """
  496. Args:
  497. hidden_states (`torch.FloatTensor`): input to the layer of shape `(batch, seq_len, embed_dim)`
  498. attention_mask (`torch.FloatTensor`): attention mask of size
  499. `(batch, 1, tgt_len, src_len)` where padding elements are indicated by very large negative values.
  500. `(config.encoder_attention_heads,)`.
  501. output_attentions (`bool`, *optional*):
  502. Whether or not to return the attentions tensors of all attention layers. See `attentions` under
  503. returned tensors for more detail.
  504. """
  505. residual = hidden_states
  506. hidden_states = self.layer_norm1(hidden_states)
  507. hidden_states, attn_weights = self.self_attn(
  508. hidden_states=hidden_states,
  509. attention_mask=attention_mask,
  510. causal_attention_mask=causal_attention_mask,
  511. output_attentions=output_attentions,
  512. )
  513. hidden_states = residual + hidden_states
  514. residual = hidden_states
  515. hidden_states = self.layer_norm2(hidden_states)
  516. hidden_states = self.mlp(hidden_states)
  517. hidden_states = residual + hidden_states
  518. outputs = (hidden_states,)
  519. if output_attentions:
  520. outputs += (attn_weights,)
  521. return outputs
  522. class AltCLIPEncoder(nn.Module):
  523. """
  524. Transformer encoder consisting of `config.num_hidden_layers` self attention layers. Each layer is a
  525. [`AltCLIPEncoderLayer`].
  526. Args:
  527. config: AltCLIPConfig
  528. """
  529. def __init__(self, config: AltCLIPConfig):
  530. super().__init__()
  531. self.config = config
  532. self.layers = nn.ModuleList([AltCLIPEncoderLayer(config) for _ in range(config.num_hidden_layers)])
  533. self.gradient_checkpointing = False
  534. @can_return_tuple
  535. def forward(
  536. self,
  537. inputs_embeds,
  538. attention_mask: Optional[torch.Tensor] = None,
  539. causal_attention_mask: Optional[torch.Tensor] = None,
  540. output_attentions: Optional[bool] = None,
  541. output_hidden_states: Optional[bool] = None,
  542. return_dict: Optional[bool] = None,
  543. ) -> Union[tuple, BaseModelOutput]:
  544. r"""
  545. Args:
  546. inputs_embeds (`torch.FloatTensor` of shape `(batch_size, sequence_length, hidden_size)`):
  547. Optionally, instead of passing `input_ids` you can choose to directly pass an embedded representation.
  548. This is useful if you want more control over how to convert `input_ids` indices into associated vectors
  549. than the model's internal embedding lookup matrix.
  550. attention_mask (`torch.Tensor` of shape `(batch_size, sequence_length)`, *optional*):
  551. Mask to avoid performing attention on padding token indices. Mask values selected in `[0, 1]`:
  552. - 1 for tokens that are **not masked**,
  553. - 0 for tokens that are **masked**.
  554. [What are attention masks?](../glossary#attention-mask)
  555. causal_attention_mask (`torch.Tensor` of shape `(batch_size, sequence_length)`, *optional*):
  556. Causal mask for the text model. Mask values selected in `[0, 1]`:
  557. - 1 for tokens that are **not masked**,
  558. - 0 for tokens that are **masked**.
  559. [What are attention masks?](../glossary#attention-mask)
  560. output_attentions (`bool`, *optional*):
  561. Whether or not to return the attentions tensors of all attention layers. See `attentions` under
  562. returned tensors for more detail.
  563. output_hidden_states (`bool`, *optional*):
  564. Whether or not to return the hidden states of all layers. See `hidden_states` under returned tensors
  565. for more detail.
  566. return_dict (`bool`, *optional*):
  567. Whether or not to return a [`~utils.ModelOutput`] instead of a plain tuple.
  568. """
  569. output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
  570. output_hidden_states = (
  571. output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
  572. )
  573. return_dict = return_dict if return_dict is not None else self.config.use_return_dict
  574. encoder_states = () if output_hidden_states else None
  575. all_attentions = () if output_attentions else None
  576. hidden_states = inputs_embeds
  577. for idx, encoder_layer in enumerate(self.layers):
  578. if output_hidden_states:
  579. encoder_states = encoder_states + (hidden_states,)
  580. layer_outputs = encoder_layer(
  581. hidden_states,
  582. attention_mask,
  583. causal_attention_mask,
  584. output_attentions=output_attentions,
  585. )
  586. hidden_states = layer_outputs[0]
  587. if output_attentions:
  588. all_attentions = all_attentions + (layer_outputs[1],)
  589. if output_hidden_states:
  590. encoder_states = encoder_states + (hidden_states,)
  591. return BaseModelOutput(
  592. last_hidden_state=hidden_states, hidden_states=encoder_states, attentions=all_attentions
  593. )
  594. # Copied from transformers.models.clip.modeling_clip.CLIPVisionEmbeddings with CLIP->AltCLIP
  595. class AltCLIPVisionEmbeddings(nn.Module):
  596. def __init__(self, config: AltCLIPVisionConfig):
  597. super().__init__()
  598. self.config = config
  599. self.embed_dim = config.hidden_size
  600. self.image_size = config.image_size
  601. self.patch_size = config.patch_size
  602. self.class_embedding = nn.Parameter(torch.randn(self.embed_dim))
  603. self.patch_embedding = nn.Conv2d(
  604. in_channels=config.num_channels,
  605. out_channels=self.embed_dim,
  606. kernel_size=self.patch_size,
  607. stride=self.patch_size,
  608. bias=False,
  609. )
  610. self.num_patches = (self.image_size // self.patch_size) ** 2
  611. self.num_positions = self.num_patches + 1
  612. self.position_embedding = nn.Embedding(self.num_positions, self.embed_dim)
  613. self.register_buffer("position_ids", torch.arange(self.num_positions).expand((1, -1)), persistent=False)
  614. def interpolate_pos_encoding(self, embeddings: torch.Tensor, height: int, width: int) -> torch.Tensor:
  615. """
  616. This method allows to interpolate the pre-trained position encodings, to be able to use the model on higher resolution
  617. images. This method is also adapted to support torch.jit tracing.
  618. Adapted from:
  619. - https://github.com/facebookresearch/dino/blob/de9ee3df6cf39fac952ab558447af1fa1365362a/vision_transformer.py#L174-L194, and
  620. - https://github.com/facebookresearch/dinov2/blob/e1277af2ba9496fbadf7aec6eba56e8d882d1e35/dinov2/models/vision_transformer.py#L179-L211
  621. """
  622. num_patches = embeddings.shape[1] - 1
  623. position_embedding = self.position_embedding.weight.unsqueeze(0)
  624. num_positions = position_embedding.shape[1] - 1
  625. # always interpolate when tracing to ensure the exported model works for dynamic input shapes
  626. if not torch.jit.is_tracing() and num_patches == num_positions and height == width:
  627. return self.position_embedding(self.position_ids)
  628. class_pos_embed = position_embedding[:, :1]
  629. patch_pos_embed = position_embedding[:, 1:]
  630. dim = embeddings.shape[-1]
  631. new_height = height // self.patch_size
  632. new_width = width // self.patch_size
  633. sqrt_num_positions = torch_int(num_positions**0.5)
  634. patch_pos_embed = patch_pos_embed.reshape(1, sqrt_num_positions, sqrt_num_positions, dim)
  635. patch_pos_embed = patch_pos_embed.permute(0, 3, 1, 2)
  636. patch_pos_embed = nn.functional.interpolate(
  637. patch_pos_embed,
  638. size=(new_height, new_width),
  639. mode="bicubic",
  640. align_corners=False,
  641. )
  642. patch_pos_embed = patch_pos_embed.permute(0, 2, 3, 1).view(1, -1, dim)
  643. return torch.cat((class_pos_embed, patch_pos_embed), dim=1)
  644. def forward(self, pixel_values: torch.FloatTensor, interpolate_pos_encoding=False) -> torch.Tensor:
  645. batch_size, _, height, width = pixel_values.shape
  646. if not interpolate_pos_encoding and (height != self.image_size or width != self.image_size):
  647. raise ValueError(
  648. f"Input image size ({height}*{width}) doesn't match model ({self.image_size}*{self.image_size})."
  649. )
  650. target_dtype = self.patch_embedding.weight.dtype
  651. patch_embeds = self.patch_embedding(pixel_values.to(dtype=target_dtype)) # shape = [*, width, grid, grid]
  652. patch_embeds = patch_embeds.flatten(2).transpose(1, 2)
  653. class_embeds = self.class_embedding.expand(batch_size, 1, -1)
  654. embeddings = torch.cat([class_embeds, patch_embeds], dim=1)
  655. if interpolate_pos_encoding:
  656. embeddings = embeddings + self.interpolate_pos_encoding(embeddings, height, width)
  657. else:
  658. embeddings = embeddings + self.position_embedding(self.position_ids)
  659. return embeddings
  660. @auto_docstring
  661. class AltCLIPPreTrainedModel(PreTrainedModel):
  662. config: AltCLIPConfig
  663. base_model_prefix = "altclip"
  664. supports_gradient_checkpointing = True
  665. _no_split_module = []
  666. def _init_weights(self, module):
  667. """Initialize the weights"""
  668. factor = self.config.initializer_factor
  669. if isinstance(module, AltCLIPVisionEmbeddings):
  670. factor = self.config.initializer_factor
  671. nn.init.normal_(module.class_embedding, mean=0.0, std=module.embed_dim**-0.5 * factor)
  672. nn.init.normal_(module.patch_embedding.weight, std=module.config.initializer_range * factor)
  673. nn.init.normal_(module.position_embedding.weight, std=module.config.initializer_range * factor)
  674. elif isinstance(module, AltCLIPAttention):
  675. factor = self.config.initializer_factor
  676. in_proj_std = (module.embed_dim**-0.5) * ((2 * module.config.num_hidden_layers) ** -0.5) * factor
  677. out_proj_std = (module.embed_dim**-0.5) * factor
  678. nn.init.normal_(module.q_proj.weight, std=in_proj_std)
  679. nn.init.normal_(module.k_proj.weight, std=in_proj_std)
  680. nn.init.normal_(module.v_proj.weight, std=in_proj_std)
  681. nn.init.normal_(module.out_proj.weight, std=out_proj_std)
  682. elif isinstance(module, AltCLIPMLP):
  683. factor = self.config.initializer_factor
  684. in_proj_std = (module.config.hidden_size**-0.5) * ((2 * module.config.num_hidden_layers) ** -0.5) * factor
  685. fc_std = (2 * module.config.hidden_size) ** -0.5 * factor
  686. nn.init.normal_(module.fc1.weight, std=fc_std)
  687. nn.init.normal_(module.fc2.weight, std=in_proj_std)
  688. elif isinstance(module, AltCLIPModel):
  689. nn.init.normal_(
  690. module.text_projection.weight,
  691. std=module.text_embed_dim**-0.5 * self.config.initializer_factor,
  692. )
  693. module.text_projection._is_hf_initialized = True
  694. nn.init.normal_(
  695. module.visual_projection.weight,
  696. std=module.vision_embed_dim**-0.5 * self.config.initializer_factor,
  697. )
  698. module.visual_projection._is_hf_initialized = True
  699. elif isinstance(module, nn.LayerNorm):
  700. module.bias.data.zero_()
  701. module.weight.data.fill_(1.0)
  702. elif isinstance(module, nn.Linear):
  703. module.weight.data.normal_(mean=0.0, std=self.config.initializer_factor)
  704. if module.bias is not None:
  705. module.bias.data.zero_()
  706. elif isinstance(module, nn.Embedding):
  707. module.weight.data.normal_(mean=0.0, std=self.config.initializer_factor)
  708. if module.padding_idx is not None:
  709. module.weight.data[module.padding_idx].zero_()
  710. class AltCLIPVisionTransformer(nn.Module):
  711. def __init__(self, config: AltCLIPVisionConfig):
  712. super().__init__()
  713. self.config = config
  714. embed_dim = config.hidden_size
  715. self.embeddings = AltCLIPVisionEmbeddings(config)
  716. self.pre_layrnorm = nn.LayerNorm(embed_dim, eps=config.layer_norm_eps)
  717. self.encoder = AltCLIPEncoder(config)
  718. self.post_layernorm = nn.LayerNorm(embed_dim, eps=config.layer_norm_eps)
  719. @can_return_tuple
  720. @auto_docstring
  721. def forward(
  722. self,
  723. pixel_values: Optional[torch.FloatTensor] = None,
  724. output_attentions: Optional[bool] = None,
  725. output_hidden_states: Optional[bool] = None,
  726. return_dict: Optional[bool] = None,
  727. interpolate_pos_encoding: Optional[bool] = False,
  728. ) -> Union[tuple, BaseModelOutputWithPooling]:
  729. output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
  730. output_hidden_states = (
  731. output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
  732. )
  733. return_dict = return_dict if return_dict is not None else self.config.use_return_dict
  734. if pixel_values is None:
  735. raise ValueError("You have to specify pixel_values")
  736. hidden_states = self.embeddings(pixel_values, interpolate_pos_encoding=interpolate_pos_encoding)
  737. hidden_states = self.pre_layrnorm(hidden_states)
  738. encoder_outputs = self.encoder(
  739. inputs_embeds=hidden_states,
  740. output_attentions=output_attentions,
  741. output_hidden_states=output_hidden_states,
  742. return_dict=True,
  743. )
  744. last_hidden_state = encoder_outputs[0]
  745. pooled_output = last_hidden_state[:, 0, :]
  746. pooled_output = self.post_layernorm(pooled_output)
  747. return BaseModelOutputWithPooling(
  748. last_hidden_state=last_hidden_state,
  749. pooler_output=pooled_output,
  750. hidden_states=encoder_outputs.hidden_states,
  751. attentions=encoder_outputs.attentions,
  752. )
  753. class AltCLIPVisionModel(AltCLIPPreTrainedModel):
  754. config: AltCLIPVisionConfig
  755. main_input_name = "pixel_values"
  756. def __init__(self, config: AltCLIPVisionConfig):
  757. super().__init__(config)
  758. self.vision_model = AltCLIPVisionTransformer(config)
  759. # Initialize weights and apply final processing
  760. self.post_init()
  761. def get_input_embeddings(self) -> nn.Module:
  762. return self.vision_model.embeddings.patch_embedding
  763. @auto_docstring
  764. def forward(
  765. self,
  766. pixel_values: Optional[torch.FloatTensor] = None,
  767. output_attentions: Optional[bool] = None,
  768. output_hidden_states: Optional[bool] = None,
  769. interpolate_pos_encoding: bool = False,
  770. return_dict: Optional[bool] = None,
  771. ) -> Union[tuple, BaseModelOutputWithPooling]:
  772. r"""
  773. Examples:
  774. ```python
  775. >>> from PIL import Image
  776. >>> import requests
  777. >>> from transformers import AutoProcessor, AltCLIPVisionModel
  778. >>> model = AltCLIPVisionModel.from_pretrained("BAAI/AltCLIP")
  779. >>> processor = AutoProcessor.from_pretrained("BAAI/AltCLIP")
  780. >>> url = "http://images.cocodataset.org/val2017/000000039769.jpg"
  781. >>> image = Image.open(requests.get(url, stream=True).raw)
  782. >>> inputs = processor(images=image, return_tensors="pt")
  783. >>> outputs = model(**inputs)
  784. >>> last_hidden_state = outputs.last_hidden_state
  785. >>> pooled_output = outputs.pooler_output # pooled CLS states
  786. ```"""
  787. return_dict = return_dict if return_dict is not None else self.config.use_return_dict
  788. return self.vision_model(
  789. pixel_values=pixel_values,
  790. output_attentions=output_attentions,
  791. output_hidden_states=output_hidden_states,
  792. interpolate_pos_encoding=interpolate_pos_encoding,
  793. return_dict=return_dict,
  794. )
  795. @auto_docstring(
  796. custom_intro="""
  797. The model behaves as an encoder following the architecture described in *Attention is
  798. all you need*_ by Ashish Vaswani, Noam Shazeer, Niki Parmar, Jakob Uszkoreit, Llion Jones, Aidan N. Gomez, Lukasz
  799. Kaiser and Illia Polosukhin.
  800. .. _*Attention is all you need*: https://huggingface.co/papers/1706.03762
  801. """
  802. )
  803. class AltRobertaModel(AltCLIPPreTrainedModel):
  804. config: AltCLIPTextConfig
  805. # Copied from transformers.models.clap.modeling_clap.ClapTextModel.__init__ with ClapText->AltRoberta
  806. def __init__(self, config, add_pooling_layer=True):
  807. r"""
  808. add_pooling_layer (bool, *optional*, defaults to `True`):
  809. Whether to add a pooling layer
  810. """
  811. super().__init__(config)
  812. self.config = config
  813. self.embeddings = AltRobertaEmbeddings(config)
  814. self.encoder = AltRobertaEncoder(config)
  815. self.pooler = AltRobertaPooler(config) if add_pooling_layer else None
  816. # Initialize weights and apply final processing
  817. self.post_init()
  818. def get_input_embeddings(self):
  819. return self.embeddings.word_embeddings
  820. def set_input_embeddings(self, value):
  821. self.embeddings.word_embeddings = value
  822. def _prune_heads(self, heads_to_prune):
  823. """
  824. Prunes heads of the model. heads_to_prune: dict of {layer_num: list of heads to prune in this layer} See base
  825. class PreTrainedModel
  826. """
  827. for layer, heads in heads_to_prune.items():
  828. self.encoder.layer[layer].attention.prune_heads(heads)
  829. @auto_docstring
  830. # Copied from transformers.models.clap.modeling_clap.ClapTextModel.forward
  831. def forward(
  832. self,
  833. input_ids: Optional[torch.Tensor] = None,
  834. attention_mask: Optional[torch.Tensor] = None,
  835. token_type_ids: Optional[torch.Tensor] = None,
  836. position_ids: Optional[torch.Tensor] = None,
  837. head_mask: Optional[torch.Tensor] = None,
  838. inputs_embeds: Optional[torch.Tensor] = None,
  839. output_attentions: Optional[bool] = None,
  840. output_hidden_states: Optional[bool] = None,
  841. return_dict: Optional[bool] = None,
  842. ) -> Union[tuple[torch.Tensor], BaseModelOutputWithPoolingAndCrossAttentions]:
  843. output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
  844. output_hidden_states = (
  845. output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
  846. )
  847. return_dict = return_dict if return_dict is not None else self.config.use_return_dict
  848. if input_ids is not None and inputs_embeds is not None:
  849. raise ValueError("You cannot specify both input_ids and inputs_embeds at the same time")
  850. elif input_ids is not None:
  851. self.warn_if_padding_and_no_attention_mask(input_ids, attention_mask)
  852. input_shape = input_ids.size()
  853. elif inputs_embeds is not None:
  854. input_shape = inputs_embeds.size()[:-1]
  855. else:
  856. raise ValueError("You have to specify either input_ids or inputs_embeds")
  857. batch_size, seq_length = input_shape
  858. device = input_ids.device if input_ids is not None else inputs_embeds.device
  859. if attention_mask is None:
  860. attention_mask = torch.ones(((batch_size, seq_length)), device=device)
  861. if token_type_ids is None:
  862. if hasattr(self.embeddings, "token_type_ids"):
  863. buffered_token_type_ids = self.embeddings.token_type_ids[:, :seq_length]
  864. buffered_token_type_ids_expanded = buffered_token_type_ids.expand(batch_size, seq_length)
  865. token_type_ids = buffered_token_type_ids_expanded
  866. else:
  867. token_type_ids = torch.zeros(input_shape, dtype=torch.long, device=device)
  868. # We can provide a self-attention mask of dimensions [batch_size, from_seq_length, to_seq_length]
  869. # ourselves in which case we just need to make it broadcastable to all heads.
  870. extended_attention_mask: torch.Tensor = self.get_extended_attention_mask(attention_mask, input_shape)
  871. # and head_mask is converted to shape [num_hidden_layers x batch x num_heads x seq_length x seq_length]
  872. head_mask = self.get_head_mask(head_mask, self.config.num_hidden_layers)
  873. embedding_output = self.embeddings(
  874. input_ids=input_ids,
  875. position_ids=position_ids,
  876. token_type_ids=token_type_ids,
  877. inputs_embeds=inputs_embeds,
  878. )
  879. encoder_outputs = self.encoder(
  880. embedding_output,
  881. attention_mask=extended_attention_mask,
  882. head_mask=head_mask,
  883. output_attentions=output_attentions,
  884. output_hidden_states=output_hidden_states,
  885. return_dict=True,
  886. )
  887. sequence_output = encoder_outputs[0]
  888. pooled_output = self.pooler(sequence_output) if self.pooler is not None else None
  889. return BaseModelOutputWithPooling(
  890. last_hidden_state=sequence_output,
  891. pooler_output=pooled_output,
  892. hidden_states=encoder_outputs.hidden_states,
  893. attentions=encoder_outputs.attentions,
  894. )
  895. class AltCLIPTextModel(AltCLIPPreTrainedModel):
  896. config: AltCLIPTextConfig
  897. def __init__(self, config):
  898. super().__init__(config)
  899. self.roberta = AltRobertaModel(config, add_pooling_layer=False)
  900. self.transformation = nn.Linear(config.hidden_size, config.project_dim)
  901. self.pre_LN = nn.LayerNorm(config.hidden_size, eps=config.layer_norm_eps)
  902. self.post_init()
  903. def get_input_embeddings(self) -> nn.Module:
  904. return self.roberta.embeddings.word_embeddings
  905. def set_input_embeddings(self, value: nn.Embedding) -> None:
  906. self.roberta.embeddings.word_embeddings = value
  907. def resize_token_embeddings(self, new_num_tokens: Optional[int] = None) -> nn.Embedding:
  908. return super().resize_token_embeddings(new_num_tokens)
  909. @can_return_tuple
  910. @auto_docstring
  911. def forward(
  912. self,
  913. input_ids: Optional[torch.Tensor] = None,
  914. attention_mask: Optional[torch.Tensor] = None,
  915. token_type_ids: Optional[torch.Tensor] = None,
  916. position_ids: Optional[torch.Tensor] = None,
  917. head_mask: Optional[torch.Tensor] = None,
  918. inputs_embeds: Optional[torch.Tensor] = None,
  919. output_attentions: Optional[bool] = None,
  920. return_dict: Optional[bool] = None,
  921. output_hidden_states: Optional[bool] = None,
  922. ) -> Union[tuple, BaseModelOutputWithPoolingAndProjection]:
  923. r"""
  924. Examples:
  925. ```python
  926. >>> from transformers import AutoProcessor, AltCLIPTextModel
  927. >>> model = AltCLIPTextModel.from_pretrained("BAAI/AltCLIP")
  928. >>> processor = AutoProcessor.from_pretrained("BAAI/AltCLIP")
  929. >>> texts = ["it's a cat", "it's a dog"]
  930. >>> inputs = processor(text=texts, padding=True, return_tensors="pt")
  931. >>> outputs = model(**inputs)
  932. >>> last_hidden_state = outputs.last_hidden_state
  933. >>> pooled_output = outputs.pooler_output # pooled CLS states
  934. ```"""
  935. return_dict = return_dict if return_dict is not None else self.config.use_return_dict
  936. outputs = self.roberta(
  937. input_ids=input_ids,
  938. attention_mask=attention_mask,
  939. token_type_ids=token_type_ids,
  940. position_ids=position_ids,
  941. head_mask=head_mask,
  942. inputs_embeds=inputs_embeds,
  943. output_attentions=output_attentions,
  944. output_hidden_states=output_hidden_states,
  945. return_dict=True,
  946. )
  947. # last module outputs
  948. sequence_output = outputs[0]
  949. # project every module
  950. sequence_output = self.pre_LN(sequence_output)
  951. # pooler
  952. projection_state = self.transformation(sequence_output)
  953. pooler_output = projection_state[:, 0]
  954. return BaseModelOutputWithPoolingAndProjection(
  955. last_hidden_state=projection_state,
  956. pooler_output=pooler_output,
  957. hidden_states=outputs.hidden_states,
  958. attentions=outputs.attentions,
  959. )
  960. class AltCLIPModel(AltCLIPPreTrainedModel):
  961. config: AltCLIPConfig
  962. def __init__(self, config: AltCLIPConfig):
  963. super().__init__(config)
  964. if not isinstance(config.vision_config, AltCLIPVisionConfig):
  965. raise TypeError(
  966. "config.vision_config is expected to be of type AltCLIPVisionConfig but is of type"
  967. f" {type(config.vision_config)}."
  968. )
  969. if not isinstance(config.text_config, AltCLIPTextConfig):
  970. raise TypeError(
  971. "config.text_config is expected to be of type AltCLIPTextConfig but is of type"
  972. f" {type(config.text_config)}."
  973. )
  974. text_config = config.text_config
  975. vision_config = config.vision_config
  976. # The module using it is not a PreTrainedModel subclass so we need this
  977. vision_config._attn_implementation = config._attn_implementation
  978. self.projection_dim = config.projection_dim
  979. self.text_embed_dim = text_config.project_dim
  980. self.vision_embed_dim = vision_config.hidden_size
  981. self.text_model = AltCLIPTextModel(text_config)
  982. self.vision_model = AltCLIPVisionTransformer(vision_config)
  983. self.visual_projection = nn.Linear(self.vision_embed_dim, self.projection_dim, bias=False)
  984. self.text_projection = nn.Linear(self.text_embed_dim, self.projection_dim, bias=False)
  985. self.logit_scale = nn.Parameter(torch.tensor(self.config.logit_scale_init_value))
  986. # Initialize weights and apply final processing
  987. self.post_init()
  988. @filter_out_non_signature_kwargs()
  989. @auto_docstring
  990. def get_text_features(
  991. self,
  992. input_ids: torch.Tensor,
  993. attention_mask: Optional[torch.Tensor] = None,
  994. position_ids: Optional[torch.Tensor] = None,
  995. token_type_ids: Optional[torch.Tensor] = None,
  996. ) -> torch.FloatTensor:
  997. r"""
  998. Returns:
  999. text_features (`torch.FloatTensor` of shape `(batch_size, output_dim`): The text embeddings obtained by
  1000. applying the projection layer to the pooled output of [`AltCLIPTextModel`].
  1001. Examples:
  1002. ```python
  1003. >>> import torch
  1004. >>> from transformers import AutoProcessor, AltCLIPModel
  1005. >>> model = AltCLIPModel.from_pretrained("BAAI/AltCLIP")
  1006. >>> processor = AutoProcessor.from_pretrained("BAAI/AltCLIP")
  1007. >>> inputs = processor(text=["a photo of a cat", "a photo of a dog"], padding=True, return_tensors="pt")
  1008. >>> with torch.inference_mode():
  1009. ... text_features = model.get_text_features(**inputs)
  1010. ```"""
  1011. text_outputs = self.text_model(
  1012. input_ids=input_ids,
  1013. attention_mask=attention_mask,
  1014. position_ids=position_ids,
  1015. token_type_ids=token_type_ids,
  1016. )
  1017. pooled_output = text_outputs.pooler_output
  1018. text_features = self.text_projection(pooled_output)
  1019. return text_features
  1020. @filter_out_non_signature_kwargs()
  1021. @auto_docstring
  1022. def get_image_features(
  1023. self,
  1024. pixel_values: torch.FloatTensor,
  1025. interpolate_pos_encoding: bool = False,
  1026. ) -> torch.FloatTensor:
  1027. r"""
  1028. Returns:
  1029. image_features (`torch.FloatTensor` of shape `(batch_size, output_dim`): The image embeddings obtained by
  1030. applying the projection layer to the pooled output of [`AltCLIPVisionModel`].
  1031. Examples:
  1032. ```python
  1033. >>> import torch
  1034. >>> from transformers import AutoProcessor, AltCLIPModel
  1035. >>> from transformers.image_utils import load_image
  1036. >>> model = AltCLIPModel.from_pretrained("BAAI/AltCLIP")
  1037. >>> processor = AutoProcessor.from_pretrained("BAAI/AltCLIP")
  1038. >>> url = "http://images.cocodataset.org/val2017/000000039769.jpg"
  1039. >>> image = load_image(url)
  1040. >>> inputs = processor(images=image, return_tensors="pt")
  1041. >>> with torch.inference_mode():
  1042. ... image_features = model.get_image_features(**inputs)
  1043. ```"""
  1044. vision_outputs = self.vision_model(
  1045. pixel_values=pixel_values,
  1046. interpolate_pos_encoding=interpolate_pos_encoding,
  1047. )
  1048. pooled_output = vision_outputs.pooler_output
  1049. image_features = self.visual_projection(pooled_output)
  1050. return image_features
  1051. @auto_docstring
  1052. def forward(
  1053. self,
  1054. input_ids: Optional[torch.LongTensor] = None,
  1055. pixel_values: Optional[torch.FloatTensor] = None,
  1056. attention_mask: Optional[torch.Tensor] = None,
  1057. position_ids: Optional[torch.LongTensor] = None,
  1058. token_type_ids: Optional[torch.Tensor] = None,
  1059. return_loss: Optional[bool] = None,
  1060. output_attentions: Optional[bool] = None,
  1061. output_hidden_states: Optional[bool] = None,
  1062. interpolate_pos_encoding: bool = False,
  1063. return_dict: Optional[bool] = None,
  1064. ) -> Union[tuple, AltCLIPOutput]:
  1065. r"""
  1066. return_loss (`bool`, *optional*):
  1067. Whether or not to return the contrastive loss.
  1068. Examples:
  1069. ```python
  1070. >>> from PIL import Image
  1071. >>> import requests
  1072. >>> from transformers import AutoProcessor, AltCLIPModel
  1073. >>> model = AltCLIPModel.from_pretrained("BAAI/AltCLIP")
  1074. >>> processor = AutoProcessor.from_pretrained("BAAI/AltCLIP")
  1075. >>> url = "http://images.cocodataset.org/val2017/000000039769.jpg"
  1076. >>> image = Image.open(requests.get(url, stream=True).raw)
  1077. >>> inputs = processor(
  1078. ... text=["a photo of a cat", "a photo of a dog"], images=image, return_tensors="pt", padding=True
  1079. ... )
  1080. >>> outputs = model(**inputs)
  1081. >>> logits_per_image = outputs.logits_per_image # this is the image-text similarity score
  1082. >>> probs = logits_per_image.softmax(dim=1) # we can take the softmax to get the label probabilities
  1083. ```"""
  1084. # Use AltCLIP model's config for some fields (if specified) instead of those of vision & text components.
  1085. output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
  1086. output_hidden_states = (
  1087. output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
  1088. )
  1089. return_dict = return_dict if return_dict is not None else self.config.use_return_dict
  1090. text_outputs = self.text_model(
  1091. input_ids=input_ids,
  1092. attention_mask=attention_mask,
  1093. token_type_ids=token_type_ids,
  1094. position_ids=position_ids,
  1095. output_attentions=output_attentions,
  1096. output_hidden_states=output_hidden_states,
  1097. return_dict=return_dict,
  1098. )
  1099. vision_outputs = self.vision_model(
  1100. pixel_values=pixel_values,
  1101. output_attentions=output_attentions,
  1102. output_hidden_states=output_hidden_states,
  1103. interpolate_pos_encoding=interpolate_pos_encoding,
  1104. return_dict=return_dict,
  1105. )
  1106. image_embeds = vision_outputs[1]
  1107. image_embeds = self.visual_projection(image_embeds)
  1108. text_embeds = text_outputs[1]
  1109. text_embeds = self.text_projection(text_embeds)
  1110. # normalized features
  1111. image_embeds = image_embeds / image_embeds.norm(p=2, dim=-1, keepdim=True)
  1112. text_embeds = text_embeds / text_embeds.norm(p=2, dim=-1, keepdim=True)
  1113. # cosine similarity as logits
  1114. logit_scale = self.logit_scale.exp()
  1115. logits_per_text = torch.matmul(text_embeds, image_embeds.t()) * logit_scale
  1116. logits_per_image = logits_per_text.T
  1117. loss = None
  1118. if return_loss:
  1119. loss = clip_loss(logits_per_text)
  1120. if not return_dict:
  1121. output = (logits_per_image, logits_per_text, text_embeds, image_embeds, text_outputs, vision_outputs)
  1122. return ((loss,) + output) if loss is not None else output
  1123. return AltCLIPOutput(
  1124. loss=loss,
  1125. logits_per_image=logits_per_image,
  1126. logits_per_text=logits_per_text,
  1127. text_embeds=text_embeds,
  1128. image_embeds=image_embeds,
  1129. text_model_output=text_outputs,
  1130. vision_model_output=vision_outputs,
  1131. )
  1132. # Copied from transformers.models.roberta.modeling_roberta.create_position_ids_from_input_ids
  1133. def create_position_ids_from_input_ids(input_ids, padding_idx, past_key_values_length=0):
  1134. """
  1135. Replace non-padding symbols with their position numbers. Position numbers begin at padding_idx+1. Padding symbols
  1136. are ignored. This is modified from fairseq's `utils.make_positions`.
  1137. Args:
  1138. x: torch.Tensor x:
  1139. Returns: torch.Tensor
  1140. """
  1141. # The series of casts and type-conversions here are carefully balanced to both work with ONNX export and XLA.
  1142. mask = input_ids.ne(padding_idx).int()
  1143. incremental_indices = (torch.cumsum(mask, dim=1).type_as(mask) + past_key_values_length) * mask
  1144. return incremental_indices.long() + padding_idx
  1145. __all__ = ["AltCLIPPreTrainedModel", "AltCLIPVisionModel", "AltCLIPTextModel", "AltCLIPModel"]