modular_mlcd.py 23 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525526527528529
  1. # coding=utf-8
  2. # Copyright 2025 The HuggingFace Inc. team.
  3. #
  4. # Licensed under the Apache License, Version 2.0 (the "License");
  5. # you may not use this file except in compliance with the License.
  6. # You may obtain a copy of the License at
  7. #
  8. # http://www.apache.org/licenses/LICENSE-2.0
  9. #
  10. # Unless required by applicable law or agreed to in writing, software
  11. # distributed under the License is distributed on an "AS IS" BASIS,
  12. # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  13. # See the License for the specific language governing permissions and
  14. # limitations under the License.
  15. from typing import Callable, Optional, Union
  16. import torch
  17. import torch.nn as nn
  18. from ...configuration_utils import PretrainedConfig
  19. from ...modeling_flash_attention_utils import FlashAttentionKwargs
  20. from ...modeling_outputs import BaseModelOutput, BaseModelOutputWithPooling
  21. from ...modeling_utils import ALL_ATTENTION_FUNCTIONS, PreTrainedModel
  22. from ...processing_utils import Unpack
  23. from ...utils import auto_docstring, logging
  24. from ..clip.modeling_clip import (
  25. CLIPMLP,
  26. CLIPAttention,
  27. CLIPEncoder,
  28. CLIPEncoderLayer,
  29. CLIPVisionEmbeddings,
  30. CLIPVisionModel,
  31. CLIPVisionTransformer,
  32. )
  33. from ..llama.modeling_llama import eager_attention_forward
  34. from ..qwen2_vl.modeling_qwen2_vl import VisionRotaryEmbedding, apply_rotary_pos_emb_vision
  35. logger = logging.get_logger(__name__)
  36. class MLCDVisionConfig(PretrainedConfig):
  37. r"""
  38. This is the configuration class to store the configuration of a [`MLCDVisionModel`]. It is used to instantiate a MLCD
  39. vision encoder according to the specified arguments, defining the model architecture. Instantiating a configuration
  40. with the defaults will yield a similar configuration to that of the vision encoder of the MLCD
  41. [DeepGlint-AI/mlcd-vit-bigG-patch14-336](https://huggingface.co/DeepGlint-AI/mlcd-vit-bigG-patch14-336) architecture.
  42. Configuration objects inherit from [`PretrainedConfig`] and can be used to control the model outputs. Read the
  43. documentation from [`PretrainedConfig`] for more information.
  44. Args:
  45. hidden_size (`int`, *optional*, defaults to 1664):
  46. Dimensionality of the encoder layers and the pooler layer.
  47. intermediate_size (`int`, *optional*, defaults to 8192):
  48. Dimensionality of the "intermediate" (i.e., feed-forward) layer in the Transformer encoder.
  49. projection_dim (`int`, *optional*, defaults to 1024):
  50. Dimensionality of text and vision projection layers.
  51. num_hidden_layers (`int`, *optional*, defaults to 48):
  52. Number of hidden layers in the Transformer encoder.
  53. num_attention_heads (`int`, *optional*, defaults to 16):
  54. Number of attention heads for each attention layer in the Transformer encoder.
  55. num_channels (`int`, *optional*, defaults to 3):
  56. The number of input channels.
  57. image_size (`int`, *optional*, defaults to 336):
  58. The size (resolution) of each image.
  59. patch_size (`int`, *optional*, defaults to 14):
  60. The size (resolution) of each patch.
  61. hidden_act (`str` or `function`, *optional*, defaults to `"gelu"`):
  62. The non-linear activation function (function or string) in the encoder and pooler. If string, `"gelu"`,
  63. `"relu"`, `"selu"` and `"gelu_new"` `"quick_gelu"` are supported.
  64. layer_norm_eps (`float`, *optional*, defaults to 1e-05):
  65. The epsilon used by the layer normalization layers.
  66. attention_dropout (`float`, *optional*, defaults to 0.0):
  67. The dropout ratio for the attention probabilities.
  68. initializer_range (`float`, *optional*, defaults to 0.02):
  69. The standard deviation of the truncated_normal_initializer for initializing all weight matrices.
  70. initializer_factor (`float`, *optional*, defaults to 1.0):
  71. A factor for initializing all weight matrices (should be kept to 1, used internally for initialization
  72. testing).
  73. Example:
  74. ```python
  75. >>> from transformers import MLCDVisionConfig, MLCDVisionModel
  76. >>> # Initializing a MLCDVisionConfig with DeepGlint-AI/mlcd-vit-bigG-patch14-336 style configuration
  77. >>> configuration = MLCDVisionConfig()
  78. >>> # Initializing a MLCDVisionModel (with random weights) from the DeepGlint-AI/mlcd-vit-bigG-patch14-336 style configuration
  79. >>> model = MLCDVisionModel(configuration)
  80. >>> # Accessing the model configuration
  81. >>> configuration = model.config
  82. ```"""
  83. model_type = "mlcd_vision_model"
  84. base_config_key = "vision_config"
  85. def __init__(
  86. self,
  87. hidden_size=1664,
  88. intermediate_size=8192,
  89. num_hidden_layers=48,
  90. num_attention_heads=16,
  91. num_key_value_groups=1,
  92. num_channels=3,
  93. image_size=336,
  94. patch_size=14,
  95. hidden_act="gelu",
  96. layer_norm_eps=1e-5,
  97. attention_dropout=0.0,
  98. initializer_range=0.02,
  99. initializer_factor=1.0,
  100. **kwargs,
  101. ):
  102. super().__init__(**kwargs)
  103. self.hidden_size = hidden_size
  104. self.intermediate_size = intermediate_size
  105. self.num_hidden_layers = num_hidden_layers
  106. self.num_attention_heads = num_attention_heads
  107. self.num_key_value_groups = num_key_value_groups
  108. self.num_channels = num_channels
  109. self.patch_size = patch_size
  110. self.image_size = image_size
  111. self.initializer_range = initializer_range
  112. self.initializer_factor = initializer_factor
  113. self.attention_dropout = attention_dropout
  114. self.layer_norm_eps = layer_norm_eps
  115. self.hidden_act = hidden_act
  116. class MLCDMLP(CLIPMLP):
  117. pass
  118. class MLCDRotaryEmbedding(VisionRotaryEmbedding):
  119. def forward(self, num_patches_height: int, num_patches_width: int) -> torch.Tensor:
  120. """
  121. Calculate the Rotary Position Embedding (RoPE) for MLCDVisionModel based on the grid size.
  122. Args:
  123. num_patches_height (int): Number of patches in the height dimension.
  124. num_patches_width (int): Number of patches in the width dimension.
  125. Returns:
  126. torch.Tensor: Rotary positional embeddings for the given grid size.
  127. """
  128. # Generate position IDs for height and width dimensions
  129. hpos_ids = (
  130. torch.arange(num_patches_height, device=self.inv_freq.device).unsqueeze(1).expand(-1, num_patches_width)
  131. )
  132. wpos_ids = (
  133. torch.arange(num_patches_width, device=self.inv_freq.device).unsqueeze(0).expand(num_patches_height, -1)
  134. )
  135. # Flatten and stack the position IDs
  136. pos_ids = torch.stack([hpos_ids.flatten(), wpos_ids.flatten()], dim=-1)
  137. # Generate the full rotary positional embeddings for the maximum grid size
  138. max_grid_size = max(num_patches_height, num_patches_width)
  139. seq = torch.arange(max_grid_size, device=self.inv_freq.device, dtype=self.inv_freq.dtype)
  140. rotary_pos_emb_full = torch.outer(seq, self.inv_freq)
  141. # Select and flatten the embeddings based on the position IDs
  142. rotary_pos_emb = rotary_pos_emb_full[pos_ids].flatten(1)
  143. return rotary_pos_emb
  144. class MLCDVisionEmbeddings(CLIPVisionEmbeddings):
  145. def __init__(self, config: MLCDVisionConfig):
  146. super().__init__(config)
  147. del self.position_embedding
  148. def forward(self, pixel_values: torch.FloatTensor) -> torch.Tensor:
  149. batch_size = pixel_values.shape[0]
  150. target_dtype = self.patch_embedding.weight.dtype
  151. # patch_embeds -> shape = [batch, width, grid, grid]
  152. patch_embeds = self.patch_embedding(pixel_values.to(dtype=target_dtype))
  153. patch_embeds = patch_embeds.flatten(2).transpose(1, 2)
  154. class_embeds = self.class_embedding.expand(batch_size, 1, -1)
  155. embeddings = torch.cat([class_embeds, patch_embeds], dim=1)
  156. return embeddings
  157. class MLCDAttention(CLIPAttention):
  158. """Multi-headed attention with RoPE. Refer to papers:
  159. - Attention is all you need:
  160. https://huggingface.co/papers/1706.03762
  161. - RoFormer: Enhanced Transformer with Rotary Position Embedding:
  162. https://huggingface.co/papers/2104.09864
  163. """
  164. def __init__(self, config: MLCDVisionConfig):
  165. super().__init__(config)
  166. self.num_key_value_groups = config.num_key_value_groups
  167. self.is_causal = False
  168. def forward(
  169. self,
  170. hidden_states: torch.Tensor,
  171. position_embeddings: tuple[torch.Tensor, torch.Tensor],
  172. attention_mask: Optional[torch.Tensor] = None,
  173. **kwargs: Unpack[FlashAttentionKwargs],
  174. ) -> tuple[torch.Tensor, Optional[torch.Tensor]]:
  175. batch_size, seq_length = hidden_states.shape[:-1]
  176. # Each of shape: [batch_size, seq_length, num_heads, head_dim]
  177. query_states = self.q_proj(hidden_states).reshape((batch_size, seq_length, self.num_heads, self.head_dim))
  178. key_states = self.k_proj(hidden_states).reshape((batch_size, seq_length, self.num_heads, self.head_dim))
  179. value_states = self.v_proj(hidden_states).reshape((batch_size, seq_length, self.num_heads, self.head_dim))
  180. # Apply positional embeddings
  181. cos = position_embeddings[0].unsqueeze(0).float()
  182. sin = position_embeddings[1].unsqueeze(0).float()
  183. query_states, key_states = apply_rotary_pos_emb_vision(query_states, key_states, cos, sin)
  184. # Each of shape: [batch_size, num_heads, seq_length, head_dim]
  185. query_states = query_states.permute(0, 2, 1, 3).contiguous()
  186. key_states = key_states.permute(0, 2, 1, 3).contiguous()
  187. value_states = value_states.permute(0, 2, 1, 3).contiguous()
  188. attention_interface: Callable = eager_attention_forward
  189. if self.config._attn_implementation != "eager":
  190. attention_interface = ALL_ATTENTION_FUNCTIONS[self.config._attn_implementation]
  191. attn_output, attn_weights = attention_interface(
  192. self,
  193. query_states,
  194. key_states,
  195. value_states,
  196. attention_mask,
  197. dropout=0.0 if not self.training else self.dropout,
  198. scaling=self.scale,
  199. is_causal=self.is_causal,
  200. **kwargs,
  201. )
  202. attn_output = attn_output.permute(1, 0, 2, 3).contiguous() # [seq_length, batch_size, num_heads, head_dim]
  203. attn_output = attn_output.view(seq_length, batch_size, -1) # [seq_length, batch_size, embedding_dim]
  204. attn_output = self.out_proj(attn_output)
  205. attn_output = attn_output.permute(1, 0, 2).contiguous() # [batch_size, seq_length, embedding_dim]
  206. return attn_output, attn_weights
  207. class MLCDEncoderLayer(CLIPEncoderLayer):
  208. def __init__(self, config: MLCDVisionConfig):
  209. super().__init__(config)
  210. self.self_attn = MLCDAttention(config)
  211. def forward(
  212. self,
  213. hidden_states: torch.Tensor,
  214. position_embeddings: tuple[torch.Tensor, torch.Tensor],
  215. attention_mask: Optional[torch.Tensor] = None,
  216. output_attentions: Optional[bool] = False,
  217. ) -> tuple[torch.FloatTensor]:
  218. """
  219. Args:
  220. hidden_states (`torch.FloatTensor`):
  221. Input to the layer of shape `(batch, seq_len, embed_dim)`.
  222. Represents the hidden states from the previous layer or the input embeddings.
  223. position_embeddings (`tuple[torch.Tensor, torch.Tensor]`):
  224. A tuple of two tensors, each of shape `(batch, seq_len, embed_dim)`.
  225. Represents absolute positional embeddings for the query and key in the attention mechanism.
  226. attention_mask (`torch.FloatTensor`):
  227. Attention mask of shape `(batch, 1, q_len, k_v_seq_len)` where padding elements are indicated by very large negative values.
  228. output_attentions (`bool`, *optional*, defaults to `False`):
  229. Whether or not to return the attentions tensors of all attention layers. See `attentions` under
  230. returned tensors for more detail.
  231. """
  232. residual = hidden_states
  233. hidden_states = self.layer_norm1(hidden_states)
  234. hidden_states, attn_weights = self.self_attn(
  235. hidden_states=hidden_states,
  236. position_embeddings=position_embeddings,
  237. attention_mask=attention_mask,
  238. output_attentions=output_attentions,
  239. )
  240. hidden_states = residual + hidden_states
  241. residual = hidden_states
  242. hidden_states = self.layer_norm2(hidden_states)
  243. hidden_states = self.mlp(hidden_states)
  244. hidden_states = residual + hidden_states
  245. outputs = (hidden_states,)
  246. if output_attentions:
  247. outputs += (attn_weights,)
  248. return outputs
  249. class MLCDEncoder(CLIPEncoder):
  250. """
  251. Transformer encoder consisting of `config.num_hidden_layers` self attention layers. Each layer is a
  252. [`MLCDEncoderLayer`].
  253. Args:
  254. config: MLCDVisionConfig
  255. """
  256. def __init__(self, config: MLCDVisionConfig):
  257. """Overwrite dummy `MLCDConfig` to `MLCDVisionConfig`."""
  258. super().__init__(config)
  259. def forward(
  260. self,
  261. inputs_embeds: torch.FloatTensor,
  262. position_embeddings: tuple[torch.Tensor, torch.Tensor],
  263. attention_mask: Optional[torch.Tensor] = None,
  264. output_attentions: Optional[bool] = None,
  265. output_hidden_states: Optional[bool] = None,
  266. return_dict: Optional[bool] = None,
  267. ) -> Union[tuple, BaseModelOutput]:
  268. r"""
  269. Args:
  270. inputs_embeds (`torch.FloatTensor` of shape `(batch_size, sequence_length, hidden_size)`):
  271. Optionally, instead of passing `input_ids` you can choose to directly pass an embedded representation.
  272. This is useful if you want more control over how to convert `input_ids` indices into associated vectors
  273. than the model's internal embedding lookup matrix.
  274. position_embeddings (`tuple[torch.Tensor, torch.Tensor]`):
  275. A tuple of two tensors, each of shape `(batch, seq_len, embed_dim)`.
  276. Represents absolute positional embeddings for the query and key in the attention mechanism.
  277. attention_mask (`torch.Tensor` of shape `(batch_size, sequence_length)`, *optional*):
  278. Mask to avoid performing attention on padding token indices. Mask values selected in `[0, 1]`:
  279. - 1 for tokens that are **not masked**,
  280. - 0 for tokens that are **masked**.
  281. [What are attention masks?](../glossary#attention-mask)
  282. output_attentions (`bool`, *optional*):
  283. Whether or not to return the attentions tensors of all attention layers. See `attentions` under
  284. returned tensors for more detail.
  285. output_hidden_states (`bool`, *optional*):
  286. Whether or not to return the hidden states of all layers. See `hidden_states` under returned tensors
  287. for more detail.
  288. return_dict (`bool`, *optional*):
  289. Whether or not to return a [`~utils.ModelOutput`] instead of a plain tuple.
  290. """
  291. output_hidden_states = (
  292. output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
  293. )
  294. return_dict = return_dict if return_dict is not None else self.config.use_return_dict
  295. output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
  296. encoder_states = () if output_hidden_states else None
  297. all_attentions = () if output_attentions else None
  298. hidden_states = inputs_embeds
  299. for idx, encoder_layer in enumerate(self.layers):
  300. if output_hidden_states:
  301. encoder_states = encoder_states + (hidden_states,)
  302. layer_outputs = encoder_layer(
  303. hidden_states=hidden_states,
  304. position_embeddings=position_embeddings,
  305. attention_mask=attention_mask,
  306. output_attentions=output_attentions,
  307. )
  308. hidden_states = layer_outputs[0]
  309. if output_attentions:
  310. all_attentions = all_attentions + (layer_outputs[1],)
  311. if output_hidden_states:
  312. encoder_states = encoder_states + (hidden_states,)
  313. if not return_dict:
  314. return tuple(v for v in [hidden_states, encoder_states, all_attentions] if v is not None)
  315. return BaseModelOutput(
  316. last_hidden_state=hidden_states,
  317. hidden_states=encoder_states,
  318. attentions=all_attentions,
  319. )
  320. class MLCDVisionTransformer(CLIPVisionTransformer):
  321. def __init__(self, config: MLCDVisionConfig):
  322. super().__init__(config)
  323. self.vision_rotary_embedding = MLCDRotaryEmbedding(config.hidden_size // config.num_attention_heads // 2)
  324. self.class_pos_emb = nn.Parameter(torch.randn(1, config.hidden_size // config.num_attention_heads // 2))
  325. @auto_docstring
  326. def forward(
  327. self,
  328. pixel_values: Optional[torch.FloatTensor] = None,
  329. output_attentions: Optional[bool] = None,
  330. output_hidden_states: Optional[bool] = None,
  331. return_dict: Optional[bool] = None,
  332. ) -> Union[tuple, BaseModelOutputWithPooling]:
  333. output_hidden_states = (
  334. output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
  335. )
  336. return_dict = return_dict if return_dict is not None else self.config.use_return_dict
  337. output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
  338. if pixel_values is None:
  339. raise ValueError("You have to specify pixel_values")
  340. num_patches_height = pixel_values.shape[-2] // self.config.patch_size
  341. num_patches_width = pixel_values.shape[-1] // self.config.patch_size
  342. rotary_pos_emb = self.vision_rotary_embedding(num_patches_height, num_patches_width)
  343. rotary_pos_emb = rotary_pos_emb.to(self.class_pos_emb.device)
  344. rotary_pos_emb = torch.cat([self.class_pos_emb, rotary_pos_emb], dim=0)
  345. emb = torch.cat((rotary_pos_emb, rotary_pos_emb), dim=-1)
  346. position_embeddings = (emb.cos(), emb.sin())
  347. hidden_states = self.embeddings(pixel_values)
  348. hidden_states = self.pre_layrnorm(hidden_states)
  349. encoder_outputs = self.encoder(
  350. inputs_embeds=hidden_states,
  351. position_embeddings=position_embeddings,
  352. output_attentions=output_attentions,
  353. output_hidden_states=output_hidden_states,
  354. return_dict=return_dict,
  355. )
  356. last_hidden_state = encoder_outputs[0]
  357. pooled_output = last_hidden_state[:, 0, :]
  358. pooled_output = self.post_layernorm(pooled_output)
  359. if not return_dict:
  360. return (last_hidden_state, pooled_output) + encoder_outputs[1:]
  361. return BaseModelOutputWithPooling(
  362. last_hidden_state=last_hidden_state,
  363. pooler_output=pooled_output,
  364. hidden_states=encoder_outputs.hidden_states,
  365. attentions=encoder_outputs.attentions,
  366. )
  367. @auto_docstring
  368. class MLCDPreTrainedModel(PreTrainedModel):
  369. config: MLCDVisionConfig
  370. base_model_prefix = "mlcd"
  371. supports_gradient_checkpointing = True
  372. _supports_flash_attn = True
  373. _supports_sdpa = True
  374. def _init_weights(self, module):
  375. """Initialize the weights"""
  376. factor = self.config.initializer_factor
  377. if isinstance(module, MLCDVisionEmbeddings):
  378. factor = self.config.initializer_factor
  379. nn.init.normal_(module.class_embedding, mean=0.0, std=module.embed_dim**-0.5 * factor)
  380. nn.init.normal_(module.patch_embedding.weight, std=module.config.initializer_range * factor)
  381. elif isinstance(module, MLCDAttention):
  382. factor = self.config.initializer_factor
  383. in_proj_std = (module.embed_dim**-0.5) * ((2 * module.config.num_hidden_layers) ** -0.5) * factor
  384. out_proj_std = (module.embed_dim**-0.5) * factor
  385. nn.init.normal_(module.q_proj.weight, std=in_proj_std)
  386. nn.init.normal_(module.k_proj.weight, std=in_proj_std)
  387. nn.init.normal_(module.v_proj.weight, std=in_proj_std)
  388. nn.init.normal_(module.out_proj.weight, std=out_proj_std)
  389. elif isinstance(module, MLCDMLP):
  390. factor = self.config.initializer_factor
  391. in_proj_std = (module.config.hidden_size**-0.5) * ((2 * module.config.num_hidden_layers) ** -0.5) * factor
  392. fc_std = (2 * module.config.hidden_size) ** -0.5 * factor
  393. nn.init.normal_(module.fc1.weight, std=fc_std)
  394. nn.init.normal_(module.fc2.weight, std=in_proj_std)
  395. elif isinstance(module, MLCDVisionTransformer):
  396. factor = self.config.initializer_factor
  397. pos_emb_std = (module.config.hidden_size // module.config.num_attention_heads // 2) ** -0.5 * factor
  398. nn.init.normal_(module.class_pos_emb, mean=0.0, std=pos_emb_std)
  399. elif isinstance(module, nn.LayerNorm):
  400. module.bias.data.zero_()
  401. module.weight.data.fill_(1.0)
  402. elif isinstance(module, nn.Linear) and module.bias is not None:
  403. module.bias.data.zero_()
  404. class MLCDVisionModel(CLIPVisionModel):
  405. @auto_docstring
  406. def forward(
  407. self,
  408. pixel_values: Optional[torch.FloatTensor] = None,
  409. output_attentions: Optional[bool] = None,
  410. output_hidden_states: Optional[bool] = None,
  411. return_dict: Optional[bool] = None,
  412. ) -> Union[tuple, BaseModelOutputWithPooling]:
  413. r"""
  414. Example:
  415. ```python
  416. >>> import requests
  417. >>> from PIL import Image
  418. >>> from transformers import AutoProcessor, MLCDVisionModel
  419. >>> model = MLCDVisionModel.from_pretrained("DeepGlint-AI/mlcd-vit-bigG-patch14-448")
  420. >>> processor = AutoProcessor.from_pretrained("DeepGlint-AI/mlcd-vit-bigG-patch14-448")
  421. >>> url = "http://images.cocodataset.org/val2017/000000039769.jpg"
  422. >>> image = Image.open(requests.get(url, stream=True).raw)
  423. >>> inputs = processor(images=image, return_tensors="pt")
  424. >>> with torch.no_grad():
  425. ... outputs = model(**inputs, output_attentions=True)
  426. >>> features = outputs.last_hidden_state
  427. >>> print(f"Extracted features shape: {features.shape}")
  428. >>> print(f"Number of attention layers: {len(outputs.attentions)}")
  429. >>> print(f"Attention shape: {outputs.attentions[0].shape}")
  430. ```"""
  431. output_hidden_states = (
  432. output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
  433. )
  434. return_dict = return_dict if return_dict is not None else self.config.use_return_dict
  435. output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
  436. return self.vision_model(
  437. pixel_values=pixel_values,
  438. output_attentions=output_attentions,
  439. output_hidden_states=output_hidden_states,
  440. return_dict=return_dict,
  441. )
  442. __all__ = [
  443. "MLCDVisionConfig",
  444. "MLCDPreTrainedModel",
  445. "MLCDVisionModel",
  446. ]