processing_voxtral.py 19 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435
  1. # coding=utf-8
  2. # Copyright 2025 Sesame and The HuggingFace Inc. team. All rights reserved.
  3. #
  4. # Licensed under the Apache License, Version 2.0 (the "License");
  5. # you may not use this file except in compliance with the License.
  6. # You may obtain a copy of the License at
  7. #
  8. # http://www.apache.org/licenses/LICENSE-2.0
  9. #
  10. # Unless required by applicable law or agreed to in writing, software
  11. # distributed under the License is distributed on an "AS IS" BASIS,
  12. # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  13. # See the License for the specific language governing permissions and
  14. # limitations under the License.
  15. import io
  16. from typing import Optional, Union
  17. from ...utils import is_mistral_common_available, is_soundfile_available, is_torch_available, logging
  18. if is_torch_available():
  19. import torch
  20. if is_soundfile_available():
  21. import soundfile as sf
  22. if is_mistral_common_available():
  23. from mistral_common.protocol.transcription.request import TranscriptionRequest
  24. from ...audio_utils import AudioInput, load_audio_as, make_list_of_audio
  25. from ...feature_extraction_utils import BatchFeature
  26. from ...processing_utils import AllKwargsForChatTemplate, AudioKwargs, ProcessingKwargs, ProcessorMixin, Unpack
  27. from ...tokenization_utils_base import PreTokenizedInput, TextInput
  28. logger = logging.get_logger(__name__)
  29. class VoxtralAudioKwargs(AudioKwargs, total=False):
  30. max_source_positions: Optional[int]
  31. class VoxtralProcessorKwargs(ProcessingKwargs, total=False):
  32. _defaults = {
  33. "text_kwargs": {
  34. "padding": True,
  35. },
  36. "audio_kwargs": {
  37. "sampling_rate": 16000,
  38. "padding": True,
  39. "truncation": False,
  40. "pad_to_multiple_of": 480000,
  41. "max_source_positions": 3000,
  42. },
  43. "common_kwargs": {
  44. "return_tensors": "pt",
  45. "return_dict": True,
  46. "tokenize": True,
  47. },
  48. }
  49. class VoxtralProcessor(ProcessorMixin):
  50. r"""
  51. Constructs a Voxtral processor which wraps [`WhisperFeatureExtractor`] and
  52. [`MistralCommonTokenizer`] into a single processor that inherits both the audio feature extraction and
  53. tokenizer functionalities.
  54. Args:
  55. feature_extractor ([`WhisperFeatureExtractor`]):
  56. The feature extractor is a required input.
  57. tokenizer ([`MistralCommonTokenizer`]):
  58. The tokenizer is a required input.
  59. """
  60. attributes = ["feature_extractor", "tokenizer"]
  61. feature_extractor_class = "WhisperFeatureExtractor"
  62. tokenizer_class = "MistralCommonTokenizer"
  63. def __init__(
  64. self,
  65. feature_extractor,
  66. tokenizer,
  67. ):
  68. self.audio_token_id = 24
  69. self.audio_token = tokenizer.convert_ids_to_tokens(self.audio_token_id)
  70. super().__init__(feature_extractor, tokenizer)
  71. def _retrieve_input_features(self, audio, max_source_positions, **kwargs):
  72. """
  73. Handles specific logic of Voxtral expected input features: audio arrays should be padded to next multiple of 480000 (duration is a multiple of 30s), see VoxtralProcessorKwargs' default audio_kwargs.
  74. Then mel input features are extracted and stacked along batch dimension, splitting into chunks of max_source_positions.
  75. """
  76. input_features_list = []
  77. for audio_array in audio:
  78. audio_inputs = self.feature_extractor(audio_array, **kwargs)
  79. # let's split into chunks of max_source_positions, and then stack them along batch dimension
  80. input_features = audio_inputs["input_features"].reshape(
  81. self.feature_extractor.feature_size, -1, max_source_positions
  82. )
  83. input_features_list.append(input_features.transpose(0, 1))
  84. return torch.cat(input_features_list)
  85. def apply_chat_template(
  86. self,
  87. conversation: Union[list[dict[str, str]], list[list[dict[str, str]]]],
  88. **kwargs: Unpack[AllKwargsForChatTemplate],
  89. ) -> str:
  90. """
  91. This method applies the model's chat completion template given a conversation. It relies on MistralCommonTokenizer's
  92. [`~MistralCommonTokenizer.apply_chat_template`] to prepare input ids to the model and on WhisperFeatureExtractor's
  93. [`~WhisperFeatureExtractor.__call__`] to prepare input features to the model.
  94. Note that audio is padded to the nearest 30-second multiple prior to mel feature extraction.
  95. A `conversation` is a list of messages, where each message is a dictionary with a `role` and a `content` field.
  96. For Voxtral, `role` can be `"user"` or `"assistant"`.
  97. The `content` field can be a string or a list of dictionaries with a `type` field. See example below.
  98. ```python
  99. from huggingface_hub import hf_hub_download
  100. from transformers.audio_utils import load_audio_as
  101. audio_url = "https://huggingface.co/datasets/hf-internal-testing/dummy-audio-samples/resolve/main/bcn_weather.mp3"
  102. audio_path = hf_hub_download(repo_id="hf-internal-testing/dummy-audio-samples", filename="bcn_weather.mp3", repo_type="dataset")
  103. audio_base64 = load_audio_as(audio_path, return_format="base64", force_mono=True)
  104. # audio + text
  105. conversation = [
  106. {
  107. "role": "user",
  108. "content": [
  109. {"type": "audio", "url": audio_url},
  110. {"type": "audio", "path": audio_path},
  111. {"type": "audio", "base64": audio_base64},
  112. {"type": "text", "text": "How many audio do you hear?"},
  113. ],
  114. },
  115. ]
  116. processor = VoxtralProcessor.from_pretrained("mistralai/Voxtral-Mini-3B-2507")
  117. inputs = processor.apply_chat_template(conversation)
  118. ```
  119. Args:
  120. conversation (`Union[list[Dict, [str, str]], list[list[dict[str, str]]]]`):
  121. The conversation to format.
  122. """
  123. if kwargs.get("continue_final_message", False):
  124. if kwargs.get("add_generation_prompt", False):
  125. raise ValueError(
  126. "continue_final_message and add_generation_prompt are not compatible. Use continue_final_message when you want the model to continue the final message, and add_generation_prompt when you want to add a header that will prompt it to start a new assistant message instead."
  127. )
  128. if kwargs.get("return_assistant_tokens_mask", False):
  129. raise ValueError("continue_final_message is not compatible with return_assistant_tokens_mask.")
  130. # Fill sets of kwargs that should be used by different parts of template
  131. processed_kwargs = {
  132. "mm_load_kwargs": {},
  133. "template_kwargs": {},
  134. }
  135. for kwarg_type in processed_kwargs:
  136. for key in AllKwargsForChatTemplate.__annotations__[kwarg_type].__annotations__:
  137. kwarg_type_defaults = AllKwargsForChatTemplate.__annotations__[kwarg_type]
  138. default_value = getattr(kwarg_type_defaults, key, None)
  139. value = kwargs.pop(key, default_value)
  140. if value is not None and not isinstance(value, dict):
  141. processed_kwargs[kwarg_type][key] = value
  142. # Pass unprocessed custom kwargs
  143. processed_kwargs["template_kwargs"].update(kwargs)
  144. if isinstance(conversation, (list, tuple)) and (
  145. isinstance(conversation[0], (list, tuple)) or hasattr(conversation[0], "content")
  146. ):
  147. is_batched = True
  148. conversations = conversation
  149. else:
  150. is_batched = False
  151. conversations = [conversation]
  152. # Check for any overlapping keys between mm_load_kwargs and kwargs
  153. mm_load_kwargs = processed_kwargs["mm_load_kwargs"]
  154. if any(key in kwargs for key in mm_load_kwargs):
  155. overlapping_keys = [key for key in mm_load_kwargs if key in kwargs]
  156. logger.warning(
  157. f"{overlapping_keys[0] if len(overlapping_keys) == 1 else ', '.join(overlapping_keys)} load multimodal data kwarg{'s' if len(overlapping_keys) > 1 else ''} {'have' if len(overlapping_keys) > 1 else 'has'} been passed to the processor, but {'they are' if len(overlapping_keys) > 1 else 'it is'} not supported for VoxtralProcessor since it relies on mistral_common directly. {'They' if len(overlapping_keys) > 1 else 'It'} will be ignored."
  158. )
  159. output_kwargs = self._merge_kwargs(
  160. VoxtralProcessorKwargs,
  161. **kwargs,
  162. )
  163. text_kwargs = output_kwargs["text_kwargs"]
  164. audio_kwargs = output_kwargs["audio_kwargs"]
  165. common_kwargs = output_kwargs["common_kwargs"]
  166. return_tensors = common_kwargs.pop("return_tensors", None)
  167. if return_tensors != "pt":
  168. raise ValueError(f"{self.__class__.__name__} only supports `return_tensors='pt'`.")
  169. tokenizer_kwargs = {**processed_kwargs["template_kwargs"], **text_kwargs}
  170. tokenizer_kwargs["return_tensors"] = None # let's not return tensors here
  171. tokenize = tokenizer_kwargs.pop("tokenize", False)
  172. return_dict = tokenizer_kwargs.pop("return_dict", False)
  173. encoded_instruct_inputs = self.tokenizer.apply_chat_template(
  174. conversations,
  175. tokenize=tokenize,
  176. return_dict=return_dict,
  177. **tokenizer_kwargs,
  178. )
  179. if tokenize:
  180. if return_dict:
  181. audio = encoded_instruct_inputs.pop("audio", None)
  182. data = dict(encoded_instruct_inputs)
  183. if audio is not None:
  184. max_source_positions = audio_kwargs.pop("max_source_positions")
  185. data["input_features"] = self._retrieve_input_features(audio, max_source_positions, **audio_kwargs)
  186. return BatchFeature(data=data, tensor_type=return_tensors)
  187. if not is_batched:
  188. return encoded_instruct_inputs[0]
  189. return encoded_instruct_inputs
  190. def __call__(
  191. self,
  192. text: Optional[Union[TextInput, PreTokenizedInput, list[TextInput], list[PreTokenizedInput]]],
  193. **kwargs: Unpack[VoxtralProcessorKwargs],
  194. ):
  195. r"""
  196. Method to prepare text to be fed as input to the model. This method forwards the `text`
  197. arguments to MistralCommonTokenizer's [`~MistralCommonTokenizer.__call__`] to encode
  198. the text. Please refer to the docstring of the above methods for more information.
  199. This methods does not support audio. To prepare the audio, please use:
  200. 1. `apply_chat_template` [`~VoxtralProcessor.apply_chat_template`] method.
  201. 2. `apply_transcription_request` [`~VoxtralProcessor.apply_transcription_request`] method.
  202. Args:
  203. text (`str`, `list[str]`, `list[list[str]]`):
  204. The sequence or batch of sequences to be encoded. Each sequence can be a string or a list of strings
  205. (pretokenized string). If the sequences are provided as list of strings (pretokenized), you must set
  206. `is_split_into_words=True` (to lift the ambiguity with a batch of sequences).
  207. return_tensors (`str` or [`~utils.TensorType`], *optional*):
  208. If set, will return tensors of a particular framework. Acceptable values are:
  209. - `'tf'`: Return TensorFlow `tf.constant` objects.
  210. - `'pt'`: Return PyTorch `torch.Tensor` objects.
  211. - `'np'`: Return NumPy `np.ndarray` objects.
  212. - `'jax'`: Return JAX `jnp.ndarray` objects.
  213. Returns:
  214. [`BatchFeature`]: A [`BatchFeature`] with the following fields:
  215. - **input_ids** -- List of token ids to be fed to a model. Returned when `text` is not `None`.
  216. - **input_features** -- List of audio values to be fed to a model. Returned when `audio` is not `None`.
  217. - **attention_mask** -- List of indices specifying which tokens should be attended to by the model (when
  218. `return_attention_mask=True` or if *"attention_mask"* is in `self.model_input_names` and if `text` is not
  219. `None`).
  220. """
  221. if isinstance(text, str):
  222. text = [text]
  223. if any(self.audio_token in t for t in text):
  224. raise ValueError(
  225. f"{self.audio_token} is present in the provided text which is not supported by VoxtralProcessor. Please use the `apply_chat_template` method instead."
  226. )
  227. output_kwargs = self._merge_kwargs(
  228. VoxtralProcessorKwargs,
  229. **kwargs,
  230. )
  231. text_kwargs = output_kwargs["text_kwargs"]
  232. common_kwargs = output_kwargs["common_kwargs"]
  233. out = self.tokenizer(text, **text_kwargs)
  234. return BatchFeature(data=out, tensor_type=common_kwargs.pop("return_tensors", None))
  235. # TODO: @eustlb, this should be moved to mistral_common + testing
  236. def apply_transcription_request(
  237. self,
  238. language: Union[str, list[str]],
  239. audio: Union[str, list[str], AudioInput],
  240. model_id: str,
  241. sampling_rate: Optional[int] = None,
  242. format: Optional[Union[str, list[str]]] = None,
  243. **kwargs: Unpack[VoxtralProcessorKwargs],
  244. ):
  245. """
  246. This method applies the model's transcription request template given a language and audio.
  247. It relies on MistralCommonTokenizer and WhisperFeatureExtractor to prepare input ids and input features to the model.
  248. ```python
  249. from transformers import VoxtralProcessor
  250. model_id = "mistralai/Voxtral-Mini-3B-2507"
  251. processor = VoxtralProcessor.from_pretrained(model_id)
  252. language = "en"
  253. audio = "https://huggingface.co/datasets/hf-internal-testing/dummy-audio-samples/resolve/main/obama.mp3"
  254. inputs = processor.apply_transcription_request(language=language, audio=audio, model_id=model_id)
  255. ```
  256. Args:
  257. language (`str`, `list[str]`):
  258. The language or languages of the audio. If provided as a string, will be applied uniformly to all audio.
  259. If provided as a list, will be applied to each audio individually with a one-to-one mapping.
  260. audio (`str`, `list[str]`, `np.ndarray`, `torch.Tensor`, `list[np.ndarray]`, `list[torch.Tensor]`):
  261. The audio or batch of audio to be prepared. If provided as a string, it should correspond to the path or url of the audio file.
  262. model_id (`str`:
  263. The hub model id of the model to use for transcription.
  264. sampling_rate (`int`, *optional*):
  265. The sampling rate of the audio. Necessary if it is provided as `np.ndarray`, `torch.Tensor`, `list[np.ndarray]`, `list[torch.Tensor]`.
  266. Used to avoid silent errors when passing audio that is not in the expected sampling rate.
  267. format (`str`, `list[str]`, *optional*):
  268. The format of the audio, necessary if is provided as `np.ndarray`, `torch.Tensor`, `list[np.ndarray]`, `list[torch.Tensor]`.
  269. """
  270. output_kwargs = self._merge_kwargs(
  271. VoxtralProcessorKwargs,
  272. **kwargs,
  273. )
  274. text_kwargs = output_kwargs["text_kwargs"]
  275. audio_kwargs = output_kwargs["audio_kwargs"]
  276. common_kwargs = output_kwargs["common_kwargs"]
  277. is_str = isinstance(audio, str)
  278. is_list_of_str = all(isinstance(el, str) for el in audio)
  279. is_list_of_audio = not (is_str or is_list_of_str)
  280. if is_list_of_audio:
  281. if sampling_rate is None:
  282. logger.warning_once(
  283. f"You've provided audio without specifying the sampling rate. It will be assumed to be {audio_kwargs['sampling_rate']}, which can result in silent errors."
  284. )
  285. elif sampling_rate != audio_kwargs["sampling_rate"]:
  286. raise ValueError(
  287. f"The sampling rate of the audio ({sampling_rate}) does not match the sampling rate of the processor ({audio_kwargs['sampling_rate']}). Please provide resampled the audio to the expected sampling rate."
  288. )
  289. sampling_rate = audio_kwargs["sampling_rate"]
  290. return_dict = common_kwargs.pop("return_dict", False)
  291. tokenize = common_kwargs.pop("tokenize", False)
  292. # make sure to remove from text_kwargs and audio_kwargs
  293. for k in ("return_dict", "tokenize"):
  294. text_kwargs.pop(k, None)
  295. audio_kwargs.pop(k, None)
  296. return_tensors = common_kwargs.pop("return_tensors", None)
  297. if return_tensors != "pt":
  298. raise ValueError(f"{self.__class__.__name__} only supports `return_tensors='pt'`.")
  299. # validate audio input
  300. if is_str:
  301. audio = [load_audio_as(audio, return_format="buffer", force_mono=True, sampling_rate=sampling_rate)]
  302. elif is_list_of_str:
  303. audio = [
  304. load_audio_as(el, return_format="buffer", force_mono=True, sampling_rate=sampling_rate) for el in audio
  305. ]
  306. else:
  307. audio = make_list_of_audio(audio)
  308. if len(audio) != len(format):
  309. raise ValueError(
  310. f"When passed as a list of audio, the length ({len(audio)}) must match the number of format ({len(format)})"
  311. )
  312. audio_buffers = []
  313. for array, f in zip(audio, format):
  314. # Create new BytesIO object and write audio data to it
  315. buffer = io.BytesIO()
  316. # Convert to mono if needed
  317. if array.ndim == 2:
  318. array = array.mean(axis=1)
  319. # Write to buffer with default format and sampling rate
  320. sf.write(buffer, array, samplerate=audio_kwargs["sampling_rate"], format=f)
  321. buffer.seek(0)
  322. audio_buffers.append(buffer)
  323. audio = audio_buffers
  324. # validate language input
  325. n_audio = len(audio)
  326. if isinstance(language, str):
  327. language = [language] * n_audio
  328. if len(language) != n_audio:
  329. raise ValueError(
  330. f"When passed as a list of languages, the length ({len(language)}) must match the number of audio ({n_audio})"
  331. )
  332. input_ids = []
  333. texts = []
  334. audio_arrays = []
  335. for audio_el, language_el in zip(audio, language):
  336. openai_transcription_request = {
  337. "model": model_id,
  338. "file": audio_el,
  339. "language": language_el,
  340. }
  341. transcription_request = TranscriptionRequest.from_openai(openai_transcription_request)
  342. tokenized_transcription_request = self.tokenizer.tokenizer.encode_transcription(transcription_request)
  343. input_ids.append(tokenized_transcription_request.tokens)
  344. texts.append(tokenized_transcription_request.text)
  345. audio_arrays.extend([el.audio_array for el in tokenized_transcription_request.audios])
  346. if tokenize:
  347. if return_dict:
  348. # text are already tokenized but we need to pad etc
  349. encoding = self.tokenizer(
  350. input_ids,
  351. add_special_tokens=False,
  352. **text_kwargs,
  353. )
  354. data = dict(encoding)
  355. # extract the input features
  356. max_source_positions = audio_kwargs.pop("max_source_positions")
  357. data["input_features"] = self._retrieve_input_features(
  358. audio_arrays, max_source_positions, **audio_kwargs
  359. )
  360. return BatchFeature(data=data, tensor_type=return_tensors)
  361. return texts
  362. __all__ = ["VoxtralProcessor"]