modeling_ovis2.py 34 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525526527528529530531532533534535536537538539540541542543544545546547548549550551552553554555556557558559560561562563564565566567568569570571572573574575576577578579580581582583584585586587588589590591592593594595596597598599600601602603604605606607608609610611612613614615616617618619620621622623624625626627628629630631632633634635636637638639640641642643644645646647648649650651652653654655656657658659660661662663664665666667668669670671672673674675676677678679680681682683684685686687688689690691692693694695696697698699700701702703704705706707708709710711712713714715716717718719720721722723724725726727728729730731732733734735736737738739740741742743744745746747748749750751752753754755756757758759760761762763764765766767768769770771772773774775776777778779780781782783784785786787788789790791792793794795796797798799800801802803804805806807808809810811812813814815816817818819820821822823824825826827828
  1. # 🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨
  2. # This file was automatically generated from src/transformers/models/ovis2/modular_ovis2.py.
  3. # Do NOT edit this file manually as any edits will be overwritten by the generation of
  4. # the file from the modular. If any change should be done, please apply the change to the
  5. # modular_ovis2.py file directly. One of our CI enforces this.
  6. # 🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨
  7. # coding=utf-8
  8. # Copyright 2025 The HuggingFace Inc. team. All rights reserved.
  9. #
  10. # Licensed under the Apache License, Version 2.0 (the "License");
  11. # you may not use this file except in compliance with the License.
  12. # You may obtain a copy of the License at
  13. #
  14. # http://www.apache.org/licenses/LICENSE-2.0
  15. #
  16. # Unless required by applicable law or agreed to in writing, software
  17. # distributed under the License is distributed on an "AS IS" BASIS,
  18. # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  19. # See the License for the specific language governing permissions and
  20. # limitations under the License.
  21. import math
  22. from dataclasses import dataclass
  23. from typing import Callable, Optional, Union
  24. import torch
  25. from torch import nn
  26. from ...activations import ACT2FN
  27. from ...cache_utils import Cache
  28. from ...generation import GenerationMixin
  29. from ...integrations import use_kernel_forward_from_hub
  30. from ...modeling_layers import GradientCheckpointingLayer
  31. from ...modeling_outputs import BaseModelOutput, BaseModelOutputWithPast
  32. from ...modeling_utils import ALL_ATTENTION_FUNCTIONS, PreTrainedModel
  33. from ...processing_utils import Unpack
  34. from ...utils import ModelOutput, TransformersKwargs, auto_docstring, can_return_tuple
  35. from ..auto import AutoModel
  36. from .configuration_ovis2 import Ovis2Config, Ovis2VisionConfig
  37. @dataclass
  38. @auto_docstring(
  39. custom_intro="""
  40. Base class for Llava outputs, with hidden states and attentions.
  41. """
  42. )
  43. class Ovis2ModelOutputWithPast(BaseModelOutputWithPast):
  44. r"""
  45. past_key_values (`Cache`, *optional*, returned when `use_cache=True` is passed or when `config.use_cache=True`):
  46. It is a [`~cache_utils.Cache`] instance. For more details, see our [kv cache guide](https://huggingface.co/docs/transformers/en/kv_cache).
  47. Contains pre-computed hidden-states (key and values in the self-attention blocks) that can be used (see
  48. `past_key_values` input) to speed up sequential decoding.
  49. image_hidden_states (`torch.FloatTensor`, *optional*):
  50. A `torch.FloatTensor` of size `(batch_size, num_images, sequence_length, hidden_size)`.
  51. image_hidden_states of the model produced by the vision encoder and after projecting the last hidden state.
  52. """
  53. image_hidden_states: Optional[torch.FloatTensor] = None
  54. @dataclass
  55. @auto_docstring(
  56. custom_intro="""
  57. Base class for Ovis2 causal language model (or autoregressive) outputs.
  58. """
  59. )
  60. class Ovis2CausalLMOutputWithPast(ModelOutput):
  61. r"""
  62. loss (`torch.FloatTensor` of shape `(1,)`, *optional*, returned when `labels` is provided):
  63. Language modeling loss (for next-token prediction).
  64. logits (`torch.FloatTensor` of shape `(batch_size, sequence_length, config.vocab_size)`):
  65. Prediction scores of the language modeling head (scores for each vocabulary token before SoftMax).
  66. past_key_values (`Cache`, *optional*, returned when `use_cache=True` is passed or when `config.use_cache=True`):
  67. It is a [`~cache_utils.Cache`] instance. For more details, see our [kv cache guide](https://huggingface.co/docs/transformers/en/kv_cache).
  68. Contains pre-computed hidden-states (key and values in the self-attention blocks) that can be used (see
  69. `past_key_values` input) to speed up sequential decoding.
  70. image_hidden_states (`torch.FloatTensor`, *optional*):
  71. A `torch.FloatTensor` of size (batch_size * num_patches, num_images, sequence_length, hidden_size)`.
  72. image_hidden_states of the model produced by the vision encoder and after projecting the last hidden state.
  73. """
  74. loss: Optional[torch.FloatTensor] = None
  75. logits: Optional[torch.FloatTensor] = None
  76. past_key_values: Optional[Cache] = None
  77. hidden_states: Optional[tuple[torch.FloatTensor]] = None
  78. attentions: Optional[tuple[torch.FloatTensor]] = None
  79. image_hidden_states: Optional[torch.FloatTensor] = None
  80. @use_kernel_forward_from_hub("RMSNorm")
  81. class Ovis2RMSNorm(nn.Module):
  82. def __init__(self, hidden_size, eps=1e-6):
  83. """
  84. Ovis2RMSNorm is equivalent to T5LayerNorm
  85. """
  86. super().__init__()
  87. self.weight = nn.Parameter(torch.ones(hidden_size))
  88. self.variance_epsilon = eps
  89. def forward(self, hidden_states):
  90. input_dtype = hidden_states.dtype
  91. hidden_states = hidden_states.to(torch.float32)
  92. variance = hidden_states.pow(2).mean(-1, keepdim=True)
  93. hidden_states = hidden_states * torch.rsqrt(variance + self.variance_epsilon)
  94. return self.weight * hidden_states.to(input_dtype)
  95. def extra_repr(self):
  96. return f"{tuple(self.weight.shape)}, eps={self.variance_epsilon}"
  97. class Ovis2VisionMLP(nn.Module):
  98. def __init__(self, config):
  99. super().__init__()
  100. self.config = config
  101. self.hidden_size = config.hidden_size
  102. self.intermediate_size = config.intermediate_size
  103. self.gate_proj = nn.Linear(self.hidden_size, self.intermediate_size, bias=config.mlp_bias)
  104. self.up_proj = nn.Linear(self.hidden_size, self.intermediate_size, bias=config.mlp_bias)
  105. self.down_proj = nn.Linear(self.intermediate_size, self.hidden_size, bias=config.mlp_bias)
  106. self.act_fn = ACT2FN[config.hidden_act]
  107. def forward(self, x):
  108. down_proj = self.down_proj(self.act_fn(self.gate_proj(x)) * self.up_proj(x))
  109. return down_proj
  110. class Ovis2VisionEmbeddings(nn.Module):
  111. def __init__(self, config: Ovis2VisionConfig):
  112. super().__init__()
  113. self.config = config
  114. self.embed_dim = config.hidden_size
  115. self.image_size = config.image_size
  116. self.patch_size = config.patch_size
  117. self.patch_embedding = nn.Conv2d(
  118. in_channels=config.num_channels,
  119. out_channels=self.embed_dim,
  120. kernel_size=self.patch_size,
  121. stride=self.patch_size,
  122. padding="valid",
  123. )
  124. self.num_patches = (self.image_size // self.patch_size) ** 2
  125. self.num_positions = self.num_patches
  126. self.position_embedding = nn.Embedding(self.num_positions, self.embed_dim)
  127. self.register_buffer("position_ids", torch.arange(self.num_positions).expand((1, -1)), persistent=False)
  128. self.rms_norm = Ovis2RMSNorm(config.hidden_size, config.rms_norm_eps)
  129. def forward(self, pixel_values: torch.FloatTensor) -> torch.Tensor:
  130. target_dtype = self.patch_embedding.weight.dtype
  131. patch_embeds = self.patch_embedding(pixel_values.to(dtype=target_dtype))
  132. embeddings = patch_embeds.flatten(2).transpose(1, 2)
  133. embeddings = self.rms_norm(embeddings)
  134. embeddings = embeddings + self.position_embedding(self.position_ids)
  135. return embeddings
  136. def eager_attention_forward(
  137. module: nn.Module,
  138. query: torch.Tensor,
  139. key: torch.Tensor,
  140. value: torch.Tensor,
  141. attention_mask: Optional[torch.Tensor],
  142. scaling: float,
  143. dropout: float = 0.0,
  144. **kwargs,
  145. ):
  146. attn_weights = torch.matmul(query, key.transpose(-1, -2)) * scaling
  147. if attention_mask is not None:
  148. attn_weights = attn_weights + attention_mask
  149. attn_weights = nn.functional.softmax(attn_weights, dim=-1, dtype=torch.float32).to(query.dtype)
  150. attn_weights = nn.functional.dropout(attn_weights, p=dropout, training=module.training)
  151. attn_output = torch.matmul(attn_weights, value)
  152. attn_output = attn_output.transpose(1, 2).contiguous()
  153. return attn_output, attn_weights
  154. class Ovis2VisionAttention(nn.Module):
  155. """Multi-headed attention from 'Attention Is All You Need' paper"""
  156. def __init__(self, config):
  157. super().__init__()
  158. self.config = config
  159. self.embed_dim = config.hidden_size
  160. self.num_heads = config.num_attention_heads
  161. self.head_dim = self.embed_dim // self.num_heads
  162. if self.head_dim * self.num_heads != self.embed_dim:
  163. raise ValueError(
  164. f"embed_dim must be divisible by num_heads (got `embed_dim`: {self.embed_dim} and `num_heads`:"
  165. f" {self.num_heads})."
  166. )
  167. self.scale = self.head_dim**-0.5
  168. self.dropout = config.attention_dropout
  169. self.is_causal = False
  170. self.k_proj = nn.Linear(self.embed_dim, self.embed_dim, bias=config.qkv_bias)
  171. self.v_proj = nn.Linear(self.embed_dim, self.embed_dim, bias=config.qkv_bias)
  172. self.q_proj = nn.Linear(self.embed_dim, self.embed_dim, bias=config.qkv_bias)
  173. self.out_proj = nn.Linear(self.embed_dim, self.embed_dim, bias=config.qkv_bias)
  174. def forward(
  175. self,
  176. hidden_states: torch.Tensor,
  177. attention_mask: Optional[torch.Tensor] = None,
  178. **kwargs,
  179. ) -> tuple[torch.Tensor, Optional[torch.Tensor]]:
  180. """Input shape: Batch x Time x Channel"""
  181. batch_size, seq_length, embed_dim = hidden_states.shape
  182. queries = self.q_proj(hidden_states)
  183. keys = self.k_proj(hidden_states)
  184. values = self.v_proj(hidden_states)
  185. queries = queries.view(batch_size, seq_length, self.num_heads, self.head_dim).transpose(1, 2)
  186. keys = keys.view(batch_size, seq_length, self.num_heads, self.head_dim).transpose(1, 2)
  187. values = values.view(batch_size, seq_length, self.num_heads, self.head_dim).transpose(1, 2)
  188. attention_interface: Callable = eager_attention_forward
  189. if self.config._attn_implementation != "eager":
  190. attention_interface = ALL_ATTENTION_FUNCTIONS[self.config._attn_implementation]
  191. attn_output, attn_weights = attention_interface(
  192. self,
  193. queries,
  194. keys,
  195. values,
  196. attention_mask,
  197. is_causal=self.is_causal,
  198. scaling=self.scale,
  199. dropout=0.0 if not self.training else self.dropout,
  200. )
  201. attn_output = attn_output.reshape(batch_size, seq_length, embed_dim).contiguous()
  202. attn_output = self.out_proj(attn_output)
  203. return attn_output, attn_weights
  204. class Ovis2MLP(nn.Module):
  205. def __init__(self, config):
  206. super().__init__()
  207. self.config = config
  208. self.hidden_size = config.hidden_size
  209. self.intermediate_size = config.intermediate_size
  210. self.gate_proj = nn.Linear(self.hidden_size, self.intermediate_size, bias=config.mlp_bias)
  211. self.up_proj = nn.Linear(self.hidden_size, self.intermediate_size, bias=config.mlp_bias)
  212. self.down_proj = nn.Linear(self.intermediate_size, self.hidden_size, bias=config.mlp_bias)
  213. self.act_fn = ACT2FN[config.hidden_act]
  214. def forward(self, x):
  215. down_proj = self.down_proj(self.act_fn(self.gate_proj(x)) * self.up_proj(x))
  216. return down_proj
  217. class Ovis2Attention(nn.Module):
  218. """Multi-headed attention from 'Attention Is All You Need' paper"""
  219. def __init__(self, config):
  220. super().__init__()
  221. self.config = config
  222. self.embed_dim = config.hidden_size
  223. self.num_heads = config.num_attention_heads
  224. self.head_dim = self.embed_dim // self.num_heads
  225. if self.head_dim * self.num_heads != self.embed_dim:
  226. raise ValueError(
  227. f"embed_dim must be divisible by num_heads (got `embed_dim`: {self.embed_dim} and `num_heads`:"
  228. f" {self.num_heads})."
  229. )
  230. self.scale = self.head_dim**-0.5
  231. self.dropout = config.attention_dropout
  232. self.is_causal = False
  233. self.k_proj = nn.Linear(self.embed_dim, self.embed_dim, bias=config.qkv_bias)
  234. self.v_proj = nn.Linear(self.embed_dim, self.embed_dim, bias=config.qkv_bias)
  235. self.q_proj = nn.Linear(self.embed_dim, self.embed_dim, bias=config.qkv_bias)
  236. self.out_proj = nn.Linear(self.embed_dim, self.embed_dim, bias=config.qkv_bias)
  237. def forward(
  238. self,
  239. hidden_states: torch.Tensor,
  240. attention_mask: Optional[torch.Tensor] = None,
  241. **kwargs,
  242. ) -> tuple[torch.Tensor, Optional[torch.Tensor]]:
  243. """Input shape: Batch x Time x Channel"""
  244. batch_size, seq_length, embed_dim = hidden_states.shape
  245. queries = self.q_proj(hidden_states)
  246. keys = self.k_proj(hidden_states)
  247. values = self.v_proj(hidden_states)
  248. queries = queries.view(batch_size, seq_length, self.num_heads, self.head_dim).transpose(1, 2)
  249. keys = keys.view(batch_size, seq_length, self.num_heads, self.head_dim).transpose(1, 2)
  250. values = values.view(batch_size, seq_length, self.num_heads, self.head_dim).transpose(1, 2)
  251. attention_interface: Callable = eager_attention_forward
  252. if self.config._attn_implementation != "eager":
  253. attention_interface = ALL_ATTENTION_FUNCTIONS[self.config._attn_implementation]
  254. attn_output, attn_weights = attention_interface(
  255. self,
  256. queries,
  257. keys,
  258. values,
  259. attention_mask,
  260. is_causal=self.is_causal,
  261. scaling=self.scale,
  262. dropout=0.0 if not self.training else self.dropout,
  263. )
  264. attn_output = attn_output.reshape(batch_size, seq_length, embed_dim).contiguous()
  265. attn_output = self.out_proj(attn_output)
  266. return attn_output, attn_weights
  267. class Ovis2VisionEncoderLayer(GradientCheckpointingLayer):
  268. def __init__(self, config: Ovis2VisionConfig):
  269. super().__init__()
  270. self.attention = Ovis2Attention(config)
  271. self.ffn = Ovis2MLP(config)
  272. self.rms_norm1 = Ovis2RMSNorm(config.hidden_size, config.rms_norm_eps)
  273. self.rms_norm2 = Ovis2RMSNorm(config.hidden_size, config.rms_norm_eps)
  274. def forward(
  275. self,
  276. hidden_states: torch.Tensor,
  277. attention_mask: Optional[torch.Tensor] = None,
  278. **kwargs: Unpack[TransformersKwargs],
  279. ) -> torch.Tensor:
  280. norm_hidden_states = self.rms_norm1(hidden_states)
  281. attn_output, _ = self.attention(hidden_states=norm_hidden_states, attention_mask=attention_mask, **kwargs)
  282. hidden_states = hidden_states + attn_output
  283. norm_hidden_states = self.rms_norm2(hidden_states)
  284. mlp_output = self.ffn(norm_hidden_states)
  285. hidden_states = hidden_states + mlp_output
  286. return hidden_states
  287. class Ovis2VisionEncoder(nn.Module):
  288. """
  289. Transformer encoder consisting of `config.num_hidden_layers` self attention layers. Each layer is a
  290. [`Ovis2VisionEncoderLayer`].
  291. Args:
  292. config: Ovis2VisionConfig
  293. """
  294. def __init__(self, config: Ovis2VisionConfig):
  295. super().__init__()
  296. self.config = config
  297. self.layers = nn.ModuleList([Ovis2VisionEncoderLayer(config) for _ in range(config.num_hidden_layers)])
  298. self.gradient_checkpointing = False
  299. # Ignore copy
  300. @can_return_tuple
  301. @auto_docstring
  302. def forward(
  303. self,
  304. inputs_embeds,
  305. attention_mask: Optional[torch.Tensor] = None,
  306. **kwargs: Unpack[TransformersKwargs],
  307. ) -> BaseModelOutput:
  308. hidden_states = inputs_embeds
  309. for encoder_layer in self.layers:
  310. hidden_states = encoder_layer(hidden_states, attention_mask, **kwargs)
  311. return BaseModelOutput(last_hidden_state=hidden_states)
  312. class Ovis2VisionTransformer(nn.Module):
  313. def __init__(self, config: Ovis2VisionConfig):
  314. super().__init__()
  315. self.config = config
  316. self.embeddings = Ovis2VisionEmbeddings(config)
  317. self.encoder = Ovis2VisionEncoder(config)
  318. self.rms_norm = Ovis2RMSNorm(config.hidden_size, config.rms_norm_eps)
  319. self.gradient_checkpointing = False
  320. @can_return_tuple
  321. def forward(
  322. self,
  323. pixel_values,
  324. attention_mask: Optional[torch.Tensor] = None,
  325. **kwargs,
  326. ):
  327. hidden_states = self.embeddings(pixel_values)
  328. encoder_outputs: BaseModelOutput = self.encoder(
  329. inputs_embeds=hidden_states,
  330. attention_mask=attention_mask,
  331. **kwargs,
  332. )
  333. last_hidden_state = encoder_outputs.last_hidden_state
  334. last_hidden_state = self.rms_norm(last_hidden_state)
  335. return BaseModelOutput(last_hidden_state=last_hidden_state)
  336. class Ovis2VisualEmbeddingTable(nn.Embedding):
  337. def forward(self, visual_tokens: torch.Tensor) -> torch.Tensor:
  338. if visual_tokens.dtype in [torch.int8, torch.int16, torch.int32, torch.int64, torch.long]:
  339. return super().forward(visual_tokens)
  340. return torch.matmul(visual_tokens, self.weight)
  341. class Ovis2PreTrainedModel(PreTrainedModel):
  342. config: Ovis2Config
  343. base_model_prefix = "model"
  344. supports_gradient_checkpointing = True
  345. _no_split_modules = ["Ovis2VisionAttention"]
  346. _skip_keys_device_placement = "past_key_values"
  347. _supports_cache_class = True
  348. _supports_flash_attn = True
  349. _supports_flex_attn = True
  350. _supports_sdpa = True
  351. _can_compile_fullgraph = True
  352. _supports_attention_backend = True
  353. def hard_softmax(logits: torch.Tensor, dim: int):
  354. y_soft = logits.softmax(dim)
  355. # Straight through.
  356. index = y_soft.max(dim, keepdim=True)[1]
  357. y_hard = torch.zeros_like(logits, memory_format=torch.legacy_contiguous_format).scatter_(dim, index, 1.0)
  358. ret = y_hard - y_soft.detach() + y_soft
  359. return ret
  360. class Ovis2VisionModel(Ovis2PreTrainedModel):
  361. config: Ovis2VisionConfig
  362. def __init__(self, config: Ovis2VisionConfig):
  363. super().__init__(config)
  364. self.config = config
  365. self.transformer = Ovis2VisionTransformer(config)
  366. self.num_visual_indicator_tokens = config.num_visual_indicator_tokens
  367. self.vocab_size = config.vocab_size
  368. self.head_linear = nn.Linear(
  369. config.hidden_size * config.hidden_stride * config.hidden_stride,
  370. self.vocab_size - self.num_visual_indicator_tokens,
  371. bias=False,
  372. )
  373. self.head_norm = nn.LayerNorm(self.vocab_size - self.num_visual_indicator_tokens)
  374. def forward(self, pixel_values: torch.FloatTensor, **kwargs) -> tuple[torch.Tensor, torch.Tensor]:
  375. outputs = self.transformer(pixel_values, **kwargs)
  376. last_hidden_state = outputs[0]
  377. if self.config.hidden_stride > 1:
  378. num_images, seq_len, hidden_dim = last_hidden_state.shape
  379. hidden_stride = self.config.hidden_stride
  380. sqrt_l = int(math.sqrt(seq_len))
  381. if sqrt_l * sqrt_l != seq_len:
  382. raise ValueError("Token sequence length must be a perfect square")
  383. pad_size = (hidden_stride - (sqrt_l % hidden_stride)) % hidden_stride
  384. last_hidden_state = nn.functional.pad(last_hidden_state, (0, 0, 0, pad_size, 0, pad_size), "constant", 0)
  385. sqrt_l += pad_size
  386. last_hidden_state = last_hidden_state.reshape(
  387. num_images, sqrt_l // hidden_stride, hidden_stride, sqrt_l // hidden_stride, hidden_stride, hidden_dim
  388. )
  389. last_hidden_state = last_hidden_state.permute(0, 1, 3, 2, 4, 5)
  390. last_hidden_state = last_hidden_state.reshape(
  391. num_images, -1, hidden_stride * hidden_stride * hidden_dim
  392. ) # (n, (sqrt_l//hs)^2, hs^2*d)
  393. logits = self.head_linear(last_hidden_state)
  394. logits = self.head_norm(logits)
  395. if self.config.tokenize_function == "gumbel_argmax":
  396. prob_token = nn.functional.gumbel_softmax(logits, dim=-1, hard=True)
  397. elif self.config.tokenize_function == "st_argmax":
  398. prob_token = hard_softmax(logits, dim=-1)
  399. elif self.config.tokenize_function == "softmax":
  400. prob_token = nn.functional.softmax(logits, dim=-1)
  401. return prob_token
  402. @auto_docstring(
  403. custom_intro="""
  404. The Ovis2 model which consists of a vision backbone and a language model, without a language modeling head.
  405. """
  406. )
  407. class Ovis2Model(Ovis2PreTrainedModel):
  408. _checkpoint_conversion_mapping = {}
  409. def __init__(self, config: Ovis2Config):
  410. super().__init__(config)
  411. self.vision_tower = Ovis2VisionModel(config.vision_config)
  412. self.language_model = AutoModel.from_config(config.text_config)
  413. self.visual_embeddings_table = Ovis2VisualEmbeddingTable(config.vision_config.vocab_size, config.hidden_size)
  414. self.visual_vocab_size = config.vision_config.vocab_size
  415. self.vocab_size = config.vocab_size
  416. self.visual_indicator_token_ids = config.visual_indicator_token_ids
  417. self.post_init()
  418. def get_input_embeddings(self):
  419. return self.language_model.get_input_embeddings()
  420. def set_input_embeddings(self, value):
  421. self.language_model.set_input_embeddings(value)
  422. def set_decoder(self, decoder):
  423. self.language_model = decoder
  424. def get_decoder(self):
  425. return self.language_model
  426. def get_image_features(
  427. self,
  428. pixel_values: torch.FloatTensor,
  429. ) -> torch.FloatTensor:
  430. """
  431. Obtains image last hidden states from the vision tower and apply multimodal projection.
  432. Args:
  433. pixel_values (`torch.FloatTensor]` of shape `(batch_size, channels, height, width)`):
  434. The tensors corresponding to the input images.
  435. vision_feature_layer (`Union[int, list[int]]`, *optional*):
  436. The index of the layer to select the vision feature. If multiple indices are provided,
  437. the vision feature of the corresponding indices will be concatenated to form the
  438. vision features.
  439. vision_feature_select_strategy (`str`, *optional*):
  440. The feature selection strategy used to select the vision feature from the vision backbone.
  441. Can be one of `"default"` or `"full"`
  442. Returns:
  443. image_features (`torch.Tensor`): Image feature tensor of shape `(num_images, image_length, embed_dim)`).
  444. """
  445. image_features = self.vision_tower(pixel_values)
  446. batch_size, img_seq_len, _ = image_features.shape
  447. padding_tensor = torch.zeros(
  448. (batch_size, img_seq_len, self.vision_tower.num_visual_indicator_tokens),
  449. dtype=image_features.dtype,
  450. device=image_features.device,
  451. requires_grad=False,
  452. layout=image_features.layout,
  453. )
  454. image_features = torch.cat([image_features, padding_tensor], dim=2)
  455. image_features = self.visual_embeddings_table(image_features)
  456. visual_indicator = torch.arange(
  457. self.visual_vocab_size - self.vision_tower.num_visual_indicator_tokens,
  458. self.visual_vocab_size,
  459. dtype=torch.long,
  460. ).to(image_features.device)
  461. visual_indicator_features = self.visual_embeddings_table(visual_indicator)
  462. return image_features, visual_indicator_features
  463. def get_placeholder_mask(
  464. self, input_ids: torch.LongTensor, inputs_embeds: torch.FloatTensor, image_features: torch.FloatTensor
  465. ):
  466. """
  467. Obtains multimodal placeholder mask from `input_ids` or `inputs_embeds`, and checks that the placeholder token count is
  468. equal to the length of multimodal features. If the lengths are different, an error is raised.
  469. """
  470. if input_ids is None:
  471. special_image_mask = inputs_embeds == self.get_input_embeddings()(
  472. torch.tensor(self.config.image_token_id, dtype=torch.long, device=inputs_embeds.device)
  473. )
  474. special_image_mask = special_image_mask.all(-1)
  475. else:
  476. special_image_mask = input_ids == self.config.image_token_id
  477. n_image_tokens = special_image_mask.sum()
  478. special_image_mask = special_image_mask.unsqueeze(-1).expand_as(inputs_embeds).to(inputs_embeds.device)
  479. n_image_features = image_features.shape[0] * image_features.shape[1]
  480. if inputs_embeds[special_image_mask].numel() != image_features.numel():
  481. raise ValueError(
  482. f"Image features and image tokens do not match: tokens: {n_image_tokens}, features {n_image_features}"
  483. )
  484. return special_image_mask
  485. @can_return_tuple
  486. @auto_docstring
  487. def forward(
  488. self,
  489. input_ids: Optional[torch.LongTensor] = None,
  490. pixel_values: Optional[torch.FloatTensor] = None,
  491. attention_mask: Optional[torch.Tensor] = None,
  492. position_ids: Optional[torch.LongTensor] = None,
  493. past_key_values: Optional[Cache] = None,
  494. inputs_embeds: Optional[torch.FloatTensor] = None,
  495. labels: Optional[torch.LongTensor] = None,
  496. use_cache: Optional[bool] = None,
  497. output_attentions: Optional[bool] = None,
  498. output_hidden_states: Optional[bool] = None,
  499. return_dict: Optional[bool] = None,
  500. cache_position: Optional[torch.LongTensor] = None,
  501. logits_to_keep: Union[int, torch.Tensor] = 0,
  502. **kwargs,
  503. ) -> Union[tuple, Ovis2ModelOutputWithPast]:
  504. output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
  505. output_hidden_states = (
  506. output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
  507. )
  508. if (input_ids is None) ^ (inputs_embeds is not None):
  509. raise ValueError("You must specify exactly one of input_ids or inputs_embeds")
  510. if inputs_embeds is None:
  511. inputs_embeds = self.get_input_embeddings()(input_ids)
  512. if pixel_values is not None:
  513. image_features, visual_indicator_features = self.get_image_features(pixel_values=pixel_values)
  514. special_image_mask = self.get_placeholder_mask(
  515. input_ids,
  516. inputs_embeds=inputs_embeds,
  517. image_features=image_features,
  518. )
  519. inputs_embeds = inputs_embeds.masked_scatter(special_image_mask, image_features)
  520. for i, visual_indicator_id in enumerate(self.visual_indicator_token_ids):
  521. if input_ids is None:
  522. mask = inputs_embeds == self.get_input_embeddings()(
  523. torch.tensor(visual_indicator_id, dtype=torch.long, device=inputs_embeds.device)
  524. )
  525. mask = mask.all(-1)
  526. else:
  527. mask = (input_ids == visual_indicator_id).to(inputs_embeds.device)
  528. if mask.any():
  529. inputs_embeds[mask] = (
  530. visual_indicator_features[i]
  531. .expand_as(inputs_embeds[mask])
  532. .to(inputs_embeds.device, inputs_embeds.dtype)
  533. )
  534. outputs = self.language_model(
  535. attention_mask=attention_mask,
  536. position_ids=position_ids,
  537. past_key_values=past_key_values,
  538. inputs_embeds=inputs_embeds,
  539. use_cache=use_cache,
  540. output_attentions=output_attentions,
  541. output_hidden_states=output_hidden_states,
  542. return_dict=True,
  543. cache_position=cache_position,
  544. logits_to_keep=logits_to_keep,
  545. **kwargs,
  546. )
  547. return Ovis2ModelOutputWithPast(
  548. last_hidden_state=outputs.last_hidden_state,
  549. past_key_values=outputs.past_key_values,
  550. hidden_states=outputs.hidden_states,
  551. attentions=outputs.attentions,
  552. image_hidden_states=image_features if pixel_values is not None else None,
  553. )
  554. @auto_docstring
  555. class Ovis2ForConditionalGeneration(Ovis2PreTrainedModel, GenerationMixin):
  556. _checkpoint_conversion_mapping = {}
  557. _tied_weights_keys = ["lm_head.weight"]
  558. def __init__(self, config: Ovis2Config):
  559. super().__init__(config)
  560. self.model = Ovis2Model(config)
  561. self.lm_head = nn.Linear(config.hidden_size, config.vocab_size, bias=False)
  562. self.post_init()
  563. def get_input_embeddings(self):
  564. return self.model.get_input_embeddings()
  565. def set_input_embeddings(self, value):
  566. self.model.set_input_embeddings(value)
  567. def get_output_embeddings(self) -> nn.Module:
  568. return self.lm_head
  569. def set_decoder(self, decoder):
  570. self.model.set_decoder(decoder)
  571. def get_decoder(self):
  572. return self.model.get_decoder()
  573. def get_image_features(self, pixel_values: torch.FloatTensor):
  574. return self.model.get_image_features(pixel_values=pixel_values)
  575. # Make modules available through conditional class for BC
  576. @property
  577. def language_model(self):
  578. return self.model.language_model
  579. @property
  580. def vision_tower(self):
  581. return self.model.vision_tower
  582. @property
  583. def multi_modal_projector(self):
  584. raise AttributeError("Not needed for Ovis2")
  585. @can_return_tuple
  586. @auto_docstring
  587. def forward(
  588. self,
  589. input_ids: Optional[torch.LongTensor] = None,
  590. pixel_values: Optional[torch.FloatTensor] = None,
  591. attention_mask: Optional[torch.Tensor] = None,
  592. position_ids: Optional[torch.LongTensor] = None,
  593. past_key_values: Optional[Cache] = None,
  594. inputs_embeds: Optional[torch.FloatTensor] = None,
  595. labels: Optional[torch.LongTensor] = None,
  596. use_cache: Optional[bool] = None,
  597. output_attentions: Optional[bool] = None,
  598. output_hidden_states: Optional[bool] = None,
  599. return_dict: Optional[bool] = None,
  600. cache_position: Optional[torch.LongTensor] = None,
  601. logits_to_keep: Union[int, torch.Tensor] = 0,
  602. **kwargs,
  603. ) -> Union[tuple, Ovis2CausalLMOutputWithPast]:
  604. r"""
  605. labels (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*):
  606. Labels for computing the masked language modeling loss. Indices should either be in `[0, ...,
  607. config.vocab_size]` or -100 (see `input_ids` docstring). Tokens with indices set to `-100` are ignored
  608. (masked), the loss is only computed for the tokens with labels in `[0, ..., config.vocab_size]`.
  609. Example:
  610. ```python
  611. >>> from PIL import Image
  612. >>> import requests
  613. >>> from transformers import AutoProcessor, Ovis2ForConditionalGeneration
  614. >>> model = Ovis2ForConditionalGeneration.from_pretrained("thisisiron/Ovis2-2B-hf")
  615. >>> processor = AutoProcessor.from_pretrained("thisisiron/Ovis2-2B-hf")
  616. >>> prompt = "<|im_start|>user\n<image>\nDescribe the image.<|im_end|>\n<|im_start|>assistant\n"
  617. >>> url = "http://images.cocodataset.org/val2014/COCO_val2014_000000537955.jpg"
  618. >>> image = Image.open(requests.get(url, stream=True).raw)
  619. >>> inputs = processor(images=image, text=prompt, return_tensors="pt")
  620. >>> # Generate
  621. >>> generate_ids = model.generate(**inputs, max_new_tokens=15)
  622. >>> processor.batch_decode(generate_ids, skip_special_tokens=True)[0]
  623. "user\n\nDescribe the image.\nassistant\nThe image features a brown dog standing on a wooden floor, looking up with"
  624. ```"""
  625. output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
  626. output_hidden_states = (
  627. output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
  628. )
  629. outputs = self.model(
  630. input_ids=input_ids,
  631. pixel_values=pixel_values,
  632. attention_mask=attention_mask,
  633. position_ids=position_ids,
  634. past_key_values=past_key_values,
  635. inputs_embeds=inputs_embeds,
  636. use_cache=use_cache,
  637. output_attentions=output_attentions,
  638. output_hidden_states=output_hidden_states,
  639. return_dict=True,
  640. cache_position=cache_position,
  641. **kwargs,
  642. )
  643. hidden_states = outputs[0]
  644. # Only compute necessary logits, and do not upcast them to float if we are not computing the loss
  645. slice_indices = slice(-logits_to_keep, None) if isinstance(logits_to_keep, int) else logits_to_keep
  646. logits = self.lm_head(hidden_states[:, slice_indices, :])
  647. loss = None
  648. if labels is not None:
  649. loss = self.loss_function(
  650. logits=logits, labels=labels, vocab_size=self.config.text_config.vocab_size, **kwargs
  651. )
  652. return Ovis2CausalLMOutputWithPast(
  653. loss=loss,
  654. logits=logits,
  655. past_key_values=outputs.past_key_values,
  656. hidden_states=outputs.hidden_states,
  657. attentions=outputs.attentions,
  658. image_hidden_states=outputs.image_hidden_states,
  659. )
  660. def prepare_inputs_for_generation(
  661. self,
  662. input_ids,
  663. past_key_values=None,
  664. inputs_embeds=None,
  665. pixel_values=None,
  666. attention_mask=None,
  667. cache_position=None,
  668. logits_to_keep=None,
  669. **kwargs,
  670. ):
  671. # Overwritten -- in specific circumstances we don't want to forward image inputs to the model
  672. model_inputs = super().prepare_inputs_for_generation(
  673. input_ids,
  674. past_key_values=past_key_values,
  675. inputs_embeds=inputs_embeds,
  676. attention_mask=attention_mask,
  677. cache_position=cache_position,
  678. logits_to_keep=logits_to_keep,
  679. **kwargs,
  680. )
  681. if cache_position[0] == 0:
  682. # If we're in cached decoding stage, pixel values should be None because input ids do not contain special image token anymore
  683. # Otherwise we need pixel values to be passed to model
  684. model_inputs["pixel_values"] = pixel_values
  685. return model_inputs
  686. __all__ = ["Ovis2PreTrainedModel", "Ovis2Model", "Ovis2ForConditionalGeneration"]