processing_csm.py 16 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373
  1. # coding=utf-8
  2. # Copyright 2025 Sesame and The HuggingFace Inc. team. All rights reserved.
  3. #
  4. # Licensed under the Apache License, Version 2.0 (the "License");
  5. # you may not use this file except in compliance with the License.
  6. # You may obtain a copy of the License at
  7. #
  8. # http://www.apache.org/licenses/LICENSE-2.0
  9. #
  10. # Unless required by applicable law or agreed to in writing, software
  11. # distributed under the License is distributed on an "AS IS" BASIS,
  12. # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  13. # See the License for the specific language governing permissions and
  14. # limitations under the License.
  15. import math
  16. from pathlib import Path
  17. from typing import Any, Optional, Union
  18. import numpy as np
  19. from ...utils import is_soundfile_available, is_torch_available
  20. if is_torch_available():
  21. import torch
  22. if is_soundfile_available():
  23. import soundfile as sf
  24. from ...audio_utils import AudioInput, make_list_of_audio
  25. from ...feature_extraction_utils import BatchFeature
  26. from ...processing_utils import AudioKwargs, ProcessingKwargs, ProcessorMixin, Unpack
  27. from ...tokenization_utils_base import PreTokenizedInput, TextInput
  28. class CsmAudioKwargs(AudioKwargs, total=False):
  29. encoded_length_kwargs: Optional[dict[str, Any]]
  30. class CsmProcessorKwargs(ProcessingKwargs, total=False):
  31. audio_kwargs: CsmAudioKwargs
  32. _defaults = {
  33. "text_kwargs": {
  34. "padding": True,
  35. "padding_side": "left",
  36. "add_special_tokens": False,
  37. },
  38. "audio_kwargs": {
  39. "encoded_length_kwargs": {
  40. "kernel_sizes": [7, 3, 1, 8, 3, 1, 10, 3, 1, 12, 3, 1, 16, 3, 4],
  41. "strides": [1, 1, 1, 4, 1, 1, 5, 1, 1, 6, 1, 1, 8, 1, 2],
  42. "dilations": [1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1],
  43. "use_causal_conv": True,
  44. },
  45. "sampling_rate": 24000,
  46. },
  47. "common_kwargs": {"return_tensors": "pt"},
  48. }
  49. class CsmProcessor(ProcessorMixin):
  50. r"""
  51. Constructs a Csm processor which wraps [`EncodecFeatureExtractor`] and
  52. [`PretrainedTokenizerFast`] into a single processor that inherits both the audio feature extraction and
  53. tokenizer functionalities. See the [`~CsmProcessor.__call__`] for more
  54. information.
  55. The preferred way of passing kwargs is as a dictionary per modality, see usage example below.
  56. ```python
  57. from transformers import CsmProcessor
  58. from datasets import load_dataset
  59. ds = load_dataset("hf-internal-testing/dailytalk-dummy", split="train")
  60. audio = ds[0]["audio"]["array"]
  61. processor = CsmProcessor.from_pretrained("sesame/csm-1b")
  62. processor(
  63. text=["<|begin_of_text|>[0]What are you working on?<|end_of_text|><|AUDIO|><|audio_eos|><|begin_of_text|>[1]I'm figuring out my budget.<|end_of_text|>"],
  64. audio=audio,
  65. text_kwargs = {"padding": False},
  66. audio_kwargs = {"sampling_rate": 16000},
  67. common_kwargs = {"return_tensors": "pt"},
  68. )
  69. # this should error out because EncodecFeatureExtractor expects a 24kHz audio :)
  70. ```
  71. Args:
  72. feature_extractor ([`EncodecFeatureExtractor`]):
  73. The feature extractor is a required input.
  74. tokenizer ([`PreTrainedTokenizer`, `PreTrainedTokenizerFast`]):
  75. The tokenizer is a required input.
  76. chat_template (`str`, *optional*): A Jinja template which will be used to convert lists of messages
  77. in a chat into a tokenizable string.
  78. """
  79. attributes = ["feature_extractor", "tokenizer"]
  80. feature_extractor_class = "EncodecFeatureExtractor"
  81. tokenizer_class = "PreTrainedTokenizerFast"
  82. def __init__(
  83. self,
  84. feature_extractor,
  85. tokenizer,
  86. chat_template=None,
  87. ):
  88. if not hasattr(tokenizer, "audio_token"):
  89. self.audio_token = "<|AUDIO|>"
  90. self.audio_token_id = tokenizer.convert_tokens_to_ids(self.audio_token)
  91. else:
  92. self.audio_token = tokenizer.audio_token
  93. self.audio_token_id = tokenizer.audio_token_id
  94. if not hasattr(tokenizer, "audio_eos_token"):
  95. self.audio_eos_token = "<|audio_eos|>"
  96. self.audio_eos_token_id = tokenizer.convert_tokens_to_ids(self.audio_eos_token)
  97. else:
  98. self.audio_eos_token = tokenizer.audio_eos_token
  99. self.audio_eos_token_id = tokenizer.audio_eos_token_id
  100. super().__init__(feature_extractor, tokenizer, chat_template=chat_template)
  101. @staticmethod
  102. def _get_encoded_length(audio_length, kernel_sizes=None, strides=None, dilations=None, use_causal_conv=None):
  103. """
  104. Compute the length of the encoded audio sequence.
  105. Args:
  106. audio_length (int): The length of the audio sequence.
  107. kernel_sizes (list[int]): The kernel sizes for the convolutional layers.
  108. strides (list[int]): The strides for the convolutional layers.
  109. use_causal_conv (bool): Whether to use causal convolutions.
  110. """
  111. cur_length = audio_length
  112. if kernel_sizes is None or strides is None or dilations is None or use_causal_conv is None:
  113. return cur_length
  114. for kernel_size, stride, dilation in zip(kernel_sizes, strides, dilations):
  115. effective_kernel_size = (kernel_size - 1) * dilation + 1
  116. padding_total = kernel_size - stride
  117. padding_right = padding_total // 2
  118. padding_left = padding_total - padding_right
  119. n_frames = (cur_length - effective_kernel_size + padding_total) / stride + 1
  120. n_frames = math.ceil(n_frames) - 1
  121. ideal_length = n_frames * stride + kernel_size - padding_total
  122. extra_padding = ideal_length - cur_length
  123. if use_causal_conv:
  124. padding_left = padding_total
  125. padding_right = extra_padding
  126. else:
  127. padding_right = padding_right + extra_padding
  128. cur_length = cur_length + padding_left + padding_right
  129. cur_length = (cur_length - dilation * (kernel_size - 1) - 1) // stride + 1
  130. return cur_length
  131. def save_audio(
  132. self,
  133. audio: AudioInput,
  134. saving_path: Union[str, Path, list[Union[str, Path]]],
  135. **kwargs: Unpack[CsmProcessorKwargs],
  136. ):
  137. # TODO: @eustlb, this should be in AudioProcessor
  138. if not is_soundfile_available():
  139. raise ImportError("Please install `soundfile` to save audio files.")
  140. # ensure correct audio input
  141. audio = make_list_of_audio(audio)
  142. # ensure correct saving path
  143. if isinstance(saving_path, (str, Path)):
  144. saving_path = [saving_path]
  145. elif not (isinstance(saving_path, (list, tuple)) and all(isinstance(p, (str, Path)) for p in saving_path)):
  146. raise ValueError("Invalid input path. Please provide a string, or a list of strings")
  147. if len(audio) != len(saving_path):
  148. raise ValueError("The number of audio and saving paths must be the same")
  149. output_kwargs = self._merge_kwargs(
  150. CsmProcessorKwargs,
  151. **kwargs,
  152. )
  153. audio_kwargs = output_kwargs["audio_kwargs"]
  154. sampling_rate = audio_kwargs["sampling_rate"]
  155. for audio_value, p in zip(audio, saving_path):
  156. if isinstance(audio_value, torch.Tensor):
  157. audio_value = audio_value.cpu().float().numpy()
  158. sf.write(p, audio_value, sampling_rate)
  159. def __call__(
  160. self,
  161. text: Optional[Union[TextInput, PreTokenizedInput, list[TextInput], list[PreTokenizedInput]]],
  162. audio: Optional[AudioInput] = None,
  163. output_labels: Optional[bool] = False,
  164. depth_decoder_labels_ratio: Optional[float] = 1.0,
  165. **kwargs: Unpack[CsmProcessorKwargs],
  166. ):
  167. r"""
  168. Main method to prepare text(s) and audio to be fed as input to the model. This method forwards the `text`
  169. arguments to PreTrainedTokenizerFast's [`~PreTrainedTokenizerFast.__call__`] to encode
  170. the text. To prepare the audio, this method forwards the `audio` arguments to
  171. EncodecFeatureExtractor's [`~EncodecFeatureExtractor.__call__`]. Please refer
  172. to the docstring of the above two methods for more information.
  173. Args:
  174. audio (`np.ndarray`, `torch.Tensor`, `list[np.ndarray]`, `list[torch.Tensor]`):
  175. The audio or batch of audio to be prepared. Each audio can be a NumPy array or PyTorch
  176. tensor.
  177. text (`str`, `list[str]`, `list[list[str]]`):
  178. The sequence or batch of sequences to be encoded. Each sequence can be a string or a list of strings
  179. (pretokenized string). If the sequences are provided as list of strings (pretokenized), you must set
  180. `is_split_into_words=True` (to lift the ambiguity with a batch of sequences).
  181. output_labels (bool, *optional*, default=False):
  182. Whether to return labels for training. Indices will be in `[config.audio_token_id, -100, -101]`.
  183. - `config.audio_token_id` indicates an audio frame (considering sequence length elements as frames)
  184. - `-100` will be ignored in the loss computation
  185. - `-101` indicates the audio frame will be used only for the backbone model (using the first codebook token as labels)
  186. depth_decoder_labels_ratio (float, *optional*, default=1.0):
  187. The ratio of audio frames to keep for the depth decoder labels.
  188. return_tensors (`str` or [`~utils.TensorType`], *optional*):
  189. If set, will return tensors of a particular framework. Acceptable values are:
  190. - `'tf'`: Return TensorFlow `tf.constant` objects.
  191. - `'pt'`: Return PyTorch `torch.Tensor` objects.
  192. - `'np'`: Return NumPy `np.ndarray` objects.
  193. - `'jax'`: Return JAX `jnp.ndarray` objects.
  194. Returns:
  195. [`BatchFeature`]: A [`BatchFeature`] with the following fields:
  196. - **input_ids** -- List of token ids to be fed to a model. Returned when `text` is not `None`.
  197. - **input_values** -- List of audio values to be fed to a model. Returned when `audio` is not `None`.
  198. - **attention_mask** -- List of indices specifying which tokens should be attended to by the model (when
  199. `return_attention_mask=True` or if *"attention_mask"* is in `self.model_input_names` and if `text` is not
  200. `None`).
  201. - **labels** -- List of labels for the audio frames. Returned when `output_labels=True`.
  202. """
  203. output_kwargs = self._merge_kwargs(
  204. CsmProcessorKwargs,
  205. tokenizer_init_kwargs=self.tokenizer.init_kwargs,
  206. **kwargs,
  207. )
  208. text_kwargs = output_kwargs["text_kwargs"]
  209. audio_kwargs = output_kwargs["audio_kwargs"]
  210. common_kwargs = output_kwargs["common_kwargs"]
  211. return_tensors = common_kwargs.pop("return_tensors", None)
  212. if return_tensors != "pt":
  213. raise ValueError(f"{self.__class__.__name__} only supports `return_tensors='pt'`.")
  214. if isinstance(text, str):
  215. text = [text]
  216. elif not (isinstance(text, (list, tuple)) and all(isinstance(t, str) for t in text)):
  217. raise ValueError("Invalid input text. Please provide a string, or a list of strings")
  218. n_audio_in_text = [t.count(self.audio_token) for t in text]
  219. n_audio = 0
  220. if audio is not None:
  221. audio = make_list_of_audio(audio)
  222. n_audio = len(audio)
  223. if sum(n_audio_in_text) > 0 and n_audio != sum(n_audio_in_text):
  224. if audio is None:
  225. raise ValueError("No audio were provided, but there are audio tokens in the prompt")
  226. else:
  227. raise ValueError(
  228. f"The number of audio tokens in each text ({n_audio_in_text}) should be the same as the "
  229. f"number of provided audios ({n_audio})."
  230. )
  231. if audio is not None:
  232. encoded_length_kwargs = audio_kwargs.pop("encoded_length_kwargs", {})
  233. num_audio_tokens_list = [
  234. self._get_encoded_length(audio_array.shape[-1], **encoded_length_kwargs) for audio_array in audio
  235. ]
  236. num_audio_tokens_list_copy = num_audio_tokens_list.copy()
  237. # expand the text to repeat the audio token for the corresponding number of frames
  238. expanded_text = []
  239. for sample in text:
  240. replace_str = []
  241. while self.audio_token in sample:
  242. num_audio_tokens = num_audio_tokens_list_copy.pop(0)
  243. expanded_audio_token = self.audio_token * num_audio_tokens
  244. replace_str.append(expanded_audio_token)
  245. sample = sample.replace(self.audio_token, "<placeholder>", 1)
  246. while "<placeholder>" in sample:
  247. sample = sample.replace("<placeholder>", replace_str.pop(0), 1)
  248. expanded_text.append(sample)
  249. text = expanded_text
  250. encoding = self.tokenizer(text, **text_kwargs)
  251. data = {}
  252. data.update(encoding)
  253. if audio is not None:
  254. audio_kwargs.pop("return_attention_mask", None) # not supported by the feature extractor
  255. concatenated_audio, input_values_cutoffs = [], []
  256. offset = 0
  257. for n_audio in n_audio_in_text:
  258. if n_audio == 0:
  259. concatenated_audio.append(np.zeros(0))
  260. input_values_cutoffs.append(torch.tensor([-1]))
  261. else:
  262. concatenated_audio.append(
  263. np.concatenate(
  264. [
  265. el.cpu().numpy() if isinstance(el, torch.Tensor) else el
  266. for el in audio[offset : offset + n_audio]
  267. ],
  268. axis=-1,
  269. )
  270. )
  271. input_values_cutoffs.append(
  272. torch.tensor([el.shape[-1] for el in audio[offset : offset + n_audio]]).cumsum(dim=-1)
  273. )
  274. offset += n_audio
  275. audio_inputs = self.feature_extractor(concatenated_audio, **audio_kwargs)
  276. audio_inputs.pop("padding_mask", None) # not applicable here
  277. data.update(audio_inputs)
  278. # pad and stack the audio cut idxs
  279. max_len = max(cut_idxs.shape[-1] for cut_idxs in input_values_cutoffs)
  280. input_values_cutoffs = [
  281. torch.nn.functional.pad(cut_idxs, (0, max_len - cut_idxs.shape[-1]), value=-1)
  282. for cut_idxs in input_values_cutoffs
  283. ]
  284. data["input_values_cutoffs"] = torch.stack(input_values_cutoffs, dim=0)
  285. if output_labels:
  286. audio_frame_idxs = (data["input_ids"] == self.audio_token_id).nonzero()
  287. n_audio_frames = audio_frame_idxs.shape[0]
  288. if depth_decoder_labels_ratio <= 1.0:
  289. rand_idxs = torch.randperm(n_audio_frames)[: int(n_audio_frames * (1 - depth_decoder_labels_ratio))]
  290. skip_frames_idxs = audio_frame_idxs[rand_idxs]
  291. else:
  292. skip_frames_idxs = audio_frame_idxs
  293. labels = torch.where(
  294. (data["input_ids"] == self.audio_token_id) | (data["input_ids"] == self.audio_eos_token_id),
  295. data["input_ids"],
  296. -100,
  297. )
  298. labels[skip_frames_idxs[:, 0], skip_frames_idxs[:, 1]] = -101
  299. data["labels"] = labels
  300. return BatchFeature(data=data, tensor_type=return_tensors)
  301. @property
  302. def model_input_names(self):
  303. tokenizer_input_names = self.tokenizer.model_input_names
  304. feature_extractor_input_names = self.feature_extractor.model_input_names
  305. # Remove `padding_mask`, it is popped and not used when processing. Make a copy of list when removing
  306. # otherwise `self.feature_extractor.model_input_names` is also modified
  307. feature_extractor_input_names = [name for name in feature_extractor_input_names if name != "padding_mask"]
  308. return list(tokenizer_input_names + feature_extractor_input_names + ["input_values_cutoffs"])
  309. __all__ = ["CsmProcessor"]