modeling_mlcd.py 27 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525526527528529530531532533534535536537538539540541542543544545546547548549550551552553554555556557558559560561562563564565566567568569570571572573574575576577578579580581582583584585586587588589590591592593594595596597598599600601602603604605606607608609610611
  1. # 🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨
  2. # This file was automatically generated from src/transformers/models/mlcd/modular_mlcd.py.
  3. # Do NOT edit this file manually as any edits will be overwritten by the generation of
  4. # the file from the modular. If any change should be done, please apply the change to the
  5. # modular_mlcd.py file directly. One of our CI enforces this.
  6. # 🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨
  7. # coding=utf-8
  8. # Copyright 2025 The HuggingFace Inc. team.
  9. #
  10. # Licensed under the Apache License, Version 2.0 (the "License");
  11. # you may not use this file except in compliance with the License.
  12. # You may obtain a copy of the License at
  13. #
  14. # http://www.apache.org/licenses/LICENSE-2.0
  15. #
  16. # Unless required by applicable law or agreed to in writing, software
  17. # distributed under the License is distributed on an "AS IS" BASIS,
  18. # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  19. # See the License for the specific language governing permissions and
  20. # limitations under the License.
  21. from typing import Callable, Optional, Union
  22. import torch
  23. import torch.nn as nn
  24. from ...activations import ACT2FN
  25. from ...modeling_flash_attention_utils import FlashAttentionKwargs
  26. from ...modeling_layers import GradientCheckpointingLayer
  27. from ...modeling_outputs import BaseModelOutput, BaseModelOutputWithPooling
  28. from ...modeling_utils import ALL_ATTENTION_FUNCTIONS, PreTrainedModel
  29. from ...processing_utils import Unpack
  30. from ...utils import TransformersKwargs, auto_docstring, torch_int
  31. from .configuration_mlcd import MLCDVisionConfig
  32. class MLCDMLP(nn.Module):
  33. def __init__(self, config):
  34. super().__init__()
  35. self.config = config
  36. self.activation_fn = ACT2FN[config.hidden_act]
  37. self.fc1 = nn.Linear(config.hidden_size, config.intermediate_size)
  38. self.fc2 = nn.Linear(config.intermediate_size, config.hidden_size)
  39. def forward(self, hidden_states: torch.Tensor) -> torch.Tensor:
  40. hidden_states = self.fc1(hidden_states)
  41. hidden_states = self.activation_fn(hidden_states)
  42. hidden_states = self.fc2(hidden_states)
  43. return hidden_states
  44. class MLCDRotaryEmbedding(nn.Module):
  45. inv_freq: torch.Tensor # fix linting for `register_buffer`
  46. def __init__(self, dim: int, theta: float = 10000.0) -> None:
  47. super().__init__()
  48. inv_freq = 1.0 / (theta ** (torch.arange(0, dim, 2, dtype=torch.float) / dim))
  49. self.register_buffer("inv_freq", inv_freq, persistent=False)
  50. def forward(self, num_patches_height: int, num_patches_width: int) -> torch.Tensor:
  51. """
  52. Calculate the Rotary Position Embedding (RoPE) for MLCDVisionModel based on the grid size.
  53. Args:
  54. num_patches_height (int): Number of patches in the height dimension.
  55. num_patches_width (int): Number of patches in the width dimension.
  56. Returns:
  57. torch.Tensor: Rotary positional embeddings for the given grid size.
  58. """
  59. # Generate position IDs for height and width dimensions
  60. hpos_ids = (
  61. torch.arange(num_patches_height, device=self.inv_freq.device).unsqueeze(1).expand(-1, num_patches_width)
  62. )
  63. wpos_ids = (
  64. torch.arange(num_patches_width, device=self.inv_freq.device).unsqueeze(0).expand(num_patches_height, -1)
  65. )
  66. # Flatten and stack the position IDs
  67. pos_ids = torch.stack([hpos_ids.flatten(), wpos_ids.flatten()], dim=-1)
  68. # Generate the full rotary positional embeddings for the maximum grid size
  69. max_grid_size = max(num_patches_height, num_patches_width)
  70. seq = torch.arange(max_grid_size, device=self.inv_freq.device, dtype=self.inv_freq.dtype)
  71. rotary_pos_emb_full = torch.outer(seq, self.inv_freq)
  72. # Select and flatten the embeddings based on the position IDs
  73. rotary_pos_emb = rotary_pos_emb_full[pos_ids].flatten(1)
  74. return rotary_pos_emb
  75. class MLCDVisionEmbeddings(nn.Module):
  76. def __init__(self, config: MLCDVisionConfig):
  77. super().__init__()
  78. self.config = config
  79. self.embed_dim = config.hidden_size
  80. self.image_size = config.image_size
  81. self.patch_size = config.patch_size
  82. self.class_embedding = nn.Parameter(torch.randn(self.embed_dim))
  83. self.patch_embedding = nn.Conv2d(
  84. in_channels=config.num_channels,
  85. out_channels=self.embed_dim,
  86. kernel_size=self.patch_size,
  87. stride=self.patch_size,
  88. bias=False,
  89. )
  90. self.num_patches = (self.image_size // self.patch_size) ** 2
  91. self.num_positions = self.num_patches + 1
  92. self.register_buffer("position_ids", torch.arange(self.num_positions).expand((1, -1)), persistent=False)
  93. def interpolate_pos_encoding(self, embeddings: torch.Tensor, height: int, width: int) -> torch.Tensor:
  94. """
  95. This method allows to interpolate the pre-trained position encodings, to be able to use the model on higher resolution
  96. images. This method is also adapted to support torch.jit tracing.
  97. Adapted from:
  98. - https://github.com/facebookresearch/dino/blob/de9ee3df6cf39fac952ab558447af1fa1365362a/vision_transformer.py#L174-L194, and
  99. - https://github.com/facebookresearch/dinov2/blob/e1277af2ba9496fbadf7aec6eba56e8d882d1e35/dinov2/models/vision_transformer.py#L179-L211
  100. """
  101. num_patches = embeddings.shape[1] - 1
  102. position_embedding = self.position_embedding.weight.unsqueeze(0)
  103. num_positions = position_embedding.shape[1] - 1
  104. # always interpolate when tracing to ensure the exported model works for dynamic input shapes
  105. if not torch.jit.is_tracing() and num_patches == num_positions and height == width:
  106. return self.position_embedding(self.position_ids)
  107. class_pos_embed = position_embedding[:, :1]
  108. patch_pos_embed = position_embedding[:, 1:]
  109. dim = embeddings.shape[-1]
  110. new_height = height // self.patch_size
  111. new_width = width // self.patch_size
  112. sqrt_num_positions = torch_int(num_positions**0.5)
  113. patch_pos_embed = patch_pos_embed.reshape(1, sqrt_num_positions, sqrt_num_positions, dim)
  114. patch_pos_embed = patch_pos_embed.permute(0, 3, 1, 2)
  115. patch_pos_embed = nn.functional.interpolate(
  116. patch_pos_embed,
  117. size=(new_height, new_width),
  118. mode="bicubic",
  119. align_corners=False,
  120. )
  121. patch_pos_embed = patch_pos_embed.permute(0, 2, 3, 1).view(1, -1, dim)
  122. return torch.cat((class_pos_embed, patch_pos_embed), dim=1)
  123. def forward(self, pixel_values: torch.FloatTensor) -> torch.Tensor:
  124. batch_size = pixel_values.shape[0]
  125. target_dtype = self.patch_embedding.weight.dtype
  126. # patch_embeds -> shape = [batch, width, grid, grid]
  127. patch_embeds = self.patch_embedding(pixel_values.to(dtype=target_dtype))
  128. patch_embeds = patch_embeds.flatten(2).transpose(1, 2)
  129. class_embeds = self.class_embedding.expand(batch_size, 1, -1)
  130. embeddings = torch.cat([class_embeds, patch_embeds], dim=1)
  131. return embeddings
  132. def eager_attention_forward(
  133. module: nn.Module,
  134. query: torch.Tensor,
  135. key: torch.Tensor,
  136. value: torch.Tensor,
  137. attention_mask: Optional[torch.Tensor],
  138. scaling: float,
  139. dropout: float = 0.0,
  140. **kwargs: Unpack[TransformersKwargs],
  141. ):
  142. key_states = repeat_kv(key, module.num_key_value_groups)
  143. value_states = repeat_kv(value, module.num_key_value_groups)
  144. attn_weights = torch.matmul(query, key_states.transpose(2, 3)) * scaling
  145. if attention_mask is not None:
  146. causal_mask = attention_mask[:, :, :, : key_states.shape[-2]]
  147. attn_weights = attn_weights + causal_mask
  148. attn_weights = nn.functional.softmax(attn_weights, dim=-1, dtype=torch.float32).to(query.dtype)
  149. attn_weights = nn.functional.dropout(attn_weights, p=dropout, training=module.training)
  150. attn_output = torch.matmul(attn_weights, value_states)
  151. attn_output = attn_output.transpose(1, 2).contiguous()
  152. return attn_output, attn_weights
  153. def rotate_half(x):
  154. """Rotates half the hidden dims of the input."""
  155. x1 = x[..., : x.shape[-1] // 2]
  156. x2 = x[..., x.shape[-1] // 2 :]
  157. return torch.cat((-x2, x1), dim=-1)
  158. def repeat_kv(hidden_states: torch.Tensor, n_rep: int) -> torch.Tensor:
  159. """
  160. This is the equivalent of torch.repeat_interleave(x, dim=1, repeats=n_rep). The hidden states go from (batch,
  161. num_key_value_heads, seqlen, head_dim) to (batch, num_attention_heads, seqlen, head_dim)
  162. """
  163. batch, num_key_value_heads, slen, head_dim = hidden_states.shape
  164. if n_rep == 1:
  165. return hidden_states
  166. hidden_states = hidden_states[:, :, None, :, :].expand(batch, num_key_value_heads, n_rep, slen, head_dim)
  167. return hidden_states.reshape(batch, num_key_value_heads * n_rep, slen, head_dim)
  168. def apply_rotary_pos_emb_vision(
  169. q: torch.Tensor, k: torch.Tensor, cos: torch.Tensor, sin: torch.Tensor
  170. ) -> tuple[torch.Tensor, torch.Tensor]:
  171. orig_q_dtype = q.dtype
  172. orig_k_dtype = k.dtype
  173. q, k = q.float(), k.float()
  174. cos, sin = cos.unsqueeze(-2).float(), sin.unsqueeze(-2).float()
  175. q_embed = (q * cos) + (rotate_half(q) * sin)
  176. k_embed = (k * cos) + (rotate_half(k) * sin)
  177. q_embed = q_embed.to(orig_q_dtype)
  178. k_embed = k_embed.to(orig_k_dtype)
  179. return q_embed, k_embed
  180. class MLCDAttention(nn.Module):
  181. """Multi-headed attention with RoPE. Refer to papers:
  182. - Attention is all you need:
  183. https://huggingface.co/papers/1706.03762
  184. - RoFormer: Enhanced Transformer with Rotary Position Embedding:
  185. https://huggingface.co/papers/2104.09864
  186. """
  187. def __init__(self, config: MLCDVisionConfig):
  188. super().__init__()
  189. self.config = config
  190. self.embed_dim = config.hidden_size
  191. self.num_heads = config.num_attention_heads
  192. self.head_dim = self.embed_dim // self.num_heads
  193. if self.head_dim * self.num_heads != self.embed_dim:
  194. raise ValueError(
  195. f"embed_dim must be divisible by num_heads (got `embed_dim`: {self.embed_dim} and `num_heads`:"
  196. f" {self.num_heads})."
  197. )
  198. self.scale = self.head_dim**-0.5
  199. self.dropout = config.attention_dropout
  200. self.is_causal = False
  201. self.k_proj = nn.Linear(self.embed_dim, self.embed_dim)
  202. self.v_proj = nn.Linear(self.embed_dim, self.embed_dim)
  203. self.q_proj = nn.Linear(self.embed_dim, self.embed_dim)
  204. self.out_proj = nn.Linear(self.embed_dim, self.embed_dim)
  205. self.num_key_value_groups = config.num_key_value_groups
  206. def forward(
  207. self,
  208. hidden_states: torch.Tensor,
  209. position_embeddings: tuple[torch.Tensor, torch.Tensor],
  210. attention_mask: Optional[torch.Tensor] = None,
  211. **kwargs: Unpack[FlashAttentionKwargs],
  212. ) -> tuple[torch.Tensor, Optional[torch.Tensor]]:
  213. """Input shape: Batch x Time x Channel"""
  214. batch_size, seq_length = hidden_states.shape[:-1]
  215. # Each of shape: [batch_size, seq_length, num_heads, head_dim]
  216. query_states = self.q_proj(hidden_states).reshape((batch_size, seq_length, self.num_heads, self.head_dim))
  217. key_states = self.k_proj(hidden_states).reshape((batch_size, seq_length, self.num_heads, self.head_dim))
  218. value_states = self.v_proj(hidden_states).reshape((batch_size, seq_length, self.num_heads, self.head_dim))
  219. # Apply positional embeddings
  220. cos = position_embeddings[0].unsqueeze(0).float()
  221. sin = position_embeddings[1].unsqueeze(0).float()
  222. query_states, key_states = apply_rotary_pos_emb_vision(query_states, key_states, cos, sin)
  223. # Each of shape: [batch_size, num_heads, seq_length, head_dim]
  224. query_states = query_states.permute(0, 2, 1, 3).contiguous()
  225. key_states = key_states.permute(0, 2, 1, 3).contiguous()
  226. value_states = value_states.permute(0, 2, 1, 3).contiguous()
  227. attention_interface: Callable = eager_attention_forward
  228. if self.config._attn_implementation != "eager":
  229. attention_interface = ALL_ATTENTION_FUNCTIONS[self.config._attn_implementation]
  230. attn_output, attn_weights = attention_interface(
  231. self,
  232. query_states,
  233. key_states,
  234. value_states,
  235. attention_mask,
  236. dropout=0.0 if not self.training else self.dropout,
  237. scaling=self.scale,
  238. is_causal=self.is_causal,
  239. **kwargs,
  240. )
  241. attn_output = attn_output.permute(1, 0, 2, 3).contiguous() # [seq_length, batch_size, num_heads, head_dim]
  242. attn_output = attn_output.view(seq_length, batch_size, -1) # [seq_length, batch_size, embedding_dim]
  243. attn_output = self.out_proj(attn_output)
  244. attn_output = attn_output.permute(1, 0, 2).contiguous() # [batch_size, seq_length, embedding_dim]
  245. return attn_output, attn_weights
  246. class MLCDEncoderLayer(GradientCheckpointingLayer):
  247. def __init__(self, config: MLCDVisionConfig):
  248. super().__init__()
  249. self.embed_dim = config.hidden_size
  250. self.self_attn = MLCDAttention(config)
  251. self.layer_norm1 = nn.LayerNorm(self.embed_dim, eps=config.layer_norm_eps)
  252. self.mlp = MLCDMLP(config)
  253. self.layer_norm2 = nn.LayerNorm(self.embed_dim, eps=config.layer_norm_eps)
  254. def forward(
  255. self,
  256. hidden_states: torch.Tensor,
  257. position_embeddings: tuple[torch.Tensor, torch.Tensor],
  258. attention_mask: Optional[torch.Tensor] = None,
  259. output_attentions: Optional[bool] = False,
  260. ) -> tuple[torch.FloatTensor]:
  261. """
  262. Args:
  263. hidden_states (`torch.FloatTensor`):
  264. Input to the layer of shape `(batch, seq_len, embed_dim)`.
  265. Represents the hidden states from the previous layer or the input embeddings.
  266. position_embeddings (`tuple[torch.Tensor, torch.Tensor]`):
  267. A tuple of two tensors, each of shape `(batch, seq_len, embed_dim)`.
  268. Represents absolute positional embeddings for the query and key in the attention mechanism.
  269. attention_mask (`torch.FloatTensor`):
  270. Attention mask of shape `(batch, 1, q_len, k_v_seq_len)` where padding elements are indicated by very large negative values.
  271. output_attentions (`bool`, *optional*, defaults to `False`):
  272. Whether or not to return the attentions tensors of all attention layers. See `attentions` under
  273. returned tensors for more detail.
  274. """
  275. residual = hidden_states
  276. hidden_states = self.layer_norm1(hidden_states)
  277. hidden_states, attn_weights = self.self_attn(
  278. hidden_states=hidden_states,
  279. position_embeddings=position_embeddings,
  280. attention_mask=attention_mask,
  281. output_attentions=output_attentions,
  282. )
  283. hidden_states = residual + hidden_states
  284. residual = hidden_states
  285. hidden_states = self.layer_norm2(hidden_states)
  286. hidden_states = self.mlp(hidden_states)
  287. hidden_states = residual + hidden_states
  288. outputs = (hidden_states,)
  289. if output_attentions:
  290. outputs += (attn_weights,)
  291. return outputs
  292. class MLCDEncoder(nn.Module):
  293. """
  294. Transformer encoder consisting of `config.num_hidden_layers` self attention layers. Each layer is a
  295. [`MLCDEncoderLayer`].
  296. Args:
  297. config: MLCDVisionConfig
  298. """
  299. def __init__(self, config: MLCDVisionConfig):
  300. """Overwrite dummy `MLCDConfig` to `MLCDVisionConfig`."""
  301. super().__init__()
  302. self.config = config
  303. self.layers = nn.ModuleList([MLCDEncoderLayer(config) for _ in range(config.num_hidden_layers)])
  304. self.gradient_checkpointing = False
  305. def forward(
  306. self,
  307. inputs_embeds: torch.FloatTensor,
  308. position_embeddings: tuple[torch.Tensor, torch.Tensor],
  309. attention_mask: Optional[torch.Tensor] = None,
  310. output_attentions: Optional[bool] = None,
  311. output_hidden_states: Optional[bool] = None,
  312. return_dict: Optional[bool] = None,
  313. ) -> Union[tuple, BaseModelOutput]:
  314. r"""
  315. Args:
  316. inputs_embeds (`torch.FloatTensor` of shape `(batch_size, sequence_length, hidden_size)`):
  317. Optionally, instead of passing `input_ids` you can choose to directly pass an embedded representation.
  318. This is useful if you want more control over how to convert `input_ids` indices into associated vectors
  319. than the model's internal embedding lookup matrix.
  320. position_embeddings (`tuple[torch.Tensor, torch.Tensor]`):
  321. A tuple of two tensors, each of shape `(batch, seq_len, embed_dim)`.
  322. Represents absolute positional embeddings for the query and key in the attention mechanism.
  323. attention_mask (`torch.Tensor` of shape `(batch_size, sequence_length)`, *optional*):
  324. Mask to avoid performing attention on padding token indices. Mask values selected in `[0, 1]`:
  325. - 1 for tokens that are **not masked**,
  326. - 0 for tokens that are **masked**.
  327. [What are attention masks?](../glossary#attention-mask)
  328. output_attentions (`bool`, *optional*):
  329. Whether or not to return the attentions tensors of all attention layers. See `attentions` under
  330. returned tensors for more detail.
  331. output_hidden_states (`bool`, *optional*):
  332. Whether or not to return the hidden states of all layers. See `hidden_states` under returned tensors
  333. for more detail.
  334. return_dict (`bool`, *optional*):
  335. Whether or not to return a [`~utils.ModelOutput`] instead of a plain tuple.
  336. """
  337. output_hidden_states = (
  338. output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
  339. )
  340. return_dict = return_dict if return_dict is not None else self.config.use_return_dict
  341. output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
  342. encoder_states = () if output_hidden_states else None
  343. all_attentions = () if output_attentions else None
  344. hidden_states = inputs_embeds
  345. for idx, encoder_layer in enumerate(self.layers):
  346. if output_hidden_states:
  347. encoder_states = encoder_states + (hidden_states,)
  348. layer_outputs = encoder_layer(
  349. hidden_states=hidden_states,
  350. position_embeddings=position_embeddings,
  351. attention_mask=attention_mask,
  352. output_attentions=output_attentions,
  353. )
  354. hidden_states = layer_outputs[0]
  355. if output_attentions:
  356. all_attentions = all_attentions + (layer_outputs[1],)
  357. if output_hidden_states:
  358. encoder_states = encoder_states + (hidden_states,)
  359. if not return_dict:
  360. return tuple(v for v in [hidden_states, encoder_states, all_attentions] if v is not None)
  361. return BaseModelOutput(
  362. last_hidden_state=hidden_states,
  363. hidden_states=encoder_states,
  364. attentions=all_attentions,
  365. )
  366. class MLCDVisionTransformer(nn.Module):
  367. def __init__(self, config: MLCDVisionConfig):
  368. super().__init__()
  369. self.config = config
  370. embed_dim = config.hidden_size
  371. self.embeddings = MLCDVisionEmbeddings(config)
  372. self.pre_layrnorm = nn.LayerNorm(embed_dim, eps=config.layer_norm_eps)
  373. self.encoder = MLCDEncoder(config)
  374. self.post_layernorm = nn.LayerNorm(embed_dim, eps=config.layer_norm_eps)
  375. self.vision_rotary_embedding = MLCDRotaryEmbedding(config.hidden_size // config.num_attention_heads // 2)
  376. self.class_pos_emb = nn.Parameter(torch.randn(1, config.hidden_size // config.num_attention_heads // 2))
  377. @auto_docstring
  378. def forward(
  379. self,
  380. pixel_values: Optional[torch.FloatTensor] = None,
  381. output_attentions: Optional[bool] = None,
  382. output_hidden_states: Optional[bool] = None,
  383. return_dict: Optional[bool] = None,
  384. ) -> Union[tuple, BaseModelOutputWithPooling]:
  385. output_hidden_states = (
  386. output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
  387. )
  388. return_dict = return_dict if return_dict is not None else self.config.use_return_dict
  389. output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
  390. if pixel_values is None:
  391. raise ValueError("You have to specify pixel_values")
  392. num_patches_height = pixel_values.shape[-2] // self.config.patch_size
  393. num_patches_width = pixel_values.shape[-1] // self.config.patch_size
  394. rotary_pos_emb = self.vision_rotary_embedding(num_patches_height, num_patches_width)
  395. rotary_pos_emb = rotary_pos_emb.to(self.class_pos_emb.device)
  396. rotary_pos_emb = torch.cat([self.class_pos_emb, rotary_pos_emb], dim=0)
  397. emb = torch.cat((rotary_pos_emb, rotary_pos_emb), dim=-1)
  398. position_embeddings = (emb.cos(), emb.sin())
  399. hidden_states = self.embeddings(pixel_values)
  400. hidden_states = self.pre_layrnorm(hidden_states)
  401. encoder_outputs = self.encoder(
  402. inputs_embeds=hidden_states,
  403. position_embeddings=position_embeddings,
  404. output_attentions=output_attentions,
  405. output_hidden_states=output_hidden_states,
  406. return_dict=return_dict,
  407. )
  408. last_hidden_state = encoder_outputs[0]
  409. pooled_output = last_hidden_state[:, 0, :]
  410. pooled_output = self.post_layernorm(pooled_output)
  411. if not return_dict:
  412. return (last_hidden_state, pooled_output) + encoder_outputs[1:]
  413. return BaseModelOutputWithPooling(
  414. last_hidden_state=last_hidden_state,
  415. pooler_output=pooled_output,
  416. hidden_states=encoder_outputs.hidden_states,
  417. attentions=encoder_outputs.attentions,
  418. )
  419. @auto_docstring
  420. class MLCDPreTrainedModel(PreTrainedModel):
  421. config: MLCDVisionConfig
  422. base_model_prefix = "mlcd"
  423. supports_gradient_checkpointing = True
  424. _supports_flash_attn = True
  425. _supports_sdpa = True
  426. def _init_weights(self, module):
  427. """Initialize the weights"""
  428. factor = self.config.initializer_factor
  429. if isinstance(module, MLCDVisionEmbeddings):
  430. factor = self.config.initializer_factor
  431. nn.init.normal_(module.class_embedding, mean=0.0, std=module.embed_dim**-0.5 * factor)
  432. nn.init.normal_(module.patch_embedding.weight, std=module.config.initializer_range * factor)
  433. elif isinstance(module, MLCDAttention):
  434. factor = self.config.initializer_factor
  435. in_proj_std = (module.embed_dim**-0.5) * ((2 * module.config.num_hidden_layers) ** -0.5) * factor
  436. out_proj_std = (module.embed_dim**-0.5) * factor
  437. nn.init.normal_(module.q_proj.weight, std=in_proj_std)
  438. nn.init.normal_(module.k_proj.weight, std=in_proj_std)
  439. nn.init.normal_(module.v_proj.weight, std=in_proj_std)
  440. nn.init.normal_(module.out_proj.weight, std=out_proj_std)
  441. elif isinstance(module, MLCDMLP):
  442. factor = self.config.initializer_factor
  443. in_proj_std = (module.config.hidden_size**-0.5) * ((2 * module.config.num_hidden_layers) ** -0.5) * factor
  444. fc_std = (2 * module.config.hidden_size) ** -0.5 * factor
  445. nn.init.normal_(module.fc1.weight, std=fc_std)
  446. nn.init.normal_(module.fc2.weight, std=in_proj_std)
  447. elif isinstance(module, MLCDVisionTransformer):
  448. factor = self.config.initializer_factor
  449. pos_emb_std = (module.config.hidden_size // module.config.num_attention_heads // 2) ** -0.5 * factor
  450. nn.init.normal_(module.class_pos_emb, mean=0.0, std=pos_emb_std)
  451. elif isinstance(module, nn.LayerNorm):
  452. module.bias.data.zero_()
  453. module.weight.data.fill_(1.0)
  454. elif isinstance(module, nn.Linear) and module.bias is not None:
  455. module.bias.data.zero_()
  456. @auto_docstring(
  457. custom_intro="""
  458. The vision model from M_L_C_D without any head or projection on top.
  459. """
  460. )
  461. class MLCDVisionModel(MLCDPreTrainedModel):
  462. config: MLCDVisionConfig
  463. main_input_name = "pixel_values"
  464. _no_split_modules = ["MLCDEncoderLayer"]
  465. def __init__(self, config: MLCDVisionConfig):
  466. super().__init__(config)
  467. self.vision_model = MLCDVisionTransformer(config)
  468. # Initialize weights and apply final processing
  469. self.post_init()
  470. def get_input_embeddings(self) -> nn.Module:
  471. return self.vision_model.embeddings.patch_embedding
  472. @auto_docstring
  473. def forward(
  474. self,
  475. pixel_values: Optional[torch.FloatTensor] = None,
  476. output_attentions: Optional[bool] = None,
  477. output_hidden_states: Optional[bool] = None,
  478. return_dict: Optional[bool] = None,
  479. ) -> Union[tuple, BaseModelOutputWithPooling]:
  480. r"""
  481. Example:
  482. ```python
  483. >>> import requests
  484. >>> from PIL import Image
  485. >>> from transformers import AutoProcessor, MLCDVisionModel
  486. >>> model = MLCDVisionModel.from_pretrained("DeepGlint-AI/mlcd-vit-bigG-patch14-448")
  487. >>> processor = AutoProcessor.from_pretrained("DeepGlint-AI/mlcd-vit-bigG-patch14-448")
  488. >>> url = "http://images.cocodataset.org/val2017/000000039769.jpg"
  489. >>> image = Image.open(requests.get(url, stream=True).raw)
  490. >>> inputs = processor(images=image, return_tensors="pt")
  491. >>> with torch.no_grad():
  492. ... outputs = model(**inputs, output_attentions=True)
  493. >>> features = outputs.last_hidden_state
  494. >>> print(f"Extracted features shape: {features.shape}")
  495. >>> print(f"Number of attention layers: {len(outputs.attentions)}")
  496. >>> print(f"Attention shape: {outputs.attentions[0].shape}")
  497. ```"""
  498. output_hidden_states = (
  499. output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
  500. )
  501. return_dict = return_dict if return_dict is not None else self.config.use_return_dict
  502. output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
  503. return self.vision_model(
  504. pixel_values=pixel_values,
  505. output_attentions=output_attentions,
  506. output_hidden_states=output_hidden_states,
  507. return_dict=return_dict,
  508. )
  509. __all__ = ["MLCDPreTrainedModel", "MLCDVisionModel"]