processing_dia.py 20 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474
  1. # coding=utf-8
  2. # Copyright 2025 The HuggingFace Inc. team.
  3. #
  4. # Licensed under the Apache License, Version 2.0 (the "License");
  5. # you may not use this file except in compliance with the License.
  6. # You may obtain a copy of the License at
  7. #
  8. # http://www.apache.org/licenses/LICENSE-2.0
  9. #
  10. # Unless required by applicable law or agreed to in writing, software
  11. # distributed under the License is distributed on an "AS IS" BASIS,
  12. # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  13. # See the License for the specific language governing permissions and
  14. # limitations under the License.
  15. """Processor class for Dia"""
  16. import math
  17. from pathlib import Path
  18. from typing import Optional, Union
  19. from ...audio_utils import AudioInput, make_list_of_audio
  20. from ...feature_extraction_utils import BatchFeature
  21. from ...processing_utils import AudioKwargs, ProcessingKwargs, ProcessorMixin, Unpack
  22. from ...utils import is_soundfile_available, is_torch_available
  23. if is_torch_available():
  24. import torch
  25. if is_soundfile_available():
  26. import soundfile as sf
  27. class DiaAudioKwargs(AudioKwargs, total=False):
  28. bos_token_id: int
  29. eos_token_id: int
  30. pad_token_id: int
  31. delay_pattern: list[int]
  32. generation: bool
  33. class DiaProcessorKwargs(ProcessingKwargs, total=False):
  34. audio_kwargs: DiaAudioKwargs
  35. _defaults = {
  36. "text_kwargs": {
  37. "padding": True,
  38. "padding_side": "right",
  39. "add_special_tokens": False,
  40. },
  41. "audio_kwargs": {
  42. "eos_token_id": 1024,
  43. "pad_token_id": 1025,
  44. "bos_token_id": 1026,
  45. "delay_pattern": [0, 8, 9, 10, 11, 12, 13, 14, 15],
  46. "generation": True,
  47. "sampling_rate": 44100,
  48. },
  49. "common_kwargs": {"return_tensors": "pt"},
  50. }
  51. class DiaProcessor(ProcessorMixin):
  52. r"""
  53. Constructs a Dia processor which wraps a [`DiaFeatureExtractor`], [`DiaTokenizer`], and a [`DacModel`] into
  54. a single processor. It inherits, the audio feature extraction, tokenizer, and audio encode/decode functio-
  55. nalities. See [`~DiaProcessor.__call__`], [`~DiaProcessor.encode`], and [`~DiaProcessor.decode`] for more
  56. information.
  57. Args:
  58. feature_extractor (`DiaFeatureExtractor`):
  59. An instance of [`DiaFeatureExtractor`]. The feature extractor is a required input.
  60. tokenizer (`DiaTokenizer`):
  61. An instance of [`DiaTokenizer`]. The tokenizer is a required input.
  62. audio_tokenizer (`DacModel`):
  63. An instance of [`DacModel`] used to encode/decode audio into/from codebooks. It is is a required input.
  64. """
  65. feature_extractor_class = "DiaFeatureExtractor"
  66. tokenizer_class = "DiaTokenizer"
  67. audio_tokenizer_class = "DacModel"
  68. def __init__(self, feature_extractor, tokenizer, audio_tokenizer):
  69. super().__init__(feature_extractor, tokenizer, audio_tokenizer=audio_tokenizer)
  70. def __call__(
  71. self,
  72. text: Union[str, list[str]],
  73. audio: Optional[AudioInput] = None,
  74. output_labels: Optional[bool] = False,
  75. **kwargs: Unpack[DiaProcessorKwargs],
  76. ):
  77. """
  78. Main method to prepare text(s) and audio to be fed as input to the model. The `audio` argument is
  79. forwarded to the DiaFeatureExtractor's [`~DiaFeatureExtractor.__call__`] and subsequently to the
  80. DacModel's [`~DacModel.encode`]. The `text` argument to [`~DiaTokenizer.__call__`]. Please refer
  81. to the docstring of the above methods for more information.
  82. """
  83. if not is_torch_available():
  84. raise ValueError(
  85. "The `DiaProcessor` relies on the `audio_tokenizer` which requires `torch` but we couldn't "
  86. "find it in your environment. You can install torch via `pip install torch`."
  87. )
  88. if text is None:
  89. raise ValueError("You need to specify the `text` input to process.")
  90. output_kwargs = self._merge_kwargs(
  91. DiaProcessorKwargs,
  92. **kwargs,
  93. )
  94. text_kwargs = output_kwargs["text_kwargs"]
  95. audio_kwargs = output_kwargs["audio_kwargs"]
  96. common_kwargs = output_kwargs["common_kwargs"]
  97. return_tensors = common_kwargs.pop("return_tensors", None)
  98. if return_tensors != "pt":
  99. raise ValueError(f"{self.__class__.__name__} only supports `return_tensors='pt'`.")
  100. data = {}
  101. # Text
  102. if isinstance(text, str):
  103. text = [text]
  104. elif not (isinstance(text, (list, tuple)) and all(isinstance(t, str) for t in text)):
  105. raise ValueError("Invalid input text. Please provide a string, or a list of strings")
  106. encodings = self.tokenizer(text, **text_kwargs)
  107. data.update(encodings)
  108. # Audio
  109. delay_pattern = audio_kwargs.pop("delay_pattern", None)
  110. audio_bos_token_id = audio_kwargs.pop("bos_token_id", None)
  111. audio_eos_token_id = audio_kwargs.pop("eos_token_id", None)
  112. audio_pad_token_id = audio_kwargs.pop("pad_token_id", None)
  113. generation = audio_kwargs.pop("generation", True)
  114. if (
  115. audio_bos_token_id is None
  116. or audio_eos_token_id is None
  117. or audio_pad_token_id is None
  118. or delay_pattern is None
  119. ):
  120. raise ValueError(
  121. "To enable processing for Dia, we need the `bos_token_id`, `eos_token_id`, "
  122. "`pad_token_id`, and `delay_pattern`. You may have accidentally overwritten one of those."
  123. )
  124. if generation and output_labels:
  125. raise ValueError(
  126. f"Labels with `generation` is incompatible, got generation={generation}, output_labels={output_labels}."
  127. )
  128. batch_size = data["input_ids"].shape[0]
  129. num_channels = len(delay_pattern)
  130. max_delay = max(delay_pattern)
  131. # Voice cloning generation / general training
  132. if audio is not None:
  133. audio = make_list_of_audio(audio)
  134. input_audios = self.feature_extractor(audio, **audio_kwargs)
  135. compression_rate = math.prod(self.audio_tokenizer.config.downsampling_ratios)
  136. max_encoded_sequence_len = input_audios["padding_mask"][0].shape[-1] // compression_rate
  137. decoder_input_ids = []
  138. decoder_attention_mask = []
  139. # TODO: dac with batching is currently broken, but non-batch is working
  140. # refer to https://gist.github.com/vasqu/643a45b680cf39fd7467271ee2eb6f80 for a validation script
  141. for padding_mask, audio in zip(input_audios["padding_mask"], input_audios["input_values"]):
  142. # get current length with hop length in mind (as if it were sampled as a single audio)
  143. base_pad_len = self.feature_extractor.hop_length
  144. current_audio_len = math.ceil(padding_mask.sum(dim=-1) / base_pad_len) * base_pad_len
  145. encoded_sequence_len = current_audio_len // compression_rate
  146. padding_len = max_encoded_sequence_len - encoded_sequence_len
  147. # compute non-padded forward pass; one extra bos (and eos if training) is added
  148. with torch.no_grad():
  149. audio = audio[None, ..., :current_audio_len].to(self.audio_tokenizer.device)
  150. input_ids = self.audio_tokenizer.encode(audio).audio_codes.transpose(1, 2)
  151. if not generation:
  152. input_ids = torch.nn.functional.pad(
  153. input_ids, pad=(0, 0, 0, 1, 0, 0), mode="constant", value=audio_eos_token_id
  154. )
  155. # apply padding
  156. # +1 for the bos within the real sequence
  157. input_ids = torch.nn.functional.pad(
  158. input_ids, pad=(0, 0, padding_len + 1, 0, 0, 0), mode="constant", value=audio_bos_token_id
  159. )
  160. num_valid_inputs = encoded_sequence_len + 1 + max_delay # sequence + bos + delay
  161. num_valid_inputs += 0 if generation else 1 # eos if training
  162. attention_mask = torch.tensor([0] * padding_len + [1] * num_valid_inputs, dtype=torch.long)[None, :]
  163. decoder_input_ids.append(input_ids)
  164. decoder_attention_mask.append(attention_mask)
  165. decoder_input_ids = torch.cat(decoder_input_ids, dim=0)
  166. decoder_attention_mask = torch.cat(decoder_attention_mask, dim=0)
  167. # TTS generation
  168. elif generation:
  169. # all bos to start with TTS
  170. decoder_input_ids = torch.full((batch_size, 1, num_channels), audio_bos_token_id, dtype=torch.long)
  171. # we preemptively add the delay
  172. decoder_attention_mask = torch.ones(size=(batch_size, 1 + max_delay), dtype=torch.long)
  173. else:
  174. raise ValueError("If you try to train, you should provide audio data as well.")
  175. if batch_size != decoder_input_ids.shape[0]:
  176. raise ValueError(
  177. f"Need the same amount of samples for both text and audio, but got text samples={batch_size} and "
  178. f"audio samples = {decoder_input_ids.shape[0]} instead."
  179. )
  180. # prepare shift indices per delay
  181. max_seq_len = decoder_attention_mask.shape[-1]
  182. max_audio_len = max_seq_len - max_delay
  183. precomputed_idx = self.build_indices(
  184. bsz=batch_size,
  185. seq_len=max_seq_len,
  186. num_channels=num_channels,
  187. delay_pattern=delay_pattern,
  188. revert=False,
  189. )
  190. # create delay pattern input
  191. # the pad token will be used for masking which input is valid for prediction during generation
  192. prefill = torch.full(
  193. (batch_size, max_seq_len, num_channels),
  194. fill_value=audio_pad_token_id,
  195. dtype=torch.int,
  196. )
  197. prefill[:, :max_audio_len] = decoder_input_ids
  198. delayed_decoder_input_ids = self.apply_audio_delay(
  199. audio=prefill,
  200. pad_token_id=audio_pad_token_id,
  201. bos_token_id=audio_bos_token_id,
  202. precomputed_idx=precomputed_idx,
  203. )
  204. data.update({"decoder_input_ids": delayed_decoder_input_ids, "decoder_attention_mask": decoder_attention_mask})
  205. if output_labels:
  206. # Base idea is to shift on the sequence dim
  207. labels = data["decoder_input_ids"].clone()[:, 1:]
  208. labels[labels == audio_pad_token_id] = -100
  209. labels[labels == audio_bos_token_id] = -100
  210. data["labels"] = labels.transpose(1, 2).reshape(batch_size * num_channels, -1).contiguous().long()
  211. data["decoder_input_ids"] = data["decoder_input_ids"][:, :-1]
  212. data["decoder_attention_mask"] = data["decoder_attention_mask"][:, :-1]
  213. return BatchFeature(data=data, tensor_type=return_tensors)
  214. def batch_decode(
  215. self,
  216. decoder_input_ids: "torch.Tensor",
  217. audio_prompt_len: Optional[int] = None,
  218. **kwargs: Unpack[DiaProcessorKwargs],
  219. ) -> list["torch.Tensor"]:
  220. """
  221. Decodes a batch of audio codebook sequences into their respective audio waveforms via the
  222. `audio_tokenizer`. See [`~DacModel.decode`] for more information.
  223. Args:
  224. decoder_input_ids (`torch.Tensor`): The complete output sequence of the decoder.
  225. audio_prompt_len (`int`): The audio prefix length (e.g. when using voice cloning).
  226. """
  227. output_kwargs = self._merge_kwargs(
  228. DiaProcessorKwargs,
  229. **kwargs,
  230. )
  231. audio_kwargs = output_kwargs["audio_kwargs"]
  232. delay_pattern = audio_kwargs.pop("delay_pattern", None)
  233. audio_bos_token_id = audio_kwargs.pop("bos_token_id", None)
  234. audio_pad_token_id = audio_kwargs.pop("pad_token_id", None)
  235. if audio_bos_token_id is None or audio_pad_token_id is None or delay_pattern is None:
  236. raise ValueError(
  237. "To enable decoding for Dia, we need the `bos_token_id`, `pad_token_id`, "
  238. "and `delay_pattern`. You may have accidentally overwritten one of those."
  239. )
  240. # either decode the whole audio sequence or only the generated parts
  241. if audio_prompt_len is not None:
  242. audio_prompt_len = torch.tensor(audio_prompt_len, device=decoder_input_ids.device, dtype=torch.long)
  243. start_of_generation_idx = audio_prompt_len[None].expand(decoder_input_ids.shape[0])
  244. else:
  245. start_of_generation_idx = (decoder_input_ids[:, :, 0] == audio_bos_token_id).sum(dim=-1)
  246. # -1 for the eos token
  247. end_of_generation_idx = (
  248. decoder_input_ids.shape[1] - (decoder_input_ids[:, :, 0] == audio_pad_token_id).sum(dim=-1) - 1
  249. )
  250. # revert delay
  251. bsz, seq_len, num_channels = decoder_input_ids.shape
  252. precomputed_idx = self.build_indices(
  253. bsz=bsz,
  254. seq_len=seq_len,
  255. num_channels=num_channels,
  256. delay_pattern=delay_pattern,
  257. revert=True,
  258. )
  259. output_sequences = self.apply_audio_delay(
  260. audio=decoder_input_ids,
  261. # We do not care about these values as we cut them out
  262. # with `start_of_generation_idx` and `end_of_generation_idx`
  263. pad_token_id=-1,
  264. bos_token_id=-1,
  265. precomputed_idx=precomputed_idx,
  266. ).transpose(1, 2)
  267. # retrieve the correct sequences each
  268. audios = []
  269. # TODO: see above, dac doesn't work in batches yet
  270. with torch.no_grad():
  271. for i in range(start_of_generation_idx.shape[0]):
  272. output_i = output_sequences[i, :, start_of_generation_idx[i] : end_of_generation_idx[i]][None, ...]
  273. output_i = output_i.to(self.audio_tokenizer.device)
  274. audio_i = self.audio_tokenizer.decode(audio_codes=output_i).audio_values.cpu().squeeze()
  275. audios.append(audio_i)
  276. return audios
  277. def decode(
  278. self,
  279. decoder_input_ids: "torch.Tensor",
  280. audio_prompt_len: Optional[int] = None,
  281. **kwargs: Unpack[DiaProcessorKwargs],
  282. ) -> "torch.Tensor":
  283. """
  284. Decodes a single sequence of audio codebooks into the respective audio waveform via the
  285. `audio_tokenizer`. See [`~DacModel.decode`] and [`~DiaProcessor.batch_decode`] for more information.
  286. """
  287. if decoder_input_ids.shape[0] != 1:
  288. raise ValueError(
  289. f"Expecting a single output to be decoded but received {decoder_input_ids.shape[0]} samples instead."
  290. )
  291. return self.batch_decode(decoder_input_ids, audio_prompt_len, **kwargs)[0]
  292. def get_audio_prompt_len(
  293. self,
  294. decoder_attention_mask: "torch.Tensor",
  295. **kwargs: Unpack[DiaProcessorKwargs],
  296. ) -> int:
  297. """Utility function to get the audio prompt length."""
  298. output_kwargs = self._merge_kwargs(
  299. DiaProcessorKwargs,
  300. **kwargs,
  301. )
  302. audio_kwargs = output_kwargs["audio_kwargs"]
  303. delay_pattern = audio_kwargs.pop("delay_pattern", None)
  304. if delay_pattern is None:
  305. raise ValueError(
  306. "To enable the utility of retrieving the prompt length for Dia, we need the "
  307. "`delay_pattern`. You may have accidentally overwritten this."
  308. )
  309. return decoder_attention_mask.shape[1] - max(delay_pattern)
  310. # Copied from transformers.models.csm.processing_csm.CsmProcessor.save_audio with Csm->Dia
  311. def save_audio(
  312. self,
  313. audio: AudioInput,
  314. saving_path: Union[str, Path, list[Union[str, Path]]],
  315. **kwargs: Unpack[DiaProcessorKwargs],
  316. ):
  317. # TODO: @eustlb, this should be in AudioProcessor
  318. if not is_soundfile_available():
  319. raise ImportError("Please install `soundfile` to save audio files.")
  320. # ensure correct audio input
  321. audio = make_list_of_audio(audio)
  322. # ensure correct saving path
  323. if isinstance(saving_path, (str, Path)):
  324. saving_path = [saving_path]
  325. elif not (isinstance(saving_path, (list, tuple)) and all(isinstance(p, (str, Path)) for p in saving_path)):
  326. raise ValueError("Invalid input path. Please provide a string, or a list of strings")
  327. if len(audio) != len(saving_path):
  328. raise ValueError("The number of audio and saving paths must be the same")
  329. output_kwargs = self._merge_kwargs(
  330. DiaProcessorKwargs,
  331. **kwargs,
  332. )
  333. audio_kwargs = output_kwargs["audio_kwargs"]
  334. sampling_rate = audio_kwargs["sampling_rate"]
  335. for audio_value, p in zip(audio, saving_path):
  336. if isinstance(audio_value, torch.Tensor):
  337. audio_value = audio_value.cpu().float().numpy()
  338. sf.write(p, audio_value, sampling_rate)
  339. @staticmethod
  340. def build_indices(
  341. bsz: int,
  342. seq_len: int,
  343. num_channels: int,
  344. delay_pattern: list[int],
  345. revert: bool = False,
  346. ) -> tuple["torch.Tensor", "torch.Tensor"]:
  347. """
  348. Precompute (sequence_idx, all_idx) so that out[seq, channel] = in[seq - delay[channel], channel]
  349. or in[seq, channel] = out[seq + delay[channel], channel] if `revert`.
  350. Negative sequence_idx => BOS; sequence_idx >= seq_len => PAD.
  351. """
  352. delay_array = torch.tensor(delay_pattern, dtype=torch.int32)
  353. # (0..seq_len-1)
  354. sequence_idx = torch.arange(seq_len, dtype=torch.int32)[None, :].expand(bsz, seq_len)[..., None]
  355. # + or - delay depending if we delay or revert the delay
  356. if not revert:
  357. sequence_idx = sequence_idx - delay_array[None, None, :]
  358. else:
  359. sequence_idx = sequence_idx + delay_array[None, None, :]
  360. # if delay goes over the range we clamp back to valid values
  361. valid_sequence_idx = torch.clamp(sequence_idx, 0, seq_len - 1)
  362. batch_idx = torch.arange(bsz, dtype=torch.int32)[:, None, None].expand(bsz, seq_len, num_channels)
  363. channel_idx = torch.arange(num_channels, dtype=torch.int32)[None, None, :].expand(bsz, seq_len, num_channels)
  364. all_idx = torch.stack(
  365. [batch_idx.reshape(-1), valid_sequence_idx.reshape(-1), channel_idx.reshape(-1)],
  366. dim=1,
  367. ).long()
  368. return sequence_idx, all_idx
  369. @staticmethod
  370. def apply_audio_delay(
  371. audio: "torch.Tensor",
  372. pad_token_id: int,
  373. bos_token_id: int,
  374. precomputed_idx: tuple["torch.Tensor", "torch.Tensor"],
  375. ) -> "torch.Tensor":
  376. """
  377. Applies or reverts the delay pattern to batched audio tokens using precomputed indices,
  378. inserting BOS where sequence_idx < 0 and PAD where sequence_idx >= seq_len.
  379. Args:
  380. audio: audio tokens of shape [bsz, seq_len, num_channels]
  381. pad_token_id: the PAD token
  382. bos_token_id: the BOS token
  383. precomputed_idx: from `build_indices`
  384. Returns:
  385. final_audio: delayed or reverted audio tokens of shape [bsz, seq_len, num_channels]
  386. """
  387. # Move everything to the same device
  388. device = audio.device
  389. sequence_idx, all_idx = precomputed_idx
  390. sequence_idx = sequence_idx.to(device)
  391. all_idx = all_idx.to(device)
  392. # Gather per precomputed indices
  393. batch_idx, valid_sequence_idx, channel_idx = torch.unbind(all_idx, dim=-1)
  394. gathered_audio = audio[batch_idx, valid_sequence_idx, channel_idx].view(audio.size())
  395. # Mask according to negative sequence_idx => BOS; sequence_idx >= seq_len => PAD
  396. mask_bos = sequence_idx < 0
  397. mask_pad = sequence_idx >= audio.shape[1]
  398. final_audio = torch.where(mask_bos, bos_token_id, torch.where(mask_pad, pad_token_id, gathered_audio))
  399. return final_audio
  400. __all__ = ["DiaProcessor"]