text_to_audio.py 11 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277
  1. # Copyright 2023 The HuggingFace Team. All rights reserved.
  2. #
  3. # Licensed under the Apache License, Version 2.0 (the "License");
  4. # you may not use this file except in compliance with the License.
  5. # You may obtain a copy of the License at
  6. #
  7. # http://www.apache.org/licenses/LICENSE-2.0
  8. #
  9. # Unless required by applicable law or agreed to in writing, software
  10. # distributed under the License is distributed on an "AS IS" BASIS,
  11. # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  12. # See the License for the specific language governing permissions and
  13. # limitations under the License.from typing import List, Union
  14. from typing import Any, Union, overload
  15. from ..generation import GenerationConfig
  16. from ..utils import is_torch_available
  17. from .base import Pipeline
  18. if is_torch_available():
  19. import torch
  20. from ..models.auto.modeling_auto import MODEL_FOR_TEXT_TO_SPECTROGRAM_MAPPING
  21. from ..models.speecht5.modeling_speecht5 import SpeechT5HifiGan
  22. DEFAULT_VOCODER_ID = "microsoft/speecht5_hifigan"
  23. class TextToAudioPipeline(Pipeline):
  24. """
  25. Text-to-audio generation pipeline using any `AutoModelForTextToWaveform` or `AutoModelForTextToSpectrogram`. This
  26. pipeline generates an audio file from an input text and optional other conditional inputs.
  27. Unless the model you're using explicitly sets these generation parameters in its configuration files
  28. (`generation_config.json`), the following default values will be used:
  29. - max_new_tokens: 256
  30. Example:
  31. ```python
  32. >>> from transformers import pipeline
  33. >>> pipe = pipeline(model="suno/bark-small")
  34. >>> output = pipe("Hey it's HuggingFace on the phone!")
  35. >>> audio = output["audio"]
  36. >>> sampling_rate = output["sampling_rate"]
  37. ```
  38. Learn more about the basics of using a pipeline in the [pipeline tutorial](../pipeline_tutorial)
  39. <Tip>
  40. You can specify parameters passed to the model by using [`TextToAudioPipeline.__call__.forward_params`] or
  41. [`TextToAudioPipeline.__call__.generate_kwargs`].
  42. Example:
  43. ```python
  44. >>> from transformers import pipeline
  45. >>> music_generator = pipeline(task="text-to-audio", model="facebook/musicgen-small", framework="pt")
  46. >>> # diversify the music generation by adding randomness with a high temperature and set a maximum music length
  47. >>> generate_kwargs = {
  48. ... "do_sample": True,
  49. ... "temperature": 0.7,
  50. ... "max_new_tokens": 35,
  51. ... }
  52. >>> outputs = music_generator("Techno music with high melodic riffs", generate_kwargs=generate_kwargs)
  53. ```
  54. </Tip>
  55. This pipeline can currently be loaded from [`pipeline`] using the following task identifiers: `"text-to-speech"` or
  56. `"text-to-audio"`.
  57. See the list of available models on [huggingface.co/models](https://huggingface.co/models?filter=text-to-speech).
  58. """
  59. # Introducing the processor at load time for new behaviour
  60. _load_processor = True
  61. _pipeline_calls_generate = True
  62. _load_processor = False
  63. _load_image_processor = False
  64. _load_feature_extractor = False
  65. _load_tokenizer = True
  66. # Make sure the docstring is updated when the default generation config is changed
  67. _default_generation_config = GenerationConfig(
  68. max_new_tokens=256,
  69. )
  70. def __init__(self, *args, vocoder=None, sampling_rate=None, no_processor=True, **kwargs):
  71. super().__init__(*args, **kwargs)
  72. # Legacy behaviour just uses the tokenizer while new models use the processor as a whole at any given time
  73. self.no_processor = no_processor
  74. if self.framework == "tf":
  75. raise ValueError("The TextToAudioPipeline is only available in PyTorch.")
  76. self.vocoder = None
  77. if self.model.__class__ in MODEL_FOR_TEXT_TO_SPECTROGRAM_MAPPING.values():
  78. self.vocoder = (
  79. SpeechT5HifiGan.from_pretrained(DEFAULT_VOCODER_ID).to(self.model.device)
  80. if vocoder is None
  81. else vocoder
  82. )
  83. self.sampling_rate = sampling_rate
  84. if self.vocoder is not None:
  85. self.sampling_rate = self.vocoder.config.sampling_rate
  86. if self.sampling_rate is None:
  87. # get sampling_rate from config and generation config
  88. config = self.model.config
  89. gen_config = self.model.__dict__.get("generation_config", None)
  90. if gen_config is not None:
  91. config.update(gen_config.to_dict())
  92. for sampling_rate_name in ["sample_rate", "sampling_rate"]:
  93. sampling_rate = getattr(config, sampling_rate_name, None)
  94. if sampling_rate is not None:
  95. self.sampling_rate = sampling_rate
  96. elif getattr(config, "codec_config", None) is not None:
  97. sampling_rate = getattr(config.codec_config, sampling_rate_name, None)
  98. if sampling_rate is not None:
  99. self.sampling_rate = sampling_rate
  100. # last fallback to get the sampling rate based on processor
  101. if self.sampling_rate is None and not self.no_processor and hasattr(self.processor, "feature_extractor"):
  102. self.sampling_rate = self.processor.feature_extractor.sampling_rate
  103. def preprocess(self, text, **kwargs):
  104. if isinstance(text, str):
  105. text = [text]
  106. if self.model.config.model_type == "bark":
  107. # bark Tokenizer is called with BarkProcessor which uses those kwargs
  108. new_kwargs = {
  109. "max_length": self.generation_config.semantic_config.get("max_input_semantic_length", 256),
  110. "add_special_tokens": False,
  111. "return_attention_mask": True,
  112. "return_token_type_ids": False,
  113. "padding": "max_length",
  114. }
  115. # priority is given to kwargs
  116. new_kwargs.update(kwargs)
  117. kwargs = new_kwargs
  118. preprocessor = self.tokenizer if self.no_processor else self.processor
  119. output = preprocessor(text, **kwargs, return_tensors="pt")
  120. return output
  121. def _forward(self, model_inputs, **kwargs):
  122. # we expect some kwargs to be additional tensors which need to be on the right device
  123. kwargs = self._ensure_tensor_on_device(kwargs, device=self.device)
  124. forward_params = kwargs["forward_params"]
  125. generate_kwargs = kwargs["generate_kwargs"]
  126. if self.model.can_generate():
  127. # we expect some kwargs to be additional tensors which need to be on the right device
  128. generate_kwargs = self._ensure_tensor_on_device(generate_kwargs, device=self.device)
  129. # User-defined `generation_config` passed to the pipeline call take precedence
  130. if "generation_config" not in generate_kwargs:
  131. generate_kwargs["generation_config"] = self.generation_config
  132. # generate_kwargs get priority over forward_params
  133. forward_params.update(generate_kwargs)
  134. output = self.model.generate(**model_inputs, **forward_params)
  135. else:
  136. if len(generate_kwargs):
  137. raise ValueError(
  138. "You're using the `TextToAudioPipeline` with a forward-only model, but `generate_kwargs` is non "
  139. "empty. For forward-only TTA models, please use `forward_params` instead of `generate_kwargs`. "
  140. f"For reference, the `generate_kwargs` used here are: {generate_kwargs.keys()}"
  141. )
  142. output = self.model(**model_inputs, **forward_params)[0]
  143. if self.vocoder is not None:
  144. # in that case, the output is a spectrogram that needs to be converted into a waveform
  145. output = self.vocoder(output)
  146. return output
  147. @overload
  148. def __call__(self, text_inputs: str, **forward_params: Any) -> dict[str, Any]: ...
  149. @overload
  150. def __call__(self, text_inputs: list[str], **forward_params: Any) -> list[dict[str, Any]]: ...
  151. def __call__(
  152. self, text_inputs: Union[str, list[str]], **forward_params
  153. ) -> Union[dict[str, Any], list[dict[str, Any]]]:
  154. """
  155. Generates speech/audio from the inputs. See the [`TextToAudioPipeline`] documentation for more information.
  156. Args:
  157. text_inputs (`str` or `list[str]`):
  158. The text(s) to generate.
  159. forward_params (`dict`, *optional*):
  160. Parameters passed to the model generation/forward method. `forward_params` are always passed to the
  161. underlying model.
  162. generate_kwargs (`dict`, *optional*):
  163. The dictionary of ad-hoc parametrization of `generate_config` to be used for the generation call. For a
  164. complete overview of generate, check the [following
  165. guide](https://huggingface.co/docs/transformers/en/main_classes/text_generation). `generate_kwargs` are
  166. only passed to the underlying model if the latter is a generative model.
  167. Return:
  168. A `dict` or a list of `dict`: The dictionaries have two keys:
  169. - **audio** (`np.ndarray` of shape `(nb_channels, audio_length)`) -- The generated audio waveform.
  170. - **sampling_rate** (`int`) -- The sampling rate of the generated audio waveform.
  171. """
  172. return super().__call__(text_inputs, **forward_params)
  173. def _sanitize_parameters(
  174. self,
  175. preprocess_params=None,
  176. forward_params=None,
  177. generate_kwargs=None,
  178. ):
  179. if getattr(self, "assistant_model", None) is not None:
  180. generate_kwargs["assistant_model"] = self.assistant_model
  181. if getattr(self, "assistant_tokenizer", None) is not None:
  182. generate_kwargs["tokenizer"] = self.tokenizer
  183. generate_kwargs["assistant_tokenizer"] = self.assistant_tokenizer
  184. params = {
  185. "forward_params": forward_params if forward_params else {},
  186. "generate_kwargs": generate_kwargs if generate_kwargs else {},
  187. }
  188. if preprocess_params is None:
  189. preprocess_params = {}
  190. postprocess_params = {}
  191. return preprocess_params, params, postprocess_params
  192. def postprocess(self, audio):
  193. output_dict = {}
  194. if self.model.config.model_type == "csm":
  195. waveform_key = "audio"
  196. else:
  197. waveform_key = "waveform"
  198. # We directly get the waveform
  199. if self.no_processor:
  200. if isinstance(audio, dict):
  201. waveform = audio[waveform_key]
  202. elif isinstance(audio, tuple):
  203. waveform = audio[0]
  204. else:
  205. waveform = audio
  206. # Or we need to postprocess to get the waveform
  207. else:
  208. waveform = self.processor.decode(audio)
  209. if isinstance(audio, list):
  210. output_dict["audio"] = [el.to(device="cpu", dtype=torch.float).numpy() for el in waveform]
  211. else:
  212. output_dict["audio"] = waveform.to(device="cpu", dtype=torch.float).numpy()
  213. output_dict["sampling_rate"] = self.sampling_rate
  214. return output_dict