modeling_voxtral.py 23 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525526527528529530531532533534535536537538539540541542543544545546547548549550551
  1. # 🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨
  2. # This file was automatically generated from src/transformers/models/voxtral/modular_voxtral.py.
  3. # Do NOT edit this file manually as any edits will be overwritten by the generation of
  4. # the file from the modular. If any change should be done, please apply the change to the
  5. # modular_voxtral.py file directly. One of our CI enforces this.
  6. # 🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨
  7. # coding=utf-8
  8. # Copyright 2025 The HuggingFace Inc. team. All rights reserved.
  9. #
  10. # Licensed under the Apache License, Version 2.0 (the "License");
  11. # you may not use this file except in compliance with the License.
  12. # You may obtain a copy of the License at
  13. #
  14. # http://www.apache.org/licenses/LICENSE-2.0
  15. #
  16. # Unless required by applicable law or agreed to in writing, software
  17. # distributed under the License is distributed on an "AS IS" BASIS,
  18. # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  19. # See the License for the specific language governing permissions and
  20. # limitations under the License.
  21. import math
  22. import warnings
  23. from typing import Callable, Optional, Union
  24. import torch
  25. from torch import nn
  26. from ...activations import ACT2FN
  27. from ...cache_utils import Cache
  28. from ...generation import GenerationMixin
  29. from ...modeling_layers import GradientCheckpointingLayer
  30. from ...modeling_outputs import BaseModelOutput, BaseModelOutputWithPast, CausalLMOutputWithPast
  31. from ...modeling_utils import ALL_ATTENTION_FUNCTIONS, PreTrainedModel
  32. from ...processing_utils import Unpack
  33. from ...utils import TransformersKwargs, auto_docstring, can_return_tuple, logging
  34. from ...utils.generic import check_model_inputs
  35. from ..auto import AutoModel, AutoModelForCausalLM
  36. from .configuration_voxtral import VoxtralConfig, VoxtralEncoderConfig
  37. logger = logging.get_logger(__name__)
  38. def eager_attention_forward(
  39. module: nn.Module,
  40. query: torch.Tensor,
  41. key: torch.Tensor,
  42. value: torch.Tensor,
  43. attention_mask: Optional[torch.Tensor],
  44. scaling: Optional[float] = None,
  45. dropout: float = 0.0,
  46. head_mask: Optional[torch.Tensor] = None,
  47. **kwargs,
  48. ):
  49. if scaling is None:
  50. scaling = query.size(-1) ** -0.5
  51. attn_weights = torch.matmul(query, key.transpose(2, 3)) * scaling
  52. if attention_mask is not None and attention_mask.ndim == 4:
  53. attn_weights = attn_weights + attention_mask[:, :, :, : key.shape[-2]]
  54. attn_weights = nn.functional.softmax(attn_weights, dim=-1)
  55. if head_mask is not None:
  56. attn_weights = attn_weights * head_mask.view(1, -1, 1, 1)
  57. attn_weights = nn.functional.dropout(attn_weights, p=dropout, training=module.training)
  58. attn_output = torch.matmul(attn_weights, value)
  59. attn_output = attn_output.transpose(1, 2).contiguous()
  60. return attn_output, attn_weights
  61. class VoxtralAttention(nn.Module):
  62. """Multi-headed attention from 'Attention Is All You Need' paper"""
  63. def __init__(
  64. self,
  65. embed_dim: int,
  66. num_heads: int,
  67. dropout: float = 0.0,
  68. is_decoder: bool = False,
  69. bias: bool = True,
  70. is_causal: bool = False,
  71. layer_idx: Optional[int] = None,
  72. config: Optional[VoxtralConfig] = None,
  73. ):
  74. super().__init__()
  75. self.embed_dim = embed_dim
  76. self.num_heads = num_heads
  77. self.dropout = dropout
  78. self.head_dim = embed_dim // num_heads
  79. self.config = config
  80. if (self.head_dim * num_heads) != self.embed_dim:
  81. raise ValueError(
  82. f"embed_dim must be divisible by num_heads (got `embed_dim`: {self.embed_dim}"
  83. f" and `num_heads`: {num_heads})."
  84. )
  85. self.scaling = self.head_dim**-0.5
  86. self.is_decoder = is_decoder
  87. self.is_causal = is_causal
  88. if layer_idx is None and is_decoder:
  89. logger.warning_once(
  90. f"Instantiating a decoder {self.__class__.__name__} without passing `layer_idx` is not recommended and "
  91. "will to errors during the forward call, if caching is used. Please make sure to provide a `layer_idx` "
  92. "when creating this class."
  93. )
  94. self.layer_idx = layer_idx
  95. self.k_proj = nn.Linear(embed_dim, embed_dim, bias=False)
  96. self.v_proj = nn.Linear(embed_dim, embed_dim, bias=bias)
  97. self.q_proj = nn.Linear(embed_dim, embed_dim, bias=bias)
  98. self.out_proj = nn.Linear(embed_dim, embed_dim, bias=bias)
  99. def _shape(self, tensor: torch.Tensor, seq_len: int, bsz: int):
  100. return tensor.view(bsz, seq_len, self.num_heads, self.head_dim).transpose(1, 2).contiguous()
  101. def forward(
  102. self,
  103. hidden_states: torch.Tensor,
  104. attention_mask: Optional[torch.Tensor] = None,
  105. layer_head_mask: Optional[torch.Tensor] = None,
  106. output_attentions: bool = False,
  107. **kwargs,
  108. ) -> tuple[torch.Tensor, Optional[torch.Tensor], Optional[tuple[torch.Tensor]]]:
  109. """Input shape: Batch x Time x Channel"""
  110. bsz, tgt_len, _ = hidden_states.size()
  111. # Scaling is susceptible to floating point arithmetics' inprecisions
  112. # which can lead to different results (this is dependent from model
  113. # to model, e.g. whisper is one such case). We therefore keep the
  114. # original order of scaling to follow the original implementation
  115. # and enforce no scaling (1.0) in the attention call below.
  116. query_states = self._shape(self.q_proj(hidden_states) * self.scaling, tgt_len, bsz)
  117. key_states = self._shape(self.k_proj(hidden_states), -1, bsz)
  118. value_states = self._shape(self.v_proj(hidden_states), -1, bsz)
  119. attention_interface: Callable = eager_attention_forward
  120. if self.config._attn_implementation != "eager":
  121. attention_interface = ALL_ATTENTION_FUNCTIONS[self.config._attn_implementation]
  122. attn_output, attn_weights = attention_interface(
  123. self,
  124. query_states,
  125. key_states,
  126. value_states,
  127. attention_mask,
  128. dropout=0.0 if not self.training else self.dropout,
  129. scaling=1.0,
  130. output_attentions=output_attentions,
  131. head_mask=layer_head_mask,
  132. **kwargs,
  133. )
  134. attn_output = attn_output.reshape(bsz, tgt_len, -1).contiguous()
  135. attn_output = self.out_proj(attn_output)
  136. return attn_output, attn_weights
  137. class VoxtralEncoderLayer(GradientCheckpointingLayer):
  138. def __init__(self, config: VoxtralConfig):
  139. super().__init__()
  140. self.embed_dim = config.d_model
  141. self.self_attn = VoxtralAttention(
  142. embed_dim=self.embed_dim,
  143. num_heads=config.encoder_attention_heads,
  144. dropout=config.attention_dropout,
  145. config=config,
  146. )
  147. self.self_attn_layer_norm = nn.LayerNorm(self.embed_dim)
  148. self.dropout = config.dropout
  149. self.activation_fn = ACT2FN[config.activation_function]
  150. self.activation_dropout = config.activation_dropout
  151. self.fc1 = nn.Linear(self.embed_dim, config.encoder_ffn_dim)
  152. self.fc2 = nn.Linear(config.encoder_ffn_dim, self.embed_dim)
  153. self.final_layer_norm = nn.LayerNorm(self.embed_dim)
  154. def forward(
  155. self,
  156. hidden_states: torch.Tensor,
  157. attention_mask: torch.Tensor,
  158. layer_head_mask: torch.Tensor,
  159. output_attentions: bool = False,
  160. ) -> torch.Tensor:
  161. """
  162. Args:
  163. hidden_states (`torch.FloatTensor`): input to the layer of shape `(batch, seq_len, embed_dim)`
  164. attention_mask (`torch.FloatTensor`): attention mask of size
  165. `(batch, 1, tgt_len, src_len)` where padding elements are indicated by very large negative values.
  166. layer_head_mask (`torch.FloatTensor`): mask for attention heads in a given layer of size
  167. `(encoder_attention_heads,)`.
  168. output_attentions (`bool`, *optional*):
  169. Whether or not to return the attentions tensors of all attention layers. See `attentions` under
  170. returned tensors for more detail.
  171. """
  172. residual = hidden_states
  173. hidden_states = self.self_attn_layer_norm(hidden_states)
  174. hidden_states, attn_weights = self.self_attn(
  175. hidden_states=hidden_states,
  176. attention_mask=attention_mask,
  177. layer_head_mask=layer_head_mask,
  178. output_attentions=output_attentions,
  179. )
  180. hidden_states = nn.functional.dropout(hidden_states, p=self.dropout, training=self.training)
  181. hidden_states = residual + hidden_states
  182. residual = hidden_states
  183. hidden_states = self.final_layer_norm(hidden_states)
  184. hidden_states = self.activation_fn(self.fc1(hidden_states))
  185. hidden_states = nn.functional.dropout(hidden_states, p=self.activation_dropout, training=self.training)
  186. hidden_states = self.fc2(hidden_states)
  187. hidden_states = nn.functional.dropout(hidden_states, p=self.dropout, training=self.training)
  188. hidden_states = residual + hidden_states
  189. if hidden_states.dtype == torch.float16:
  190. clamp_value = torch.finfo(hidden_states.dtype).max - 1000
  191. hidden_states = torch.clamp(hidden_states, min=-clamp_value, max=clamp_value)
  192. return hidden_states, attn_weights
  193. @auto_docstring
  194. class VoxtralPreTrainedModel(PreTrainedModel):
  195. config: VoxtralConfig
  196. base_model_prefix = "model"
  197. supports_gradient_checkpointing = True
  198. _no_split_modules = None
  199. _skip_keys_device_placement = "past_key_values"
  200. _supports_flash_attn = True
  201. _supports_sdpa = True
  202. _supports_flex_attn = True
  203. _supports_cache_class = True
  204. _supports_attention_backend = True
  205. _can_compile_fullgraph = True
  206. def _init_weights(self, module):
  207. # important: this ported version of Voxtral isn't meant for training from scratch - only
  208. # inference and fine-tuning - so the proper init weights code has been removed
  209. std = (
  210. self.config.initializer_range
  211. if hasattr(self.config, "initializer_range")
  212. else self.config.audio_config.initializer_range
  213. )
  214. if isinstance(module, (nn.Linear, nn.Conv1d)):
  215. module.weight.data.normal_(mean=0.0, std=std)
  216. if module.bias is not None:
  217. module.bias.data.zero_()
  218. elif isinstance(module, nn.LayerNorm):
  219. module.weight.data.fill_(1.0)
  220. module.bias.data.zero_()
  221. elif isinstance(module, nn.Embedding):
  222. module.weight.data.normal_(mean=0.0, std=std)
  223. if module.padding_idx is not None:
  224. module.weight.data[module.padding_idx].zero_()
  225. @auto_docstring(
  226. custom_intro="""
  227. The Voxtral encoder, which is a Whisper encoder.
  228. """
  229. )
  230. class VoxtralEncoder(VoxtralPreTrainedModel):
  231. """
  232. Transformer encoder consisting of *config.encoder_layers* self attention layers. Each layer is a
  233. [`VoxtralEncoderLayer`].
  234. Args:
  235. config: VoxtralEncoderConfig
  236. """
  237. # Ignore copy
  238. config: VoxtralEncoderConfig
  239. main_input_name = "input_features"
  240. _no_split_modules = ["VoxtralEncoderLayer"]
  241. _can_record_outputs = {
  242. "attentions": VoxtralAttention,
  243. "hidden_states": VoxtralEncoderLayer,
  244. }
  245. def __init__(self, config: VoxtralEncoderConfig):
  246. super().__init__(config)
  247. self.dropout = config.dropout
  248. self.layerdrop = config.encoder_layerdrop
  249. embed_dim = config.d_model
  250. self.num_mel_bins = config.num_mel_bins
  251. self.padding_idx = config.pad_token_id
  252. self.max_source_positions = config.max_source_positions
  253. self.embed_scale = math.sqrt(embed_dim) if config.scale_embedding else 1.0
  254. self.conv1 = nn.Conv1d(self.num_mel_bins, embed_dim, kernel_size=3, padding=1)
  255. self.conv2 = nn.Conv1d(embed_dim, embed_dim, kernel_size=3, stride=2, padding=1)
  256. self.embed_positions = nn.Embedding(self.max_source_positions, embed_dim)
  257. self.embed_positions.requires_grad_(False)
  258. self.layers = nn.ModuleList([VoxtralEncoderLayer(config) for _ in range(config.encoder_layers)])
  259. self.layer_norm = nn.LayerNorm(config.d_model)
  260. # Ignore copy
  261. self.avg_pooler = nn.AvgPool1d(2, stride=2)
  262. self.gradient_checkpointing = False
  263. # Initialize weights and apply final processing
  264. self.post_init()
  265. def _freeze_parameters(self):
  266. for param in self.parameters():
  267. param.requires_grad = False
  268. self._requires_grad = False
  269. def get_input_embeddings(self) -> nn.Module:
  270. return self.conv1
  271. def set_input_embeddings(self, value: nn.Module):
  272. self.conv1 = value
  273. @check_model_inputs()
  274. def forward(
  275. self,
  276. input_features,
  277. attention_mask=None,
  278. **kwargs: Unpack[TransformersKwargs],
  279. ):
  280. r"""
  281. Args:
  282. input_features (`torch.LongTensor` of shape `(batch_size, feature_size, sequence_length)`):
  283. Float values of mel features extracted from the raw speech waveform. Raw speech waveform can be
  284. obtained by loading a `.flac` or `.wav` audio file into an array of type `list[float]` or a
  285. `numpy.ndarray`, *e.g.* via the soundfile library (`pip install soundfile`). To prepare the array into
  286. `input_features`, the [`AutoFeatureExtractor`] should be used for extracting the mel features, padding
  287. and conversion into a tensor of type `torch.FloatTensor`. See [`~WhisperFeatureExtractor.__call__`]
  288. attention_mask (`torch.Tensor`)`, *optional*):
  289. Voxtral does not support masking of the `input_features`, this argument is preserved for compatibility,
  290. but it is not used. By default the silence in the input log mel spectrogram are ignored.
  291. """
  292. expected_seq_length = self.config.max_source_positions * self.conv1.stride[0] * self.conv2.stride[0]
  293. if input_features.shape[-1] != expected_seq_length:
  294. raise ValueError(
  295. f"Qwen2Audio expects the mel input features to be of length {expected_seq_length}, but found {input_features.shape[-1]}. Make sure to pad the input mel features to {expected_seq_length}."
  296. )
  297. input_features = input_features.to(dtype=self.conv1.weight.dtype, device=self.conv1.weight.device)
  298. inputs_embeds = nn.functional.gelu(self.conv1(input_features))
  299. inputs_embeds = nn.functional.gelu(self.conv2(inputs_embeds))
  300. inputs_embeds = inputs_embeds.permute(0, 2, 1)
  301. embed_pos = self.embed_positions.weight
  302. hidden_states = (inputs_embeds + embed_pos).to(inputs_embeds.dtype)
  303. hidden_states = nn.functional.dropout(hidden_states, p=self.dropout, training=self.training)
  304. for idx, encoder_layer in enumerate(self.layers):
  305. layer_outputs = encoder_layer(
  306. hidden_states,
  307. attention_mask=attention_mask,
  308. layer_head_mask=None,
  309. )
  310. hidden_states = layer_outputs[0]
  311. hidden_states = self.layer_norm(hidden_states)
  312. return BaseModelOutput(
  313. last_hidden_state=hidden_states,
  314. )
  315. # Ignore copy
  316. def _get_feat_extract_output_lengths(self, input_lengths: torch.LongTensor):
  317. """
  318. Computes the output length of the convolutional layers and the output length of the audio encoder
  319. """
  320. input_lengths = (input_lengths - 1) // 2 + 1
  321. output_lengths = (input_lengths - 2) // 2 + 1
  322. return input_lengths, output_lengths
  323. class VoxtralMultiModalProjector(nn.Module):
  324. def __init__(self, config: VoxtralConfig):
  325. super().__init__()
  326. self.linear_1 = nn.Linear(config.audio_config.intermediate_size, config.text_config.hidden_size, bias=False)
  327. self.act = ACT2FN[config.projector_hidden_act]
  328. self.linear_2 = nn.Linear(config.text_config.hidden_size, config.text_config.hidden_size, bias=False)
  329. def forward(self, audio_features):
  330. hidden_states = self.linear_1(audio_features)
  331. hidden_states = self.act(hidden_states)
  332. hidden_states = self.linear_2(hidden_states)
  333. return hidden_states
  334. @auto_docstring(
  335. custom_intro="""
  336. The Voxtral model, which consists of Whisper encoder, a multi-modal projector and a LLama language model.
  337. """
  338. )
  339. class VoxtralForConditionalGeneration(VoxtralPreTrainedModel, GenerationMixin):
  340. _tied_weights_keys = ["lm_head.weight"]
  341. _tp_plan = {"lm_head": "colwise_rep"}
  342. _pp_plan = {"lm_head": (["hidden_states"], ["logits"])}
  343. _keep_in_fp32_modules_strict = ["embed_positions"]
  344. def __init__(self, config):
  345. super().__init__(config)
  346. self.vocab_size = config.text_config.vocab_size
  347. self.audio_tower = AutoModel.from_config(config.audio_config)
  348. self.language_model = AutoModelForCausalLM.from_config(config.text_config)
  349. self.multi_modal_projector = VoxtralMultiModalProjector(config)
  350. # Initialize weights and apply final processing
  351. self.post_init()
  352. def get_input_embeddings(self):
  353. return self.language_model.get_input_embeddings()
  354. def set_input_embeddings(self, value):
  355. self.language_model.set_input_embeddings(value)
  356. def get_output_embeddings(self):
  357. return self.language_model.get_output_embeddings()
  358. def set_output_embeddings(self, new_embeddings):
  359. self.language_model.set_output_embeddings(new_embeddings)
  360. def set_decoder(self, decoder):
  361. self.language_model.set_decoder(decoder)
  362. def get_decoder(self):
  363. return self.language_model.get_decoder()
  364. def get_audio_features(self, input_features: torch.FloatTensor):
  365. """
  366. This method is used to get the audio embeddings from input features (a log mel spectrogram), meaning inferring the audio encoder and the multi-modal projector.
  367. Args:
  368. input_features (`torch.FloatTensor`):
  369. Float values of mel features extracted from the raw speech waveform. Raw speech waveform can be
  370. obtained by loading a `.flac` or `.wav` audio file into an array of type `list[float]` or a
  371. `numpy.ndarray`, *e.g.* via the soundfile library (`pip install soundfile`). To prepare the array into
  372. `input_features`, the [`AutoFeatureExtractor`] should be used for extracting the mel features, padding
  373. and conversion into a tensor of type `torch.FloatTensor`. See [`~WhisperFeatureExtractor.__call__`]
  374. Returns:
  375. `torch.FloatTensor`:
  376. The audio embeddings.
  377. """
  378. audio_outputs = self.audio_tower(input_features)
  379. audio_hidden_states = audio_outputs.last_hidden_state
  380. audio_hidden_states = audio_hidden_states.reshape(-1, self.config.audio_config.intermediate_size)
  381. audio_embeds = self.multi_modal_projector(audio_hidden_states)
  382. return audio_embeds
  383. def get_audio_embeds(self, input_features: torch.FloatTensor):
  384. warnings.warn(
  385. "The method `get_audio_embeds` is deprecated. Please use `get_audio_features` instead.", FutureWarning
  386. )
  387. return self.get_audio_features(input_features)
  388. @can_return_tuple
  389. @auto_docstring
  390. def forward(
  391. self,
  392. input_ids: Optional[torch.LongTensor] = None,
  393. input_features: Optional[torch.FloatTensor] = None,
  394. attention_mask: Optional[torch.Tensor] = None,
  395. position_ids: Optional[torch.LongTensor] = None,
  396. past_key_values: Optional[Cache] = None,
  397. inputs_embeds: Optional[torch.FloatTensor] = None,
  398. labels: Optional[torch.LongTensor] = None,
  399. use_cache: Optional[bool] = None,
  400. cache_position: Optional[torch.LongTensor] = None,
  401. logits_to_keep: Union[int, torch.Tensor] = 0,
  402. **kwargs: Unpack[TransformersKwargs],
  403. ) -> CausalLMOutputWithPast:
  404. r"""
  405. Example:
  406. ```python
  407. >>> from transformers import VoxtralForConditionalGeneration, AutoProcessor
  408. >>> import torch
  409. >>> device = "cuda" if torch.cuda.is_available() else "cpu"
  410. >>> repo_id = "mistralai/Voxtral-Mini-3B-2507"
  411. >>> processor = AutoProcessor.from_pretrained(repo_id)
  412. >>> model = VoxtralForConditionalGeneration.from_pretrained(repo_id, dtype=torch.bfloat16, device_map=device)
  413. >>> conversation = [
  414. {
  415. "role": "user",
  416. "content": [
  417. {
  418. "type": "audio",
  419. "url": "https://huggingface.co/datasets/hf-internal-testing/dummy-audio-samples/resolve/main/dude_where_is_my_car.wav",
  420. },
  421. {"type": "text", "text": "What can you tell me about this audio?"},
  422. ],
  423. }
  424. ]
  425. >>> inputs = processor.apply_chat_template(conversation)
  426. >>> inputs = inputs.to(device, dtype=torch.bfloat16)
  427. >>> outputs = model.generate(**inputs, max_new_tokens=30)
  428. >>> processor.batch_decode(outputs[:, inputs.input_ids.shape[1]:], skip_special_tokens=True)
  429. ["This audio is a humorous conversation between two friends, likely in English, where one of them is trying to figure out what the other's tattoo says."]
  430. ```"""
  431. if inputs_embeds is None:
  432. inputs_embeds = self.get_input_embeddings()(input_ids)
  433. if input_features is not None and input_ids is not None:
  434. audio_embeds = self.get_audio_features(input_features)
  435. # replace text-audio token placeholders with audio embeddings
  436. audio_token_mask = (input_ids == self.config.audio_token_id).unsqueeze(-1)
  437. inputs_embeds = inputs_embeds.masked_scatter(
  438. audio_token_mask.to(inputs_embeds.device), audio_embeds.to(inputs_embeds.device)
  439. )
  440. outputs: BaseModelOutputWithPast = self.language_model(
  441. attention_mask=attention_mask,
  442. position_ids=position_ids,
  443. past_key_values=past_key_values,
  444. inputs_embeds=inputs_embeds,
  445. labels=labels,
  446. use_cache=use_cache,
  447. cache_position=cache_position,
  448. logits_to_keep=logits_to_keep,
  449. **kwargs,
  450. )
  451. return outputs
  452. def prepare_inputs_for_generation(self, *args, **kwargs):
  453. # Overwritten -- we should not pass input_features when we are in cached decoding stage
  454. input_features = kwargs.pop("input_features", None)
  455. cache_position = kwargs.get("cache_position")
  456. model_inputs = super().prepare_inputs_for_generation(*args, **kwargs)
  457. if cache_position is not None and cache_position[0] == 0:
  458. # input_features should only be passed when we are not in cached decoding stage
  459. model_inputs["input_features"] = input_features
  460. return model_inputs
  461. __all__ = ["VoxtralPreTrainedModel", "VoxtralEncoder", "VoxtralForConditionalGeneration"]