vision.py 21 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476
  1. # coding=utf-8
  2. # Copyright 2021 The OpenAI Team Authors and The HuggingFace Team. All rights reserved.
  3. #
  4. # Licensed under the Apache License, Version 2.0 (the "License");
  5. # you may not use this file except in compliance with the License.
  6. # You may obtain a copy of the License at
  7. #
  8. # http://www.apache.org/licenses/LICENSE-2.0
  9. #
  10. # Unless required by applicable law or agreed to in writing, software
  11. # distributed under the License is distributed on an "AS IS" BASIS,
  12. # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  13. # See the License for the specific language governing permissions and
  14. # limitations under the License.
  15. """PyTorch IdeficsVision model: a copy of CLIPVisionModel using a simpler config object"""
  16. import math
  17. from dataclasses import dataclass
  18. from typing import Callable, Optional, Union
  19. import torch
  20. from torch import nn
  21. from ...activations import ACT2FN
  22. from ...modeling_layers import GradientCheckpointingLayer
  23. from ...modeling_outputs import BaseModelOutput, BaseModelOutputWithPooling
  24. from ...modeling_utils import ALL_ATTENTION_FUNCTIONS
  25. from ...utils import (
  26. ModelOutput,
  27. can_return_tuple,
  28. logging,
  29. )
  30. from .configuration_idefics import IdeficsVisionConfig
  31. logger = logging.get_logger(__name__)
  32. @dataclass
  33. class IdeficsVisionModelOutput(ModelOutput):
  34. """
  35. Base class for vision model's outputs that also contains image embeddings of the pooling of the last hidden states.
  36. Args:
  37. image_embeds (`torch.FloatTensor` of shape `(batch_size, output_dim)` *optional* returned when model is initialized with `with_projection=True`):
  38. The image embeddings obtained by applying the projection layer to the pooler_output.
  39. last_hidden_state (`torch.FloatTensor` of shape `(batch_size, sequence_length, hidden_size)`):
  40. Sequence of hidden-states at the output of the last layer of the model.
  41. hidden_states (`tuple(torch.FloatTensor)`, *optional*, returned when `output_hidden_states=True` is passed or when `config.output_hidden_states=True`):
  42. Tuple of `torch.FloatTensor` (one for the output of the embeddings, if the model has an embedding layer, +
  43. one for the output of each layer) of shape `(batch_size, sequence_length, hidden_size)`.
  44. Hidden-states of the model at the output of each layer plus the optional initial embedding outputs.
  45. attentions (`tuple(torch.FloatTensor)`, *optional*, returned when `output_attentions=True` is passed or when `config.output_attentions=True`):
  46. Tuple of `torch.FloatTensor` (one for each layer) of shape `(batch_size, num_heads, sequence_length,
  47. sequence_length)`.
  48. Attentions weights after the attention softmax, used to compute the weighted average in the self-attention
  49. heads.
  50. """
  51. image_embeds: Optional[torch.FloatTensor] = None
  52. last_hidden_state: Optional[torch.FloatTensor] = None
  53. hidden_states: Optional[tuple[torch.FloatTensor, ...]] = None
  54. attentions: Optional[tuple[torch.FloatTensor, ...]] = None
  55. # Adapted from transformers.models.clip.modeling_clip.CLIPVisionEmbeddings
  56. class IdeficsVisionEmbeddings(nn.Module):
  57. def __init__(self, config: IdeficsVisionConfig):
  58. super().__init__()
  59. self.config = config
  60. self.embed_dim = config.hidden_size
  61. self.image_size = config.image_size
  62. self.patch_size = config.patch_size
  63. self.class_embedding = nn.Parameter(torch.randn(self.embed_dim))
  64. self.patch_embedding = nn.Conv2d(
  65. in_channels=config.num_channels,
  66. out_channels=self.embed_dim,
  67. kernel_size=self.patch_size,
  68. stride=self.patch_size,
  69. bias=False,
  70. )
  71. self.num_patches = (self.image_size // self.patch_size) ** 2
  72. self.num_positions = self.num_patches + 1
  73. self.position_embedding = nn.Embedding(self.num_positions, self.embed_dim)
  74. self.register_buffer("position_ids", torch.arange(self.num_positions).expand((1, -1)), persistent=False)
  75. # Heavily inspired from https://github.com/huggingface/transformers/blob/v4.33.0/src/transformers/models/vit/modeling_vit.py#L82
  76. def interpolate_pos_encoding(self, embeddings: torch.Tensor, height: int, width: int) -> torch.Tensor:
  77. """
  78. This method allows to interpolate the pre-trained position encodings, to be able to use the model on higher
  79. resolution images.
  80. Source:
  81. https://github.com/facebookresearch/dino/blob/de9ee3df6cf39fac952ab558447af1fa1365362a/vision_transformer.py#L174
  82. """
  83. num_patches = embeddings.shape[1] - 1
  84. pos_embed = self.position_embedding(self.position_ids)
  85. num_positions = pos_embed.shape[1] - 1
  86. if num_patches == num_positions and height == width:
  87. return pos_embed
  88. class_pos_embed = pos_embed[:, 0]
  89. patch_pos_embed = pos_embed[:, 1:]
  90. embed_dim = embeddings.shape[-1]
  91. num_h_patches = height // self.config.patch_size
  92. num_w_patches = width // self.config.patch_size
  93. # we add a small number to avoid floating point error in the interpolation
  94. # see discussion at https://github.com/facebookresearch/dino/issues/8
  95. num_h_patches, num_w_patches = num_h_patches + 0.1, num_w_patches + 0.1
  96. sqrt_num_positions = math.sqrt(num_positions)
  97. patch_pos_embed = patch_pos_embed.reshape(1, int(sqrt_num_positions), int(sqrt_num_positions), embed_dim)
  98. patch_pos_embed = patch_pos_embed.permute(0, 3, 1, 2)
  99. fp32_upcasting = patch_pos_embed.dtype == torch.bfloat16
  100. if fp32_upcasting:
  101. logger.warning_once(
  102. "Upcasting patch_pos_embed to fp32 for interpolation since `upsample_bicubic2d_out_frame` in nn.functional.interpolate "
  103. "is not implemented for 'torch.bfloat16' dtype. This will result in a slight overhead."
  104. )
  105. patch_pos_embed = patch_pos_embed.to(torch.float)
  106. patch_pos_embed = nn.functional.interpolate(
  107. patch_pos_embed,
  108. scale_factor=(num_h_patches / sqrt_num_positions, num_w_patches / sqrt_num_positions),
  109. mode="bicubic",
  110. align_corners=False,
  111. )
  112. if fp32_upcasting:
  113. patch_pos_embed = patch_pos_embed.to(torch.bfloat16)
  114. if int(num_h_patches) != patch_pos_embed.shape[-2] or int(num_w_patches) != patch_pos_embed.shape[-1]:
  115. raise ValueError(
  116. f"Number of patches for images ({int(num_h_patches), int(num_w_patches)}) don't match the "
  117. f"shape of position embedding ({patch_pos_embed.shape[-2], patch_pos_embed.shape[-1]})"
  118. )
  119. patch_pos_embed = patch_pos_embed.permute(0, 2, 3, 1).view(1, -1, embed_dim)
  120. return torch.cat((class_pos_embed.unsqueeze(0), patch_pos_embed), dim=1)
  121. def forward(self, pixel_values: torch.FloatTensor, interpolate_pos_encoding: bool = False) -> torch.Tensor:
  122. batch_size, num_channels, height, width = pixel_values.shape
  123. if not interpolate_pos_encoding:
  124. if height != self.image_size or width != self.image_size:
  125. raise ValueError(
  126. f"Input image size ({height}*{width}) doesn't match model"
  127. f" ({self.image_size}*{self.image_size}). You should try to set `interpolate_pos_encoding=True`"
  128. )
  129. target_dtype = self.patch_embedding.weight.dtype
  130. patch_embeds = self.patch_embedding(pixel_values.to(dtype=target_dtype)) # shape = [*, width, grid, grid]
  131. patch_embeds = patch_embeds.flatten(2).transpose(1, 2)
  132. class_embeds = self.class_embedding.expand(batch_size, 1, -1)
  133. embeddings = torch.cat([class_embeds, patch_embeds], dim=1)
  134. # add positional encoding to each token
  135. if interpolate_pos_encoding:
  136. embeddings = embeddings + self.interpolate_pos_encoding(embeddings, height, width)
  137. else:
  138. embeddings = embeddings + self.position_embedding(self.position_ids)
  139. return embeddings
  140. # Copied from transformers.models.siglip.modeling_siglip.eager_attention_forward
  141. def eager_attention_forward(
  142. module: nn.Module,
  143. query: torch.Tensor,
  144. key: torch.Tensor,
  145. value: torch.Tensor,
  146. attention_mask: Optional[torch.Tensor],
  147. scaling: float,
  148. dropout: float = 0.0,
  149. **kwargs,
  150. ):
  151. attn_weights = torch.matmul(query, key.transpose(-1, -2)) * scaling
  152. if attention_mask is not None:
  153. attn_weights = attn_weights + attention_mask
  154. attn_weights = nn.functional.softmax(attn_weights, dim=-1, dtype=torch.float32).to(query.dtype)
  155. attn_weights = nn.functional.dropout(attn_weights, p=dropout, training=module.training)
  156. attn_output = torch.matmul(attn_weights, value)
  157. attn_output = attn_output.transpose(1, 2).contiguous()
  158. return attn_output, attn_weights
  159. class IdeficsVisionAttention(nn.Module):
  160. """Multi-headed attention from 'Attention Is All You Need' paper"""
  161. def __init__(self, config: IdeficsVisionConfig):
  162. super().__init__()
  163. self.config = config
  164. self.embed_dim = config.hidden_size
  165. self.num_heads = config.num_attention_heads
  166. self.head_dim = self.embed_dim // self.num_heads
  167. if self.head_dim * self.num_heads != self.embed_dim:
  168. raise ValueError(
  169. f"embed_dim must be divisible by num_heads (got `embed_dim`: {self.embed_dim} and `num_heads`:"
  170. f" {self.num_heads})."
  171. )
  172. self.scale = self.head_dim**-0.5
  173. self.dropout = config.attention_dropout
  174. self.is_causal = False
  175. self.k_proj = nn.Linear(self.embed_dim, self.embed_dim)
  176. self.v_proj = nn.Linear(self.embed_dim, self.embed_dim)
  177. self.q_proj = nn.Linear(self.embed_dim, self.embed_dim)
  178. self.out_proj = nn.Linear(self.embed_dim, self.embed_dim)
  179. def forward(
  180. self,
  181. hidden_states: torch.Tensor,
  182. attention_mask: Optional[torch.Tensor] = None,
  183. causal_attention_mask: Optional[torch.Tensor] = None,
  184. output_attentions: Optional[bool] = False,
  185. ) -> tuple[torch.Tensor, Optional[torch.Tensor]]:
  186. """Input shape: Batch x Time x Channel"""
  187. batch_size, seq_length, embed_dim = hidden_states.shape
  188. queries = self.q_proj(hidden_states)
  189. keys = self.k_proj(hidden_states)
  190. values = self.v_proj(hidden_states)
  191. queries = queries.view(batch_size, seq_length, self.num_heads, self.head_dim).transpose(1, 2)
  192. keys = keys.view(batch_size, seq_length, self.num_heads, self.head_dim).transpose(1, 2)
  193. values = values.view(batch_size, seq_length, self.num_heads, self.head_dim).transpose(1, 2)
  194. # CLIP text model uses both `causal_attention_mask` and `attention_mask`
  195. # in case FA2 kernel is called, `is_causal` should be inferred from `causal_attention_mask`
  196. if self.config._attn_implementation != "flash_attention_2":
  197. if attention_mask is not None and causal_attention_mask is not None:
  198. attention_mask = attention_mask + causal_attention_mask
  199. elif causal_attention_mask is not None:
  200. attention_mask = causal_attention_mask
  201. else:
  202. self.is_causal = causal_attention_mask is not None
  203. attention_interface: Callable = eager_attention_forward
  204. if self.config._attn_implementation != "eager":
  205. attention_interface = ALL_ATTENTION_FUNCTIONS[self.config._attn_implementation]
  206. attn_output, attn_weights = attention_interface(
  207. self,
  208. queries,
  209. keys,
  210. values,
  211. attention_mask,
  212. is_causal=self.is_causal,
  213. scaling=self.scale,
  214. dropout=0.0 if not self.training else self.dropout,
  215. )
  216. attn_output = attn_output.reshape(batch_size, seq_length, embed_dim).contiguous()
  217. attn_output = self.out_proj(attn_output)
  218. if not output_attentions:
  219. attn_weights = None
  220. return attn_output, attn_weights
  221. # Copied from transformers.models.clip.modeling_clip.CLIPMLP with CLIP->IdeficsVision
  222. class IdeficsVisionMLP(nn.Module):
  223. def __init__(self, config):
  224. super().__init__()
  225. self.config = config
  226. self.activation_fn = ACT2FN[config.hidden_act]
  227. self.fc1 = nn.Linear(config.hidden_size, config.intermediate_size)
  228. self.fc2 = nn.Linear(config.intermediate_size, config.hidden_size)
  229. def forward(self, hidden_states: torch.Tensor) -> torch.Tensor:
  230. hidden_states = self.fc1(hidden_states)
  231. hidden_states = self.activation_fn(hidden_states)
  232. hidden_states = self.fc2(hidden_states)
  233. return hidden_states
  234. # Copied from transformers.models.altclip.modeling_altclip.AltCLIPEncoderLayer with AltCLIP->IdeficsVision
  235. class IdeficsVisionEncoderLayer(GradientCheckpointingLayer):
  236. def __init__(self, config: IdeficsVisionConfig):
  237. super().__init__()
  238. self.embed_dim = config.hidden_size
  239. self.self_attn = IdeficsVisionAttention(config)
  240. self.layer_norm1 = nn.LayerNorm(self.embed_dim, eps=config.layer_norm_eps)
  241. self.mlp = IdeficsVisionMLP(config)
  242. self.layer_norm2 = nn.LayerNorm(self.embed_dim, eps=config.layer_norm_eps)
  243. def forward(
  244. self,
  245. hidden_states: torch.Tensor,
  246. attention_mask: torch.Tensor,
  247. causal_attention_mask: torch.Tensor,
  248. output_attentions: Optional[bool] = False,
  249. ) -> tuple[torch.FloatTensor]:
  250. """
  251. Args:
  252. hidden_states (`torch.FloatTensor`): input to the layer of shape `(batch, seq_len, embed_dim)`
  253. attention_mask (`torch.FloatTensor`): attention mask of size
  254. `(batch, 1, tgt_len, src_len)` where padding elements are indicated by very large negative values.
  255. `(config.encoder_attention_heads,)`.
  256. output_attentions (`bool`, *optional*):
  257. Whether or not to return the attentions tensors of all attention layers. See `attentions` under
  258. returned tensors for more detail.
  259. """
  260. residual = hidden_states
  261. hidden_states = self.layer_norm1(hidden_states)
  262. hidden_states, attn_weights = self.self_attn(
  263. hidden_states=hidden_states,
  264. attention_mask=attention_mask,
  265. causal_attention_mask=causal_attention_mask,
  266. output_attentions=output_attentions,
  267. )
  268. hidden_states = residual + hidden_states
  269. residual = hidden_states
  270. hidden_states = self.layer_norm2(hidden_states)
  271. hidden_states = self.mlp(hidden_states)
  272. hidden_states = residual + hidden_states
  273. outputs = (hidden_states,)
  274. if output_attentions:
  275. outputs += (attn_weights,)
  276. return outputs
  277. # Copied from transformers.models.altclip.modeling_altclip.AltCLIPEncoder with AltCLIP->IdeficsVision
  278. class IdeficsVisionEncoder(nn.Module):
  279. """
  280. Transformer encoder consisting of `config.num_hidden_layers` self attention layers. Each layer is a
  281. [`IdeficsVisionEncoderLayer`].
  282. Args:
  283. config: IdeficsVisionConfig
  284. """
  285. def __init__(self, config: IdeficsVisionConfig):
  286. super().__init__()
  287. self.config = config
  288. self.layers = nn.ModuleList([IdeficsVisionEncoderLayer(config) for _ in range(config.num_hidden_layers)])
  289. self.gradient_checkpointing = False
  290. @can_return_tuple
  291. def forward(
  292. self,
  293. inputs_embeds,
  294. attention_mask: Optional[torch.Tensor] = None,
  295. causal_attention_mask: Optional[torch.Tensor] = None,
  296. output_attentions: Optional[bool] = None,
  297. output_hidden_states: Optional[bool] = None,
  298. return_dict: Optional[bool] = None,
  299. ) -> Union[tuple, BaseModelOutput]:
  300. r"""
  301. Args:
  302. inputs_embeds (`torch.FloatTensor` of shape `(batch_size, sequence_length, hidden_size)`):
  303. Optionally, instead of passing `input_ids` you can choose to directly pass an embedded representation.
  304. This is useful if you want more control over how to convert `input_ids` indices into associated vectors
  305. than the model's internal embedding lookup matrix.
  306. attention_mask (`torch.Tensor` of shape `(batch_size, sequence_length)`, *optional*):
  307. Mask to avoid performing attention on padding token indices. Mask values selected in `[0, 1]`:
  308. - 1 for tokens that are **not masked**,
  309. - 0 for tokens that are **masked**.
  310. [What are attention masks?](../glossary#attention-mask)
  311. causal_attention_mask (`torch.Tensor` of shape `(batch_size, sequence_length)`, *optional*):
  312. Causal mask for the text model. Mask values selected in `[0, 1]`:
  313. - 1 for tokens that are **not masked**,
  314. - 0 for tokens that are **masked**.
  315. [What are attention masks?](../glossary#attention-mask)
  316. output_attentions (`bool`, *optional*):
  317. Whether or not to return the attentions tensors of all attention layers. See `attentions` under
  318. returned tensors for more detail.
  319. output_hidden_states (`bool`, *optional*):
  320. Whether or not to return the hidden states of all layers. See `hidden_states` under returned tensors
  321. for more detail.
  322. return_dict (`bool`, *optional*):
  323. Whether or not to return a [`~utils.ModelOutput`] instead of a plain tuple.
  324. """
  325. output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
  326. output_hidden_states = (
  327. output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
  328. )
  329. return_dict = return_dict if return_dict is not None else self.config.use_return_dict
  330. encoder_states = () if output_hidden_states else None
  331. all_attentions = () if output_attentions else None
  332. hidden_states = inputs_embeds
  333. for idx, encoder_layer in enumerate(self.layers):
  334. if output_hidden_states:
  335. encoder_states = encoder_states + (hidden_states,)
  336. layer_outputs = encoder_layer(
  337. hidden_states,
  338. attention_mask,
  339. causal_attention_mask,
  340. output_attentions=output_attentions,
  341. )
  342. hidden_states = layer_outputs[0]
  343. if output_attentions:
  344. all_attentions = all_attentions + (layer_outputs[1],)
  345. if output_hidden_states:
  346. encoder_states = encoder_states + (hidden_states,)
  347. return BaseModelOutput(
  348. last_hidden_state=hidden_states, hidden_states=encoder_states, attentions=all_attentions
  349. )
  350. # Adapted from transformers.models.clip.modeling_clip.CLIPVisionTransformer
  351. class IdeficsVisionTransformer(nn.Module):
  352. def __init__(self, config: IdeficsVisionConfig):
  353. super().__init__()
  354. self.config = config
  355. embed_dim = config.hidden_size
  356. self.embeddings = IdeficsVisionEmbeddings(config)
  357. self.pre_layrnorm = nn.LayerNorm(embed_dim, eps=config.layer_norm_eps)
  358. self.encoder = IdeficsVisionEncoder(config)
  359. self.post_layernorm = nn.LayerNorm(embed_dim, eps=config.layer_norm_eps)
  360. # Adapted from transformers.models.clip.modeling_clip.CLIPVisionTransformer.forward
  361. def forward(
  362. self,
  363. pixel_values: Optional[torch.FloatTensor] = None,
  364. output_attentions: Optional[bool] = None,
  365. output_hidden_states: Optional[bool] = None,
  366. interpolate_pos_encoding: Optional[bool] = False,
  367. return_dict: Optional[bool] = None,
  368. ) -> Union[tuple, BaseModelOutputWithPooling]:
  369. r"""
  370. Returns:
  371. """
  372. output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
  373. output_hidden_states = (
  374. output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
  375. )
  376. return_dict = return_dict if return_dict is not None else self.config.use_return_dict
  377. if pixel_values is None:
  378. raise ValueError("You have to specify pixel_values")
  379. hidden_states = self.embeddings(pixel_values, interpolate_pos_encoding=interpolate_pos_encoding)
  380. hidden_states = self.pre_layrnorm(hidden_states)
  381. encoder_outputs = self.encoder(
  382. inputs_embeds=hidden_states,
  383. output_attentions=output_attentions,
  384. output_hidden_states=output_hidden_states,
  385. return_dict=return_dict,
  386. )
  387. last_hidden_state = encoder_outputs[0]
  388. pooled_output = last_hidden_state[:, 0, :]
  389. pooled_output = self.post_layernorm(pooled_output)
  390. if not return_dict:
  391. return (last_hidden_state, pooled_output) + encoder_outputs[1:]
  392. return BaseModelOutputWithPooling(
  393. last_hidden_state=last_hidden_state,
  394. pooler_output=pooled_output,
  395. hidden_states=encoder_outputs.hidden_states,
  396. attentions=encoder_outputs.attentions,
  397. )