modeling_git.py 61 KB

12345678910111213141516171819202122232425262728293031323334353637383940414243444546474849505152535455565758596061626364656667686970717273747576777879808182838485868788899091929394959697989910010110210310410510610710810911011111211311411511611711811912012112212312412512612712812913013113213313413513613713813914014114214314414514614714814915015115215315415515615715815916016116216316416516616716816917017117217317417517617717817918018118218318418518618718818919019119219319419519619719819920020120220320420520620720820921021121221321421521621721821922022122222322422522622722822923023123223323423523623723823924024124224324424524624724824925025125225325425525625725825926026126226326426526626726826927027127227327427527627727827928028128228328428528628728828929029129229329429529629729829930030130230330430530630730830931031131231331431531631731831932032132232332432532632732832933033133233333433533633733833934034134234334434534634734834935035135235335435535635735835936036136236336436536636736836937037137237337437537637737837938038138238338438538638738838939039139239339439539639739839940040140240340440540640740840941041141241341441541641741841942042142242342442542642742842943043143243343443543643743843944044144244344444544644744844945045145245345445545645745845946046146246346446546646746846947047147247347447547647747847948048148248348448548648748848949049149249349449549649749849950050150250350450550650750850951051151251351451551651751851952052152252352452552652752852953053153253353453553653753853954054154254354454554654754854955055155255355455555655755855956056156256356456556656756856957057157257357457557657757857958058158258358458558658758858959059159259359459559659759859960060160260360460560660760860961061161261361461561661761861962062162262362462562662762862963063163263363463563663763863964064164264364464564664764864965065165265365465565665765865966066166266366466566666766866967067167267367467567667767867968068168268368468568668768868969069169269369469569669769869970070170270370470570670770870971071171271371471571671771871972072172272372472572672772872973073173273373473573673773873974074174274374474574674774874975075175275375475575675775875976076176276376476576676776876977077177277377477577677777877978078178278378478578678778878979079179279379479579679779879980080180280380480580680780880981081181281381481581681781881982082182282382482582682782882983083183283383483583683783883984084184284384484584684784884985085185285385485585685785885986086186286386486586686786886987087187287387487587687787887988088188288388488588688788888989089189289389489589689789889990090190290390490590690790890991091191291391491591691791891992092192292392492592692792892993093193293393493593693793893994094194294394494594694794894995095195295395495595695795895996096196296396496596696796896997097197297397497597697797897998098198298398498598698798898999099199299399499599699799899910001001100210031004100510061007100810091010101110121013101410151016101710181019102010211022102310241025102610271028102910301031103210331034103510361037103810391040104110421043104410451046104710481049105010511052105310541055105610571058105910601061106210631064106510661067106810691070107110721073107410751076107710781079108010811082108310841085108610871088108910901091109210931094109510961097109810991100110111021103110411051106110711081109111011111112111311141115111611171118111911201121112211231124112511261127112811291130113111321133113411351136113711381139114011411142114311441145114611471148114911501151115211531154115511561157115811591160116111621163116411651166116711681169117011711172117311741175117611771178117911801181118211831184118511861187118811891190119111921193119411951196119711981199120012011202120312041205120612071208120912101211121212131214121512161217121812191220122112221223122412251226122712281229123012311232123312341235123612371238123912401241124212431244124512461247124812491250125112521253125412551256125712581259126012611262126312641265126612671268126912701271127212731274127512761277127812791280128112821283128412851286128712881289129012911292129312941295129612971298129913001301130213031304130513061307130813091310131113121313131413151316131713181319132013211322132313241325132613271328132913301331133213331334133513361337133813391340134113421343134413451346134713481349135013511352135313541355135613571358135913601361136213631364136513661367136813691370137113721373137413751376137713781379138013811382138313841385138613871388138913901391139213931394139513961397139813991400140114021403140414051406140714081409141014111412141314141415141614171418141914201421142214231424142514261427142814291430143114321433143414351436143714381439144014411442144314441445144614471448144914501451145214531454145514561457145814591460
  1. # coding=utf-8
  2. # Copyright 2022 Microsoft Research and The HuggingFace Inc. team.
  3. # All rights reserved.
  4. #
  5. # Licensed under the Apache License, Version 2.0 (the "License");
  6. # you may not use this file except in compliance with the License.
  7. # You may obtain a copy of the License at
  8. #
  9. # http://www.apache.org/licenses/LICENSE-2.0
  10. #
  11. # Unless required by applicable law or agreed to in writing, software
  12. # distributed under the License is distributed on an "AS IS" BASIS,
  13. # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  14. # See the License for the specific language governing permissions and
  15. # limitations under the License.
  16. """PyTorch GIT model."""
  17. import math
  18. from dataclasses import dataclass
  19. from typing import Callable, Optional, Union
  20. import torch
  21. from torch import nn
  22. from ...activations import ACT2FN
  23. from ...cache_utils import Cache, DynamicCache
  24. from ...generation import GenerationMixin
  25. from ...modeling_attn_mask_utils import _prepare_4d_attention_mask
  26. from ...modeling_layers import GradientCheckpointingLayer
  27. from ...modeling_outputs import (
  28. BaseModelOutput,
  29. BaseModelOutputWithPast,
  30. BaseModelOutputWithPooling,
  31. CausalLMOutputWithPast,
  32. )
  33. from ...modeling_utils import ALL_ATTENTION_FUNCTIONS, PreTrainedModel
  34. from ...pytorch_utils import apply_chunking_to_forward, find_pruneable_heads_and_indices, prune_linear_layer
  35. from ...utils import (
  36. ModelOutput,
  37. auto_docstring,
  38. can_return_tuple,
  39. logging,
  40. torch_int,
  41. )
  42. from ...utils.deprecation import deprecate_kwarg
  43. from .configuration_git import GitConfig, GitVisionConfig
  44. logger = logging.get_logger(__name__)
  45. @dataclass
  46. @auto_docstring(
  47. custom_intro="""
  48. Base class for vision model's outputs that also contains image embeddings of the pooling of the last hidden states.
  49. """
  50. )
  51. # Copied from transformers.models.clip.modeling_clip.CLIPVisionModelOutput with CLIP->Git
  52. class GitVisionModelOutput(ModelOutput):
  53. r"""
  54. image_embeds (`torch.FloatTensor` of shape `(batch_size, output_dim)` *optional* returned when model is initialized with `with_projection=True`):
  55. The image embeddings obtained by applying the projection layer to the pooler_output.
  56. """
  57. image_embeds: Optional[torch.FloatTensor] = None
  58. last_hidden_state: Optional[torch.FloatTensor] = None
  59. hidden_states: Optional[tuple[torch.FloatTensor, ...]] = None
  60. attentions: Optional[tuple[torch.FloatTensor, ...]] = None
  61. class GitEmbeddings(nn.Module):
  62. """Construct the embeddings from word and position embeddings."""
  63. def __init__(self, config):
  64. super().__init__()
  65. self.word_embeddings = nn.Embedding(config.vocab_size, config.hidden_size, padding_idx=config.pad_token_id)
  66. self.position_embeddings = nn.Embedding(config.max_position_embeddings, config.hidden_size)
  67. # self.LayerNorm is not snake-cased to stick with TensorFlow model variable name and be able to load
  68. # any TensorFlow checkpoint file
  69. self.LayerNorm = nn.LayerNorm(config.hidden_size, eps=config.layer_norm_eps)
  70. self.dropout = nn.Dropout(config.hidden_dropout_prob)
  71. # position_ids (1, len position emb) is contiguous in memory and exported when serialized
  72. self.position_embedding_type = getattr(config, "position_embedding_type", "absolute")
  73. self.register_buffer(
  74. "position_ids", torch.arange(config.max_position_embeddings).expand((1, -1)), persistent=False
  75. )
  76. def forward(
  77. self,
  78. input_ids: Optional[torch.LongTensor] = None,
  79. position_ids: Optional[torch.LongTensor] = None,
  80. inputs_embeds: Optional[torch.FloatTensor] = None,
  81. past_key_values_length: int = 0,
  82. ) -> torch.Tensor:
  83. if input_ids is not None:
  84. input_shape = input_ids.size()
  85. else:
  86. input_shape = inputs_embeds.size()[:-1]
  87. seq_length = input_shape[1]
  88. if position_ids is None:
  89. position_ids = self.position_ids[:, past_key_values_length : seq_length + past_key_values_length]
  90. if inputs_embeds is None:
  91. embeddings = self.word_embeddings(input_ids)
  92. else:
  93. embeddings = inputs_embeds
  94. if self.position_embedding_type == "absolute":
  95. position_embeddings = self.position_embeddings(position_ids)
  96. embeddings += position_embeddings
  97. embeddings = self.LayerNorm(embeddings)
  98. embeddings = self.dropout(embeddings)
  99. return embeddings
  100. class GitSelfAttention(nn.Module):
  101. def __init__(self, config, position_embedding_type=None, layer_idx=None):
  102. super().__init__()
  103. if config.hidden_size % config.num_attention_heads != 0 and not hasattr(config, "embedding_size"):
  104. raise ValueError(
  105. f"The hidden size ({config.hidden_size}) is not a multiple of the number of attention "
  106. f"heads ({config.num_attention_heads})"
  107. )
  108. self.layer_idx = layer_idx
  109. if layer_idx is None:
  110. logger.warning_once(
  111. f"Instantiating {self.__class__.__name__} without passing a `layer_idx` is not recommended and will "
  112. "lead to errors during the forward call if caching is used. Please make sure to provide a `layer_idx` "
  113. "when creating this class."
  114. )
  115. self.num_attention_heads = config.num_attention_heads
  116. self.attention_head_size = int(config.hidden_size / config.num_attention_heads)
  117. self.all_head_size = self.num_attention_heads * self.attention_head_size
  118. self.image_patch_tokens = int((config.vision_config.image_size / config.vision_config.patch_size) ** 2 + 1)
  119. if config.num_image_with_embedding is not None:
  120. self.image_patch_tokens *= config.num_image_with_embedding
  121. self.query = nn.Linear(config.hidden_size, self.all_head_size)
  122. self.key = nn.Linear(config.hidden_size, self.all_head_size)
  123. self.value = nn.Linear(config.hidden_size, self.all_head_size)
  124. self.dropout = nn.Dropout(config.attention_probs_dropout_prob)
  125. self.position_embedding_type = position_embedding_type or getattr(
  126. config, "position_embedding_type", "absolute"
  127. )
  128. if self.position_embedding_type == "relative_key" or self.position_embedding_type == "relative_key_query":
  129. self.max_position_embeddings = config.max_position_embeddings
  130. self.distance_embedding = nn.Embedding(2 * config.max_position_embeddings - 1, self.attention_head_size)
  131. @deprecate_kwarg("past_key_value", new_name="past_key_values", version="4.58")
  132. def forward(
  133. self,
  134. hidden_states: torch.Tensor,
  135. attention_mask: Optional[torch.FloatTensor] = None,
  136. head_mask: Optional[torch.FloatTensor] = None,
  137. past_key_values: Optional[Cache] = None,
  138. output_attentions: Optional[bool] = False,
  139. pixel_values_present: Optional[bool] = False,
  140. ) -> tuple[torch.Tensor]:
  141. batch_size, seq_length, _ = hidden_states.shape
  142. query_layer = (
  143. self.query(hidden_states)
  144. .view(batch_size, -1, self.num_attention_heads, self.attention_head_size)
  145. .transpose(1, 2)
  146. )
  147. cutoff = self.image_patch_tokens if pixel_values_present else 0
  148. key_layer = (
  149. self.key(hidden_states)
  150. .view(batch_size, -1, self.num_attention_heads, self.attention_head_size)
  151. .transpose(1, 2)
  152. )
  153. value_layer = (
  154. self.value(hidden_states)
  155. .view(batch_size, -1, self.num_attention_heads, self.attention_head_size)
  156. .transpose(1, 2)
  157. )
  158. if past_key_values is not None:
  159. # NOTE: like in other caches, we store the text component. In GIT it means we discard the image component.
  160. key_layer_past, value_layer_past = past_key_values.update(
  161. key_layer[:, :, cutoff:, :], value_layer[:, :, cutoff:, :], self.layer_idx
  162. )
  163. key_layer = torch.cat([key_layer[:, :, :cutoff, :], key_layer_past], dim=2)
  164. value_layer = torch.cat([value_layer[:, :, :cutoff, :], value_layer_past], dim=2)
  165. # Take the dot product between "query" and "key" to get the raw attention scores.
  166. attention_scores = torch.matmul(query_layer, key_layer.transpose(-1, -2))
  167. if self.position_embedding_type == "relative_key" or self.position_embedding_type == "relative_key_query":
  168. query_length, key_length = query_layer.shape[2], key_layer.shape[2]
  169. if past_key_values is not None:
  170. position_ids_l = torch.tensor(key_length - 1, dtype=torch.long, device=hidden_states.device).view(
  171. -1, 1
  172. )
  173. else:
  174. position_ids_l = torch.arange(query_length, dtype=torch.long, device=hidden_states.device).view(-1, 1)
  175. position_ids_r = torch.arange(key_length, dtype=torch.long, device=hidden_states.device).view(1, -1)
  176. distance = position_ids_l - position_ids_r
  177. positional_embedding = self.distance_embedding(distance + self.max_position_embeddings - 1)
  178. positional_embedding = positional_embedding.to(dtype=query_layer.dtype) # fp16 compatibility
  179. if self.position_embedding_type == "relative_key":
  180. relative_position_scores = torch.einsum("bhld,lrd->bhlr", query_layer, positional_embedding)
  181. attention_scores = attention_scores + relative_position_scores
  182. elif self.position_embedding_type == "relative_key_query":
  183. relative_position_scores_query = torch.einsum("bhld,lrd->bhlr", query_layer, positional_embedding)
  184. relative_position_scores_key = torch.einsum("bhrd,lrd->bhlr", key_layer, positional_embedding)
  185. attention_scores = attention_scores + relative_position_scores_query + relative_position_scores_key
  186. attention_scores = attention_scores / math.sqrt(self.attention_head_size)
  187. if attention_mask is not None:
  188. # Apply the attention mask is (precomputed for all layers in GitModel forward() function)
  189. attention_scores = attention_scores + attention_mask
  190. # Normalize the attention scores to probabilities.
  191. attention_probs = nn.functional.softmax(attention_scores, dim=-1)
  192. # This is actually dropping out entire tokens to attend to, which might
  193. # seem a bit unusual, but is taken from the original Transformer paper.
  194. attention_probs = self.dropout(attention_probs)
  195. # Mask heads if we want to
  196. if head_mask is not None:
  197. attention_probs = attention_probs * head_mask
  198. context_layer = torch.matmul(attention_probs, value_layer)
  199. context_layer = context_layer.permute(0, 2, 1, 3).contiguous()
  200. new_context_layer_shape = context_layer.size()[:-2] + (self.all_head_size,)
  201. context_layer = context_layer.view(new_context_layer_shape)
  202. return context_layer, attention_probs
  203. # Copied from transformers.models.bert.modeling_bert.BertSelfOutput
  204. class GitSelfOutput(nn.Module):
  205. def __init__(self, config):
  206. super().__init__()
  207. self.dense = nn.Linear(config.hidden_size, config.hidden_size)
  208. self.LayerNorm = nn.LayerNorm(config.hidden_size, eps=config.layer_norm_eps)
  209. self.dropout = nn.Dropout(config.hidden_dropout_prob)
  210. def forward(self, hidden_states: torch.Tensor, input_tensor: torch.Tensor) -> torch.Tensor:
  211. hidden_states = self.dense(hidden_states)
  212. hidden_states = self.dropout(hidden_states)
  213. hidden_states = self.LayerNorm(hidden_states + input_tensor)
  214. return hidden_states
  215. GIT_SELF_ATTENTION_CLASSES = {
  216. "eager": GitSelfAttention,
  217. }
  218. class GitAttention(nn.Module):
  219. def __init__(self, config, position_embedding_type=None, layer_idx=None):
  220. super().__init__()
  221. self.self = GIT_SELF_ATTENTION_CLASSES[config._attn_implementation](
  222. config, position_embedding_type=position_embedding_type, layer_idx=layer_idx
  223. )
  224. self.output = GitSelfOutput(config)
  225. self.pruned_heads = set()
  226. # Copied from transformers.models.bert.modeling_bert.BertAttention.prune_heads
  227. def prune_heads(self, heads):
  228. if len(heads) == 0:
  229. return
  230. heads, index = find_pruneable_heads_and_indices(
  231. heads, self.self.num_attention_heads, self.self.attention_head_size, self.pruned_heads
  232. )
  233. # Prune linear layers
  234. self.self.query = prune_linear_layer(self.self.query, index)
  235. self.self.key = prune_linear_layer(self.self.key, index)
  236. self.self.value = prune_linear_layer(self.self.value, index)
  237. self.output.dense = prune_linear_layer(self.output.dense, index, dim=1)
  238. # Update hyper params and store pruned heads
  239. self.self.num_attention_heads = self.self.num_attention_heads - len(heads)
  240. self.self.all_head_size = self.self.attention_head_size * self.self.num_attention_heads
  241. self.pruned_heads = self.pruned_heads.union(heads)
  242. @deprecate_kwarg("past_key_value", new_name="past_key_values", version="4.58")
  243. def forward(
  244. self,
  245. hidden_states: torch.Tensor,
  246. attention_mask: Optional[torch.FloatTensor] = None,
  247. head_mask: Optional[torch.FloatTensor] = None,
  248. past_key_values: Optional[Cache] = None,
  249. output_attentions: Optional[bool] = False,
  250. pixel_values_present: Optional[bool] = False,
  251. ) -> tuple[torch.Tensor]:
  252. attn_output, self_attn_weights = self.self(
  253. hidden_states,
  254. attention_mask,
  255. head_mask,
  256. past_key_values,
  257. output_attentions,
  258. pixel_values_present,
  259. )
  260. attention_output = self.output(attn_output, hidden_states)
  261. return attention_output, self_attn_weights
  262. # Copied from transformers.models.bert.modeling_bert.BertIntermediate
  263. class GitIntermediate(nn.Module):
  264. def __init__(self, config):
  265. super().__init__()
  266. self.dense = nn.Linear(config.hidden_size, config.intermediate_size)
  267. if isinstance(config.hidden_act, str):
  268. self.intermediate_act_fn = ACT2FN[config.hidden_act]
  269. else:
  270. self.intermediate_act_fn = config.hidden_act
  271. def forward(self, hidden_states: torch.Tensor) -> torch.Tensor:
  272. hidden_states = self.dense(hidden_states)
  273. hidden_states = self.intermediate_act_fn(hidden_states)
  274. return hidden_states
  275. # Copied from transformers.models.bert.modeling_bert.BertOutput
  276. class GitOutput(nn.Module):
  277. def __init__(self, config):
  278. super().__init__()
  279. self.dense = nn.Linear(config.intermediate_size, config.hidden_size)
  280. self.LayerNorm = nn.LayerNorm(config.hidden_size, eps=config.layer_norm_eps)
  281. self.dropout = nn.Dropout(config.hidden_dropout_prob)
  282. def forward(self, hidden_states: torch.Tensor, input_tensor: torch.Tensor) -> torch.Tensor:
  283. hidden_states = self.dense(hidden_states)
  284. hidden_states = self.dropout(hidden_states)
  285. hidden_states = self.LayerNorm(hidden_states + input_tensor)
  286. return hidden_states
  287. class GitLayer(GradientCheckpointingLayer):
  288. def __init__(self, config, layer_idx=None):
  289. super().__init__()
  290. self.chunk_size_feed_forward = config.chunk_size_feed_forward
  291. self.seq_len_dim = 1
  292. self.attention = GitAttention(config, layer_idx=layer_idx)
  293. self.intermediate = GitIntermediate(config)
  294. self.output = GitOutput(config)
  295. @deprecate_kwarg("past_key_value", new_name="past_key_values", version="4.58")
  296. def forward(
  297. self,
  298. hidden_states: torch.Tensor,
  299. attention_mask: Optional[torch.FloatTensor] = None,
  300. head_mask: Optional[torch.FloatTensor] = None,
  301. past_key_values: Optional[Cache] = None,
  302. output_attentions: Optional[bool] = False,
  303. pixel_values_present: Optional[bool] = False,
  304. ) -> tuple[torch.Tensor]:
  305. # decoder uni-directional self-attention cached key/values tuple is at positions 1,2
  306. attention_output, self_attention_weights = self.attention(
  307. hidden_states,
  308. attention_mask,
  309. head_mask,
  310. output_attentions=output_attentions,
  311. past_key_values=past_key_values,
  312. pixel_values_present=pixel_values_present,
  313. )
  314. layer_output = apply_chunking_to_forward(
  315. self.feed_forward_chunk, self.chunk_size_feed_forward, self.seq_len_dim, attention_output
  316. )
  317. return layer_output, self_attention_weights
  318. def feed_forward_chunk(self, attention_output):
  319. intermediate_output = self.intermediate(attention_output)
  320. layer_output = self.output(intermediate_output, attention_output)
  321. return layer_output
  322. class GitEncoder(nn.Module):
  323. def __init__(self, config):
  324. super().__init__()
  325. self.config = config
  326. self.layer = nn.ModuleList([GitLayer(config, i) for i in range(config.num_hidden_layers)])
  327. self.gradient_checkpointing = False
  328. def forward(
  329. self,
  330. hidden_states: torch.Tensor,
  331. attention_mask: Optional[torch.FloatTensor] = None,
  332. head_mask: Optional[torch.FloatTensor] = None,
  333. past_key_values: Optional[Union[Cache, tuple[tuple[torch.FloatTensor]]]] = None,
  334. use_cache: Optional[bool] = None,
  335. output_attentions: Optional[bool] = False,
  336. output_hidden_states: Optional[bool] = False,
  337. pixel_values_present: Optional[bool] = False,
  338. return_dict: Optional[bool] = True,
  339. ) -> Union[tuple[torch.Tensor], BaseModelOutputWithPast]:
  340. if self.gradient_checkpointing and self.training:
  341. if use_cache:
  342. logger.warning_once(
  343. "`use_cache=True` is incompatible with gradient checkpointing. Setting `use_cache=False`..."
  344. )
  345. use_cache = False
  346. if use_cache and past_key_values is None:
  347. past_key_values = DynamicCache(config=self.config)
  348. all_hidden_states = () if output_hidden_states else None
  349. all_self_attentions = () if output_attentions else None
  350. for i, layer_module in enumerate(self.layer):
  351. if output_hidden_states:
  352. all_hidden_states = all_hidden_states + (hidden_states,)
  353. layer_head_mask = head_mask[i] if head_mask is not None else None
  354. layer_outputs = layer_module(
  355. hidden_states,
  356. attention_mask,
  357. layer_head_mask,
  358. past_key_values,
  359. output_attentions,
  360. pixel_values_present,
  361. )
  362. hidden_states = layer_outputs[0]
  363. if output_attentions:
  364. all_self_attentions = all_self_attentions + (layer_outputs[1],)
  365. if output_hidden_states:
  366. all_hidden_states = all_hidden_states + (hidden_states,)
  367. if not return_dict:
  368. return tuple(
  369. v
  370. for v in [
  371. hidden_states,
  372. past_key_values,
  373. all_hidden_states,
  374. all_self_attentions,
  375. ]
  376. if v is not None
  377. )
  378. return BaseModelOutputWithPast(
  379. last_hidden_state=hidden_states,
  380. past_key_values=past_key_values,
  381. hidden_states=all_hidden_states,
  382. attentions=all_self_attentions,
  383. )
  384. @auto_docstring
  385. class GitPreTrainedModel(PreTrainedModel):
  386. config: GitConfig
  387. base_model_prefix = "git"
  388. supports_gradient_checkpointing = True
  389. def _init_weights(self, module):
  390. """Initialize the weights"""
  391. if isinstance(module, GitVisionEmbeddings):
  392. nn.init.normal_(module.class_embedding, mean=0.0, std=self.config.initializer_range)
  393. nn.init.normal_(module.patch_embedding.weight, std=self.config.initializer_range)
  394. nn.init.normal_(module.position_embedding.weight, std=self.config.initializer_range)
  395. if isinstance(module, nn.Linear):
  396. # Slightly different from the TF version which uses truncated_normal for initialization
  397. # cf https://github.com/pytorch/pytorch/pull/5617
  398. module.weight.data.normal_(mean=0.0, std=self.config.initializer_range)
  399. if module.bias is not None:
  400. module.bias.data.zero_()
  401. elif isinstance(module, nn.Embedding):
  402. module.weight.data.normal_(mean=0.0, std=self.config.initializer_range)
  403. if module.padding_idx is not None:
  404. module.weight.data[module.padding_idx].zero_()
  405. elif isinstance(module, nn.LayerNorm):
  406. module.bias.data.zero_()
  407. module.weight.data.fill_(1.0)
  408. # Copied from transformers.models.clip.modeling_clip.CLIPVisionEmbeddings with CLIP->Git
  409. class GitVisionEmbeddings(nn.Module):
  410. def __init__(self, config: GitVisionConfig):
  411. super().__init__()
  412. self.config = config
  413. self.embed_dim = config.hidden_size
  414. self.image_size = config.image_size
  415. self.patch_size = config.patch_size
  416. self.class_embedding = nn.Parameter(torch.randn(self.embed_dim))
  417. self.patch_embedding = nn.Conv2d(
  418. in_channels=config.num_channels,
  419. out_channels=self.embed_dim,
  420. kernel_size=self.patch_size,
  421. stride=self.patch_size,
  422. bias=False,
  423. )
  424. self.num_patches = (self.image_size // self.patch_size) ** 2
  425. self.num_positions = self.num_patches + 1
  426. self.position_embedding = nn.Embedding(self.num_positions, self.embed_dim)
  427. self.register_buffer("position_ids", torch.arange(self.num_positions).expand((1, -1)), persistent=False)
  428. def interpolate_pos_encoding(self, embeddings: torch.Tensor, height: int, width: int) -> torch.Tensor:
  429. """
  430. This method allows to interpolate the pre-trained position encodings, to be able to use the model on higher resolution
  431. images. This method is also adapted to support torch.jit tracing.
  432. Adapted from:
  433. - https://github.com/facebookresearch/dino/blob/de9ee3df6cf39fac952ab558447af1fa1365362a/vision_transformer.py#L174-L194, and
  434. - https://github.com/facebookresearch/dinov2/blob/e1277af2ba9496fbadf7aec6eba56e8d882d1e35/dinov2/models/vision_transformer.py#L179-L211
  435. """
  436. num_patches = embeddings.shape[1] - 1
  437. position_embedding = self.position_embedding.weight.unsqueeze(0)
  438. num_positions = position_embedding.shape[1] - 1
  439. # always interpolate when tracing to ensure the exported model works for dynamic input shapes
  440. if not torch.jit.is_tracing() and num_patches == num_positions and height == width:
  441. return self.position_embedding(self.position_ids)
  442. class_pos_embed = position_embedding[:, :1]
  443. patch_pos_embed = position_embedding[:, 1:]
  444. dim = embeddings.shape[-1]
  445. new_height = height // self.patch_size
  446. new_width = width // self.patch_size
  447. sqrt_num_positions = torch_int(num_positions**0.5)
  448. patch_pos_embed = patch_pos_embed.reshape(1, sqrt_num_positions, sqrt_num_positions, dim)
  449. patch_pos_embed = patch_pos_embed.permute(0, 3, 1, 2)
  450. patch_pos_embed = nn.functional.interpolate(
  451. patch_pos_embed,
  452. size=(new_height, new_width),
  453. mode="bicubic",
  454. align_corners=False,
  455. )
  456. patch_pos_embed = patch_pos_embed.permute(0, 2, 3, 1).view(1, -1, dim)
  457. return torch.cat((class_pos_embed, patch_pos_embed), dim=1)
  458. def forward(self, pixel_values: torch.FloatTensor, interpolate_pos_encoding=False) -> torch.Tensor:
  459. batch_size, _, height, width = pixel_values.shape
  460. if not interpolate_pos_encoding and (height != self.image_size or width != self.image_size):
  461. raise ValueError(
  462. f"Input image size ({height}*{width}) doesn't match model ({self.image_size}*{self.image_size})."
  463. )
  464. target_dtype = self.patch_embedding.weight.dtype
  465. patch_embeds = self.patch_embedding(pixel_values.to(dtype=target_dtype)) # shape = [*, width, grid, grid]
  466. patch_embeds = patch_embeds.flatten(2).transpose(1, 2)
  467. class_embeds = self.class_embedding.expand(batch_size, 1, -1)
  468. embeddings = torch.cat([class_embeds, patch_embeds], dim=1)
  469. if interpolate_pos_encoding:
  470. embeddings = embeddings + self.interpolate_pos_encoding(embeddings, height, width)
  471. else:
  472. embeddings = embeddings + self.position_embedding(self.position_ids)
  473. return embeddings
  474. class GitVisionMLP(nn.Module):
  475. def __init__(self, config):
  476. super().__init__()
  477. self.config = config
  478. self.activation_fn = ACT2FN[config.hidden_act]
  479. self.fc1 = nn.Linear(config.hidden_size, config.intermediate_size)
  480. self.fc2 = nn.Linear(config.intermediate_size, config.hidden_size)
  481. def forward(self, hidden_states: torch.Tensor) -> torch.Tensor:
  482. hidden_states = self.fc1(hidden_states)
  483. hidden_states = self.activation_fn(hidden_states)
  484. hidden_states = self.fc2(hidden_states)
  485. return hidden_states
  486. # Copied from transformers.models.siglip.modeling_siglip.eager_attention_forward
  487. def eager_attention_forward(
  488. module: nn.Module,
  489. query: torch.Tensor,
  490. key: torch.Tensor,
  491. value: torch.Tensor,
  492. attention_mask: Optional[torch.Tensor],
  493. scaling: float,
  494. dropout: float = 0.0,
  495. **kwargs,
  496. ):
  497. attn_weights = torch.matmul(query, key.transpose(-1, -2)) * scaling
  498. if attention_mask is not None:
  499. attn_weights = attn_weights + attention_mask
  500. attn_weights = nn.functional.softmax(attn_weights, dim=-1, dtype=torch.float32).to(query.dtype)
  501. attn_weights = nn.functional.dropout(attn_weights, p=dropout, training=module.training)
  502. attn_output = torch.matmul(attn_weights, value)
  503. attn_output = attn_output.transpose(1, 2).contiguous()
  504. return attn_output, attn_weights
  505. class GitVisionAttention(nn.Module):
  506. """Multi-headed attention from 'Attention Is All You Need' paper"""
  507. def __init__(self, config):
  508. super().__init__()
  509. self.config = config
  510. self.embed_dim = config.hidden_size
  511. self.num_heads = config.num_attention_heads
  512. self.head_dim = self.embed_dim // self.num_heads
  513. if self.head_dim * self.num_heads != self.embed_dim:
  514. raise ValueError(
  515. f"embed_dim must be divisible by num_heads (got `embed_dim`: {self.embed_dim} and `num_heads`:"
  516. f" {self.num_heads})."
  517. )
  518. self.scale = self.head_dim**-0.5
  519. self.dropout = config.attention_dropout
  520. self.is_causal = False
  521. self.k_proj = nn.Linear(self.embed_dim, self.embed_dim)
  522. self.v_proj = nn.Linear(self.embed_dim, self.embed_dim)
  523. self.q_proj = nn.Linear(self.embed_dim, self.embed_dim)
  524. self.out_proj = nn.Linear(self.embed_dim, self.embed_dim)
  525. def forward(
  526. self,
  527. hidden_states: torch.Tensor,
  528. attention_mask: Optional[torch.Tensor] = None,
  529. causal_attention_mask: Optional[torch.Tensor] = None,
  530. output_attentions: Optional[bool] = False,
  531. ) -> tuple[torch.Tensor, Optional[torch.Tensor]]:
  532. """Input shape: Batch x Time x Channel"""
  533. batch_size, seq_length, embed_dim = hidden_states.shape
  534. queries = self.q_proj(hidden_states)
  535. keys = self.k_proj(hidden_states)
  536. values = self.v_proj(hidden_states)
  537. queries = queries.view(batch_size, seq_length, self.num_heads, self.head_dim).transpose(1, 2)
  538. keys = keys.view(batch_size, seq_length, self.num_heads, self.head_dim).transpose(1, 2)
  539. values = values.view(batch_size, seq_length, self.num_heads, self.head_dim).transpose(1, 2)
  540. # CLIP text model uses both `causal_attention_mask` and `attention_mask`
  541. # in case FA2 kernel is called, `is_causal` should be inferred from `causal_attention_mask`
  542. if self.config._attn_implementation != "flash_attention_2":
  543. if attention_mask is not None and causal_attention_mask is not None:
  544. attention_mask = attention_mask + causal_attention_mask
  545. elif causal_attention_mask is not None:
  546. attention_mask = causal_attention_mask
  547. else:
  548. self.is_causal = causal_attention_mask is not None
  549. attention_interface: Callable = eager_attention_forward
  550. if self.config._attn_implementation != "eager":
  551. attention_interface = ALL_ATTENTION_FUNCTIONS[self.config._attn_implementation]
  552. attn_output, attn_weights = attention_interface(
  553. self,
  554. queries,
  555. keys,
  556. values,
  557. attention_mask,
  558. is_causal=self.is_causal,
  559. scaling=self.scale,
  560. dropout=0.0 if not self.training else self.dropout,
  561. )
  562. attn_output = attn_output.reshape(batch_size, seq_length, embed_dim).contiguous()
  563. attn_output = self.out_proj(attn_output)
  564. if not output_attentions:
  565. attn_weights = None
  566. return attn_output, attn_weights
  567. # Copied from transformers.models.altclip.modeling_altclip.AltCLIPEncoderLayer with AltCLIP->GitVision
  568. class GitVisionEncoderLayer(GradientCheckpointingLayer):
  569. def __init__(self, config: GitVisionConfig):
  570. super().__init__()
  571. self.embed_dim = config.hidden_size
  572. self.self_attn = GitVisionAttention(config)
  573. self.layer_norm1 = nn.LayerNorm(self.embed_dim, eps=config.layer_norm_eps)
  574. self.mlp = GitVisionMLP(config)
  575. self.layer_norm2 = nn.LayerNorm(self.embed_dim, eps=config.layer_norm_eps)
  576. def forward(
  577. self,
  578. hidden_states: torch.Tensor,
  579. attention_mask: torch.Tensor,
  580. causal_attention_mask: torch.Tensor,
  581. output_attentions: Optional[bool] = False,
  582. ) -> tuple[torch.FloatTensor]:
  583. """
  584. Args:
  585. hidden_states (`torch.FloatTensor`): input to the layer of shape `(batch, seq_len, embed_dim)`
  586. attention_mask (`torch.FloatTensor`): attention mask of size
  587. `(batch, 1, tgt_len, src_len)` where padding elements are indicated by very large negative values.
  588. `(config.encoder_attention_heads,)`.
  589. output_attentions (`bool`, *optional*):
  590. Whether or not to return the attentions tensors of all attention layers. See `attentions` under
  591. returned tensors for more detail.
  592. """
  593. residual = hidden_states
  594. hidden_states = self.layer_norm1(hidden_states)
  595. hidden_states, attn_weights = self.self_attn(
  596. hidden_states=hidden_states,
  597. attention_mask=attention_mask,
  598. causal_attention_mask=causal_attention_mask,
  599. output_attentions=output_attentions,
  600. )
  601. hidden_states = residual + hidden_states
  602. residual = hidden_states
  603. hidden_states = self.layer_norm2(hidden_states)
  604. hidden_states = self.mlp(hidden_states)
  605. hidden_states = residual + hidden_states
  606. outputs = (hidden_states,)
  607. if output_attentions:
  608. outputs += (attn_weights,)
  609. return outputs
  610. # Copied from transformers.models.altclip.modeling_altclip.AltCLIPEncoder with AltCLIP->GitVision, CLIPConfig
  611. class GitVisionEncoder(nn.Module):
  612. """
  613. Transformer encoder consisting of `config.num_hidden_layers` self attention layers. Each layer is a
  614. [`GitVisionEncoderLayer`].
  615. Args:
  616. config: GitVisionConfig
  617. """
  618. def __init__(self, config: GitVisionConfig):
  619. super().__init__()
  620. self.config = config
  621. self.layers = nn.ModuleList([GitVisionEncoderLayer(config) for _ in range(config.num_hidden_layers)])
  622. self.gradient_checkpointing = False
  623. @can_return_tuple
  624. def forward(
  625. self,
  626. inputs_embeds,
  627. attention_mask: Optional[torch.Tensor] = None,
  628. causal_attention_mask: Optional[torch.Tensor] = None,
  629. output_attentions: Optional[bool] = None,
  630. output_hidden_states: Optional[bool] = None,
  631. return_dict: Optional[bool] = None,
  632. ) -> Union[tuple, BaseModelOutput]:
  633. r"""
  634. Args:
  635. inputs_embeds (`torch.FloatTensor` of shape `(batch_size, sequence_length, hidden_size)`):
  636. Optionally, instead of passing `input_ids` you can choose to directly pass an embedded representation.
  637. This is useful if you want more control over how to convert `input_ids` indices into associated vectors
  638. than the model's internal embedding lookup matrix.
  639. attention_mask (`torch.Tensor` of shape `(batch_size, sequence_length)`, *optional*):
  640. Mask to avoid performing attention on padding token indices. Mask values selected in `[0, 1]`:
  641. - 1 for tokens that are **not masked**,
  642. - 0 for tokens that are **masked**.
  643. [What are attention masks?](../glossary#attention-mask)
  644. causal_attention_mask (`torch.Tensor` of shape `(batch_size, sequence_length)`, *optional*):
  645. Causal mask for the text model. Mask values selected in `[0, 1]`:
  646. - 1 for tokens that are **not masked**,
  647. - 0 for tokens that are **masked**.
  648. [What are attention masks?](../glossary#attention-mask)
  649. output_attentions (`bool`, *optional*):
  650. Whether or not to return the attentions tensors of all attention layers. See `attentions` under
  651. returned tensors for more detail.
  652. output_hidden_states (`bool`, *optional*):
  653. Whether or not to return the hidden states of all layers. See `hidden_states` under returned tensors
  654. for more detail.
  655. return_dict (`bool`, *optional*):
  656. Whether or not to return a [`~utils.ModelOutput`] instead of a plain tuple.
  657. """
  658. output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
  659. output_hidden_states = (
  660. output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
  661. )
  662. return_dict = return_dict if return_dict is not None else self.config.use_return_dict
  663. encoder_states = () if output_hidden_states else None
  664. all_attentions = () if output_attentions else None
  665. hidden_states = inputs_embeds
  666. for idx, encoder_layer in enumerate(self.layers):
  667. if output_hidden_states:
  668. encoder_states = encoder_states + (hidden_states,)
  669. layer_outputs = encoder_layer(
  670. hidden_states,
  671. attention_mask,
  672. causal_attention_mask,
  673. output_attentions=output_attentions,
  674. )
  675. hidden_states = layer_outputs[0]
  676. if output_attentions:
  677. all_attentions = all_attentions + (layer_outputs[1],)
  678. if output_hidden_states:
  679. encoder_states = encoder_states + (hidden_states,)
  680. return BaseModelOutput(
  681. last_hidden_state=hidden_states, hidden_states=encoder_states, attentions=all_attentions
  682. )
  683. class GitVisionTransformer(nn.Module):
  684. # Copied from transformers.models.altclip.modeling_altclip.AltCLIPVisionTransformer.__init__ with AltCLIPEncoder->GitVisionEncoder, AltCLIP->Git
  685. def __init__(self, config: GitVisionConfig):
  686. super().__init__()
  687. self.config = config
  688. embed_dim = config.hidden_size
  689. self.embeddings = GitVisionEmbeddings(config)
  690. self.pre_layrnorm = nn.LayerNorm(embed_dim, eps=config.layer_norm_eps)
  691. self.encoder = GitVisionEncoder(config)
  692. self.post_layernorm = nn.LayerNorm(embed_dim, eps=config.layer_norm_eps)
  693. @auto_docstring
  694. def forward(
  695. self,
  696. pixel_values: Optional[torch.FloatTensor] = None,
  697. output_attentions: Optional[bool] = None,
  698. output_hidden_states: Optional[bool] = None,
  699. interpolate_pos_encoding: Optional[bool] = False,
  700. return_dict: Optional[bool] = None,
  701. ) -> Union[tuple, BaseModelOutput]:
  702. output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
  703. output_hidden_states = (
  704. output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
  705. )
  706. return_dict = return_dict if return_dict is not None else self.config.use_return_dict
  707. if pixel_values is None:
  708. raise ValueError("You have to specify pixel_values")
  709. hidden_states = self.embeddings(pixel_values, interpolate_pos_encoding=interpolate_pos_encoding)
  710. hidden_states = self.pre_layrnorm(hidden_states)
  711. encoder_outputs = self.encoder(
  712. inputs_embeds=hidden_states,
  713. output_attentions=output_attentions,
  714. output_hidden_states=output_hidden_states,
  715. return_dict=return_dict,
  716. )
  717. last_hidden_state = encoder_outputs[0]
  718. last_hidden_state = self.post_layernorm(last_hidden_state)
  719. if not return_dict:
  720. return (last_hidden_state,) + encoder_outputs[1:]
  721. return BaseModelOutput(
  722. last_hidden_state=last_hidden_state,
  723. hidden_states=encoder_outputs.hidden_states,
  724. attentions=encoder_outputs.attentions,
  725. )
  726. @auto_docstring(
  727. custom_intro="""
  728. The vision model from CLIP, used in GIT, without any head or projection on top.
  729. """
  730. )
  731. class GitVisionModel(GitPreTrainedModel):
  732. config: GitVisionConfig
  733. main_input_name = "pixel_values"
  734. # Copied from transformers.models.clip.modeling_clip.CLIPVisionModel.__init__ with CLIP->Git
  735. def __init__(self, config: GitVisionConfig):
  736. super().__init__(config)
  737. self.vision_model = GitVisionTransformer(config)
  738. # Initialize weights and apply final processing
  739. self.post_init()
  740. def get_input_embeddings(self) -> nn.Module:
  741. return self.vision_model.embeddings.patch_embedding
  742. @auto_docstring
  743. def forward(
  744. self,
  745. pixel_values: Optional[torch.FloatTensor] = None,
  746. output_attentions: Optional[bool] = None,
  747. output_hidden_states: Optional[bool] = None,
  748. interpolate_pos_encoding: bool = False,
  749. return_dict: Optional[bool] = None,
  750. ) -> Union[tuple, BaseModelOutput]:
  751. r"""
  752. Examples:
  753. ```python
  754. >>> from PIL import Image
  755. >>> import requests
  756. >>> from transformers import AutoProcessor, GitVisionModel
  757. >>> processor = AutoProcessor.from_pretrained("microsoft/git-base")
  758. >>> model = GitVisionModel.from_pretrained("microsoft/git-base")
  759. >>> url = "http://images.cocodataset.org/val2017/000000039769.jpg"
  760. >>> image = Image.open(requests.get(url, stream=True).raw)
  761. >>> inputs = processor(images=image, return_tensors="pt")
  762. >>> outputs = model(**inputs)
  763. >>> last_hidden_state = outputs.last_hidden_state
  764. ```"""
  765. return_dict = return_dict if return_dict is not None else self.config.use_return_dict
  766. return self.vision_model(
  767. pixel_values=pixel_values,
  768. output_attentions=output_attentions,
  769. output_hidden_states=output_hidden_states,
  770. interpolate_pos_encoding=interpolate_pos_encoding,
  771. return_dict=return_dict,
  772. )
  773. class GitProjection(nn.Module):
  774. def __init__(self, config: GitConfig):
  775. super().__init__()
  776. self.config = config
  777. self.visual_projection = nn.Sequential(
  778. nn.Linear(config.vision_config.hidden_size, config.hidden_size),
  779. nn.LayerNorm(config.hidden_size, eps=config.vision_config.layer_norm_eps),
  780. )
  781. def forward(self, embeddings: torch.Tensor) -> torch.Tensor:
  782. return self.visual_projection(embeddings)
  783. @auto_docstring(
  784. custom_intro="""
  785. The bare GIT Model transformer consisting of a CLIP image encoder and text decoder outputting raw hidden-states
  786. """
  787. )
  788. class GitModel(GitPreTrainedModel):
  789. def __init__(self, config):
  790. super().__init__(config)
  791. self.config = config
  792. self.embeddings = GitEmbeddings(config)
  793. self.image_encoder = GitVisionModel(config.vision_config)
  794. self.encoder = GitEncoder(config)
  795. self.visual_projection = GitProjection(config)
  796. if config.num_image_with_embedding is not None:
  797. self.img_temporal_embedding = nn.ParameterList(
  798. nn.Parameter(torch.zeros(1, 1, config.vision_config.hidden_size))
  799. for _ in range(config.num_image_with_embedding)
  800. )
  801. # Initialize weights and apply final processing
  802. self.post_init()
  803. def get_input_embeddings(self):
  804. return self.embeddings.word_embeddings
  805. def set_input_embeddings(self, value):
  806. self.embeddings.word_embeddings = value
  807. def _prune_heads(self, heads_to_prune):
  808. """
  809. Prunes heads of the model. heads_to_prune: dict of {layer_num: list of heads to prune in this layer} See base
  810. class PreTrainedModel
  811. """
  812. for layer, heads in heads_to_prune.items():
  813. self.encoder.layer[layer].attention.prune_heads(heads)
  814. def _generate_future_mask(self, size: int, dtype: torch.dtype, device: torch.device) -> torch.Tensor:
  815. # Default mask is for forward direction. Flip for backward direction.
  816. mask = torch.triu(torch.ones(size, size, device=device, dtype=dtype), diagonal=1)
  817. mask = mask.masked_fill(mask == 1, float("-inf"))
  818. return mask
  819. def create_attention_mask(self, tgt, memory, tgt_mask, past_key_values_length, memory_key_padding_mask=None):
  820. num_tgt = tgt.shape[1]
  821. num_memory = memory.shape[1]
  822. device = tgt.device
  823. dtype = tgt.dtype
  824. top_left = torch.zeros((num_memory, num_memory), device=device, dtype=dtype)
  825. top_right = torch.full(
  826. (num_memory, num_tgt + past_key_values_length),
  827. float("-inf"),
  828. device=tgt.device,
  829. dtype=dtype,
  830. )
  831. bottom_left = torch.zeros(
  832. (num_tgt, num_memory),
  833. dtype=dtype,
  834. device=tgt_mask.device,
  835. )
  836. if past_key_values_length > 0:
  837. tgt_mask = torch.zeros(
  838. (tgt_mask.shape[0], tgt_mask.shape[0] + past_key_values_length),
  839. dtype=dtype,
  840. device=tgt_mask.device,
  841. )
  842. left = torch.cat((top_left, bottom_left), dim=0)
  843. right = torch.cat((top_right, tgt_mask.to(dtype)), dim=0)
  844. full_attention_mask = torch.cat((left, right), dim=1)[None, :]
  845. if memory_key_padding_mask is None:
  846. memory_key_padding_mask = torch.full((memory.shape[0], memory.shape[1]), fill_value=False, device=device)
  847. # if it is False, it means valid. That is, it is not a padding
  848. if memory_key_padding_mask.dtype != torch.bool:
  849. raise ValueError("Memory key padding mask must be a boolean tensor.")
  850. zero_negative_infinity = torch.zeros_like(memory_key_padding_mask, dtype=tgt.dtype)
  851. zero_negative_infinity[memory_key_padding_mask] = float("-inf")
  852. full_attention_mask = full_attention_mask.expand(
  853. (memory_key_padding_mask.shape[0], num_memory + num_tgt, num_memory + past_key_values_length + num_tgt)
  854. )
  855. full_attention_mask = full_attention_mask.clone()
  856. origin_left = full_attention_mask[:, :, :num_memory]
  857. update = zero_negative_infinity[:, None, :]
  858. full_attention_mask[:, :, :num_memory] = origin_left + update
  859. # add axis for multi-head
  860. full_attention_mask = full_attention_mask[:, None, :, :]
  861. return full_attention_mask
  862. @auto_docstring
  863. def forward(
  864. self,
  865. input_ids: Optional[torch.Tensor] = None,
  866. attention_mask: Optional[torch.Tensor] = None,
  867. position_ids: Optional[torch.Tensor] = None,
  868. pixel_values: Optional[torch.Tensor] = None,
  869. head_mask: Optional[torch.Tensor] = None,
  870. inputs_embeds: Optional[torch.Tensor] = None,
  871. past_key_values: Optional[Union[Cache, list[torch.FloatTensor]]] = None,
  872. use_cache: Optional[bool] = None,
  873. output_attentions: Optional[bool] = None,
  874. output_hidden_states: Optional[bool] = None,
  875. interpolate_pos_encoding: bool = False,
  876. return_dict: Optional[bool] = None,
  877. ) -> Union[tuple[torch.Tensor], BaseModelOutputWithPooling]:
  878. r"""
  879. Examples:
  880. ```python
  881. >>> from transformers import AutoProcessor, AutoModel
  882. >>> import requests
  883. >>> from PIL import Image
  884. >>> processor = AutoProcessor.from_pretrained("microsoft/git-base")
  885. >>> model = AutoModel.from_pretrained("microsoft/git-base")
  886. >>> url = "http://images.cocodataset.org/val2017/000000039769.jpg"
  887. >>> image = Image.open(requests.get(url, stream=True).raw)
  888. >>> text = "this is an image of two cats"
  889. >>> inputs = processor(images=image, text=text, return_tensors="pt")
  890. >>> outputs = model(**inputs)
  891. >>> last_hidden_state = outputs.last_hidden_state
  892. ```"""
  893. output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
  894. output_hidden_states = (
  895. output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
  896. )
  897. use_cache = use_cache if use_cache is not None else self.config.use_cache
  898. return_dict = return_dict if return_dict is not None else self.config.use_return_dict
  899. if input_ids is not None and inputs_embeds is not None:
  900. raise ValueError("You cannot specify both input_ids and inputs_embeds at the same time")
  901. elif input_ids is not None:
  902. self.warn_if_padding_and_no_attention_mask(input_ids, attention_mask)
  903. input_shape = input_ids.size()
  904. elif inputs_embeds is not None:
  905. input_shape = inputs_embeds.size()[:-1]
  906. else:
  907. raise ValueError("You have to specify either input_ids or inputs_embeds")
  908. seq_length = input_shape[1]
  909. # past_key_values_length
  910. past_key_values_length = 0
  911. if past_key_values is not None:
  912. past_key_values_length = (
  913. past_key_values.get_seq_length()
  914. if not isinstance(past_key_values, Cache)
  915. else past_key_values.get_seq_length()
  916. )
  917. # Prepare head mask if needed
  918. # 1.0 in head_mask indicate we keep the head
  919. # attention_probs has shape bsz x n_heads x N x N
  920. # input head_mask has shape [num_heads] or [num_hidden_layers x num_heads]
  921. # and head_mask is converted to shape [num_hidden_layers x batch x num_heads x seq_length x seq_length]
  922. head_mask = self.get_head_mask(head_mask, self.config.num_hidden_layers)
  923. projected_visual_features = None
  924. if pixel_values is not None:
  925. if pixel_values.ndim == 4:
  926. # here we assume pixel_values is of shape (batch_size, num_channels, height, width)
  927. visual_features = self.image_encoder(
  928. pixel_values, interpolate_pos_encoding=interpolate_pos_encoding
  929. ).last_hidden_state
  930. elif pixel_values.ndim == 5:
  931. # here we assume pixel_values is of shape (batch_size, num_frames, num_channels, height, width)
  932. visual_features = []
  933. for frame_idx in range(pixel_values.shape[1]):
  934. visual_features_frame = self.image_encoder(
  935. pixel_values[:, frame_idx, :, :], interpolate_pos_encoding=interpolate_pos_encoding
  936. ).last_hidden_state
  937. visual_features_frame += self.img_temporal_embedding[frame_idx]
  938. visual_features.append(visual_features_frame)
  939. # finally, concatenate all features along sequence dimension
  940. visual_features = torch.cat(visual_features, dim=1)
  941. else:
  942. raise ValueError("pixel_values must be of rank 4 or 5")
  943. projected_visual_features = self.visual_projection(visual_features)
  944. embedding_output = self.embeddings(
  945. input_ids=input_ids,
  946. position_ids=position_ids,
  947. inputs_embeds=inputs_embeds,
  948. past_key_values_length=past_key_values_length,
  949. )
  950. if projected_visual_features is None:
  951. projected_visual_features = torch.zeros(
  952. (embedding_output.shape[0], 0, embedding_output.shape[2]),
  953. dtype=embedding_output.dtype,
  954. device=embedding_output.device,
  955. )
  956. # Repeat visual features to match embedding batch size.
  957. projected_visual_features = projected_visual_features.repeat(
  958. embedding_output.size(0) // projected_visual_features.size(0), 1, 1
  959. )
  960. # concatenate patch token and text token embeddings
  961. hidden_states = torch.cat((projected_visual_features, embedding_output), dim=1)
  962. # By default, an additive causal mask is created
  963. # for masking the future (one direction).
  964. tgt_mask = self._generate_future_mask(seq_length, embedding_output.dtype, embedding_output.device)
  965. # Create an attention mask of shape (batch_size, 1, tgt_seq_len, src_seq_len)
  966. combined_attention_mask = self.create_attention_mask(
  967. tgt=embedding_output,
  968. memory=projected_visual_features,
  969. tgt_mask=tgt_mask,
  970. past_key_values_length=past_key_values_length,
  971. )
  972. if attention_mask is not None:
  973. # if the user provides an attention mask, we add it to the default one
  974. # [bsz, seq_len] -> [bsz, 1, tgt_seq_len, src_seq_len]
  975. expanded_attn_mask = _prepare_4d_attention_mask(
  976. attention_mask, embedding_output.dtype, tgt_len=input_shape[-1]
  977. ).to(embedding_output.device)
  978. if past_key_values_length > 0:
  979. expanded_attn_mask = expanded_attn_mask[:, :, -past_key_values_length:, :]
  980. else:
  981. combined_attention_mask[:, :, -input_shape[1] :, -input_shape[1] :] += expanded_attn_mask
  982. encoder_outputs = self.encoder(
  983. hidden_states,
  984. attention_mask=combined_attention_mask,
  985. head_mask=head_mask,
  986. past_key_values=past_key_values,
  987. use_cache=use_cache,
  988. output_attentions=output_attentions,
  989. output_hidden_states=output_hidden_states,
  990. return_dict=return_dict,
  991. pixel_values_present=pixel_values is not None,
  992. )
  993. sequence_output = encoder_outputs[0]
  994. if not return_dict:
  995. return (sequence_output,) + encoder_outputs[1:]
  996. return BaseModelOutputWithPast(
  997. last_hidden_state=sequence_output,
  998. past_key_values=encoder_outputs.past_key_values,
  999. hidden_states=encoder_outputs.hidden_states,
  1000. attentions=encoder_outputs.attentions,
  1001. )
  1002. @auto_docstring(
  1003. custom_intro="""
  1004. GIT Model with a `language modeling` head on top for autoregressive language modeling.
  1005. """
  1006. )
  1007. class GitForCausalLM(GitPreTrainedModel, GenerationMixin):
  1008. _tied_weights_keys = ["output.weight"]
  1009. def __init__(self, config):
  1010. super().__init__(config)
  1011. self.git = GitModel(config)
  1012. self.output = nn.Linear(config.hidden_size, config.vocab_size)
  1013. # Initialize weights and apply final processing
  1014. self.post_init()
  1015. def get_output_embeddings(self):
  1016. return self.output
  1017. def set_output_embeddings(self, new_embeddings):
  1018. self.output = new_embeddings
  1019. @auto_docstring
  1020. def forward(
  1021. self,
  1022. input_ids: Optional[torch.Tensor] = None,
  1023. attention_mask: Optional[torch.Tensor] = None,
  1024. position_ids: Optional[torch.Tensor] = None,
  1025. pixel_values: Optional[torch.Tensor] = None,
  1026. head_mask: Optional[torch.Tensor] = None,
  1027. inputs_embeds: Optional[torch.Tensor] = None,
  1028. labels: Optional[torch.Tensor] = None,
  1029. past_key_values: Optional[Union[Cache, list[torch.Tensor]]] = None,
  1030. use_cache: Optional[bool] = None,
  1031. output_attentions: Optional[bool] = None,
  1032. output_hidden_states: Optional[bool] = None,
  1033. interpolate_pos_encoding: bool = False,
  1034. return_dict: Optional[bool] = None,
  1035. **kwargs,
  1036. ) -> Union[tuple[torch.Tensor], CausalLMOutputWithPast]:
  1037. r"""
  1038. labels (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*):
  1039. Labels for computing the left-to-right language modeling loss (next word prediction). Indices should be in
  1040. `[-100, 0, ..., config.vocab_size]` (see `input_ids` docstring) Tokens with indices set to `-100` are
  1041. ignored (masked), the loss is only computed for the tokens with labels n `[0, ..., config.vocab_size]`
  1042. Examples:
  1043. Image captioning example:
  1044. ```python
  1045. >>> from transformers import AutoProcessor, AutoModelForCausalLM
  1046. >>> import requests
  1047. >>> from PIL import Image
  1048. >>> processor = AutoProcessor.from_pretrained("microsoft/git-base-coco")
  1049. >>> model = AutoModelForCausalLM.from_pretrained("microsoft/git-base-coco")
  1050. >>> url = "http://images.cocodataset.org/val2017/000000039769.jpg"
  1051. >>> image = Image.open(requests.get(url, stream=True).raw)
  1052. >>> pixel_values = processor(images=image, return_tensors="pt").pixel_values
  1053. >>> generated_ids = model.generate(pixel_values=pixel_values, max_length=50)
  1054. >>> generated_caption = processor.batch_decode(generated_ids, skip_special_tokens=True)[0]
  1055. >>> print(generated_caption)
  1056. two cats sleeping on a pink blanket next to remotes.
  1057. ```
  1058. Visual question answering (VQA) example:
  1059. ```python
  1060. >>> from transformers import AutoProcessor, AutoModelForCausalLM
  1061. >>> from huggingface_hub import hf_hub_download
  1062. >>> from PIL import Image
  1063. >>> processor = AutoProcessor.from_pretrained("microsoft/git-base-textvqa")
  1064. >>> model = AutoModelForCausalLM.from_pretrained("microsoft/git-base-textvqa")
  1065. >>> file_path = hf_hub_download(repo_id="nielsr/textvqa-sample", filename="bus.png", repo_type="dataset")
  1066. >>> image = Image.open(file_path).convert("RGB")
  1067. >>> pixel_values = processor(images=image, return_tensors="pt").pixel_values
  1068. >>> question = "what does the front of the bus say at the top?"
  1069. >>> input_ids = processor(text=question, add_special_tokens=False).input_ids
  1070. >>> input_ids = [processor.tokenizer.cls_token_id] + input_ids
  1071. >>> input_ids = torch.tensor(input_ids).unsqueeze(0)
  1072. >>> generated_ids = model.generate(pixel_values=pixel_values, input_ids=input_ids, max_length=50)
  1073. >>> print(processor.batch_decode(generated_ids, skip_special_tokens=True))
  1074. ['what does the front of the bus say at the top? special']
  1075. ```
  1076. Video captioning example:
  1077. ```python
  1078. >>> import av
  1079. >>> import numpy as np
  1080. >>> from PIL import Image
  1081. >>> from huggingface_hub import hf_hub_download
  1082. >>> from transformers import AutoProcessor, AutoModelForCausalLM
  1083. >>> processor = AutoProcessor.from_pretrained("microsoft/git-base-vatex")
  1084. >>> model = AutoModelForCausalLM.from_pretrained("microsoft/git-base-vatex")
  1085. >>> # set seed for reproducibility
  1086. >>> np.random.seed(45)
  1087. >>> def read_video_pyav(container, indices):
  1088. ... '''
  1089. ... Decode the video with PyAV decoder.
  1090. ... Args:
  1091. ... container (`av.container.input.InputContainer`): PyAV container.
  1092. ... indices (`list[int]`): List of frame indices to decode.
  1093. ... Returns:
  1094. ... result (np.ndarray): np array of decoded frames of shape (num_frames, height, width, 3).
  1095. ... '''
  1096. ... frames = []
  1097. ... container.seek(0)
  1098. ... start_index = indices[0]
  1099. ... end_index = indices[-1]
  1100. ... for i, frame in enumerate(container.decode(video=0)):
  1101. ... if i > end_index:
  1102. ... break
  1103. ... if i >= start_index and i in indices:
  1104. ... frames.append(frame)
  1105. ... return np.stack([x.to_ndarray(format="rgb24") for x in frames])
  1106. >>> def sample_frame_indices(clip_len, frame_sample_rate, seg_len):
  1107. ... '''
  1108. ... Sample a given number of frame indices from the video.
  1109. ... Args:
  1110. ... clip_len (`int`): Total number of frames to sample.
  1111. ... frame_sample_rate (`int`): Sample every n-th frame.
  1112. ... seg_len (`int`): Maximum allowed index of sample's last frame.
  1113. ... Returns:
  1114. ... indices (`list[int]`): List of sampled frame indices
  1115. ... '''
  1116. ... converted_len = int(clip_len * frame_sample_rate)
  1117. ... end_idx = np.random.randint(converted_len, seg_len)
  1118. ... start_idx = end_idx - converted_len
  1119. ... indices = np.linspace(start_idx, end_idx, num=clip_len)
  1120. ... indices = np.clip(indices, start_idx, end_idx - 1).astype(np.int64)
  1121. ... return indices
  1122. >>> # load video
  1123. >>> file_path = hf_hub_download(
  1124. ... repo_id="nielsr/video-demo", filename="eating_spaghetti.mp4", repo_type="dataset"
  1125. ... )
  1126. >>> container = av.open(file_path)
  1127. >>> # sample frames
  1128. >>> num_frames = model.config.num_image_with_embedding
  1129. >>> indices = sample_frame_indices(
  1130. ... clip_len=num_frames, frame_sample_rate=4, seg_len=container.streams.video[0].frames
  1131. ... )
  1132. >>> frames = read_video_pyav(container, indices)
  1133. >>> pixel_values = processor(images=list(frames), return_tensors="pt").pixel_values
  1134. >>> generated_ids = model.generate(pixel_values=pixel_values, max_length=50)
  1135. >>> print("Generated caption:", processor.batch_decode(generated_ids, skip_special_tokens=True))
  1136. Generated caption: ['a woman is sitting at a table and she is talking about the food she is holding.']
  1137. ```
  1138. """
  1139. return_dict = return_dict if return_dict is not None else self.config.use_return_dict
  1140. if labels is not None:
  1141. use_cache = False
  1142. outputs = self.git(
  1143. input_ids,
  1144. attention_mask=attention_mask,
  1145. position_ids=position_ids,
  1146. pixel_values=pixel_values,
  1147. head_mask=head_mask,
  1148. inputs_embeds=inputs_embeds,
  1149. past_key_values=past_key_values,
  1150. use_cache=use_cache,
  1151. output_attentions=output_attentions,
  1152. output_hidden_states=output_hidden_states,
  1153. interpolate_pos_encoding=interpolate_pos_encoding,
  1154. return_dict=return_dict,
  1155. )
  1156. sequence_output = outputs[0]
  1157. logits = self.output(sequence_output)
  1158. loss = None
  1159. if labels is not None:
  1160. # we are doing next-token prediction; shift prediction scores and input ids by one
  1161. num_image_tokens = self.git.encoder.layer[0].attention.self.image_patch_tokens
  1162. shifted_logits = logits[:, num_image_tokens:-1, :].contiguous()
  1163. labels = labels[:, 1:].contiguous()
  1164. loss = self.loss_function(
  1165. shifted_logits.view(-1, self.config.vocab_size),
  1166. labels.view(-1),
  1167. vocab_size=self.config.vocab_size,
  1168. **kwargs,
  1169. )
  1170. if not return_dict:
  1171. output = (logits,) + outputs[1:]
  1172. return ((loss,) + output) if loss is not None else output
  1173. return CausalLMOutputWithPast(
  1174. loss=loss,
  1175. logits=logits,
  1176. past_key_values=outputs.past_key_values,
  1177. hidden_states=outputs.hidden_states,
  1178. attentions=outputs.attentions,
  1179. )
  1180. def prepare_inputs_for_generation(
  1181. self, input_ids, past_key_values=None, attention_mask=None, use_cache=None, **kwargs
  1182. ):
  1183. # Overwritten -- `git` has special cache handling and doesn't support generating from `inputs_embeds` atm
  1184. # cut decoder_input_ids if past_key_values is used
  1185. if past_key_values is not None:
  1186. past_length = past_key_values.get_seq_length()
  1187. # Some generation methods already pass only the last input ID
  1188. if input_ids.shape[1] > past_length:
  1189. remove_prefix_length = past_length
  1190. else:
  1191. # Default to old behavior: keep only final ID
  1192. remove_prefix_length = input_ids.shape[1] - 1
  1193. input_ids = input_ids[:, remove_prefix_length:]
  1194. # if model is used as a decoder in encoder-decoder model, the decoder attention mask is created on the fly
  1195. input_shape = input_ids.shape
  1196. if attention_mask is None:
  1197. attention_mask = input_ids.new_ones(input_shape)
  1198. model_inputs = {
  1199. "input_ids": input_ids,
  1200. "attention_mask": attention_mask,
  1201. "pixel_values": kwargs.get("pixel_values"),
  1202. "past_key_values": past_key_values,
  1203. "use_cache": use_cache,
  1204. }
  1205. # Forward ALL kwargs that are uninitialized (e.g. `use_cache`).
  1206. for key, value in kwargs.items():
  1207. if key not in model_inputs:
  1208. model_inputs[key] = value
  1209. return model_inputs
  1210. __all__ = ["GitForCausalLM", "GitModel", "GitPreTrainedModel", "GitVisionModel"]