generation_csm.py 25 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491
  1. # coding=utf-8
  2. # Copyright 2025 The HuggingFace Inc. team. All rights reserved.
  3. #
  4. # Licensed under the Apache License, Version 2.0 (the "License");
  5. # you may not use this file except in compliance with the License.
  6. # You may obtain a copy of the License at
  7. #
  8. # http://www.apache.org/licenses/LICENSE-2.0
  9. #
  10. # Unless required by applicable law or agreed to in writing, software
  11. # distributed under the License is distributed on an "AS IS" BASIS,
  12. # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  13. # See the License for the specific language governing permissions and
  14. # limitations under the License.
  15. import os
  16. from dataclasses import dataclass
  17. from typing import TYPE_CHECKING, Any, Optional, Union
  18. import torch
  19. import torch.nn as nn
  20. from ...generation import (
  21. GenerateDecoderOnlyOutput,
  22. GenerationConfig,
  23. GenerationMixin,
  24. GenerationMode,
  25. )
  26. from ...generation.logits_process import LogitsProcessorList
  27. from ...generation.stopping_criteria import MaxLengthCriteria, StoppingCriteriaList
  28. from ...generation.utils import GenerateNonBeamOutput
  29. from ...utils import logging
  30. if TYPE_CHECKING:
  31. from ...generation.streamers import BaseStreamer
  32. logger = logging.get_logger(__name__)
  33. @dataclass
  34. class CsmGenerateOutput(GenerateDecoderOnlyOutput):
  35. """
  36. Outputs of CsmForConditionalGeneration.generate.
  37. Args:
  38. sequences (`torch.LongTensor` of shape `(batch_size, sequence_length)`):
  39. The generated sequences. The second dimension (sequence_length) is either equal to `max_length` or shorter
  40. if all batches finished early due to the `eos_token_id`.
  41. scores (`tuple(torch.FloatTensor)` *optional*, returned when `output_scores=True`):
  42. Processed prediction scores of the language modeling head (scores for each vocabulary token before SoftMax)
  43. at each generation step. Tuple of `torch.FloatTensor` with up to `max_new_tokens` elements (one element for
  44. each generated token), with each tensor of shape `(batch_size, config.vocab_size)`.
  45. logits (`tuple(torch.FloatTensor)` *optional*, returned when `output_logits=True`):
  46. Unprocessed prediction scores of the language modeling head (scores for each vocabulary token before SoftMax)
  47. at each generation step. Tuple of `torch.FloatTensor` with up to `max_new_tokens` elements (one element for
  48. each generated token), with each tensor of shape `(batch_size, config.vocab_size)`.
  49. attentions (`tuple(tuple(torch.FloatTensor))`, *optional*, returned when `output_attentions=True`):
  50. Tuple (one element for each generated token) of tuples (one element for each layer of the decoder) of
  51. `torch.FloatTensor` of shape `(batch_size, num_heads, generated_length, sequence_length)`.
  52. hidden_states (`tuple(tuple(torch.FloatTensor))`, *optional*, returned when `output_hidden_states=True`):
  53. Tuple (one element for each generated token) of tuples (one element for each layer of the decoder) of
  54. `torch.FloatTensor` of shape `(batch_size, generated_length, hidden_size)`.
  55. past_key_values (`Cache`, *optional*, returned when `use_cache=True`):
  56. Returns the model cache, used to speed up decoding. Different models have a different cache format, check
  57. audio (`list(torch.FloatTensor)` of length `batch_size`):
  58. The generated audio.
  59. """
  60. audio: Optional[list[torch.Tensor]] = None
  61. class CsmGenerationMixin(GenerationMixin):
  62. def _get_stopping_criteria(
  63. self,
  64. *args,
  65. **kwargs,
  66. ) -> StoppingCriteriaList:
  67. criteria = super()._get_stopping_criteria(*args, **kwargs)
  68. kept_criteria = StoppingCriteriaList()
  69. for criterion in criteria:
  70. if not isinstance(criterion, MaxLengthCriteria):
  71. logger.warning(
  72. f"Csm does not support {criterion.__class__.__name__} stopping criteria, it will be ignored."
  73. )
  74. else:
  75. kept_criteria.append(criterion)
  76. return kept_criteria
  77. def _prepare_generation_config(
  78. self, generation_config: Optional[GenerationConfig], use_model_defaults: Optional[bool] = None, **kwargs: Any
  79. ) -> tuple[GenerationConfig, dict]:
  80. """
  81. This method overrides [~generation.utils.GenerationMixin._prepare_generation_config].
  82. It ensures that the depth decoder generation config is initialized and that passed args as depth_decoder_* are properly handled.
  83. """
  84. # extract depth decoder kwargs and remove them from the main kwargs
  85. depth_decoder_kwargs = {
  86. k[len("depth_decoder_") :]: v for k, v in kwargs.items() if k.startswith("depth_decoder_")
  87. }
  88. # remove the depth decoder keys from the original kwargs
  89. kwargs = {k: v for k, v in kwargs.items() if not k.startswith("depth_decoder_")}
  90. # initialize the generation config
  91. generation_config, model_kwargs = super()._prepare_generation_config(
  92. generation_config, use_model_defaults, **kwargs
  93. )
  94. self.depth_decoder.generation_config.update(**depth_decoder_kwargs)
  95. # ensure the depth decoder generation config is valid
  96. depth_decoder_min_new_tokens = getattr(self.depth_decoder.generation_config, "min_new_tokens") or (
  97. self.config.num_codebooks - 1
  98. )
  99. depth_decoder_max_new_tokens = getattr(self.depth_decoder.generation_config, "max_new_tokens") or (
  100. self.config.num_codebooks - 1
  101. )
  102. if {depth_decoder_min_new_tokens, depth_decoder_max_new_tokens} != {self.config.num_codebooks - 1}:
  103. raise ValueError(
  104. f"depth_decoder_generation_config's min_new_tokens ({depth_decoder_min_new_tokens}) and max_new_tokens ({depth_decoder_max_new_tokens}) must be equal to self.config.num_codebooks - 1 ({self.config.num_codebooks - 1})"
  105. )
  106. elif self.depth_decoder.generation_config.return_dict_in_generate:
  107. logger.warning(
  108. "depth_decoder_generation_config.return_dict_in_generate is set to True, but this will be ignored as the depth decoder model does not return a dictionary in generate"
  109. )
  110. self.depth_decoder.generation_config.return_dict_in_generate = False
  111. self.depth_decoder.generation_config.min_new_tokens = depth_decoder_min_new_tokens
  112. self.depth_decoder.generation_config.max_new_tokens = depth_decoder_max_new_tokens
  113. # Monkey patch the get_generation_mode method to support CSM model
  114. original_get_generation_mode = generation_config.get_generation_mode
  115. def patched_get_generation_mode(assistant_model=None):
  116. generation_mode = original_get_generation_mode(assistant_model)
  117. if generation_mode not in [GenerationMode.GREEDY_SEARCH, GenerationMode.SAMPLE]:
  118. raise ValueError(
  119. f"Generation mode {generation_mode} is not supported for CSM model. Please set generation parameters to use greedy or sampling generation."
  120. )
  121. return generation_mode
  122. generation_config.get_generation_mode = patched_get_generation_mode
  123. return generation_config, model_kwargs
  124. def _sample(
  125. self,
  126. input_ids: torch.LongTensor,
  127. logits_processor: LogitsProcessorList,
  128. stopping_criteria: StoppingCriteriaList,
  129. generation_config: GenerationConfig,
  130. synced_gpus: bool = False,
  131. streamer: Optional["BaseStreamer"] = None,
  132. **model_kwargs,
  133. ) -> Union[GenerateNonBeamOutput, torch.LongTensor]:
  134. """
  135. This method overrides [~generation.utils.GenerationMixin._sample].
  136. To ease maintenance, modifications are marked with the comment "Csm specific".
  137. Indeed, Csm model requires a custom generation sampling step:
  138. 1. Infer the backbone model to sample the first codebook token
  139. 2. Call generate on the depth decoder with the first codebook token as input_ids to sample the next codebook tokens
  140. 3. Use these generated codebook tokens as input_ids to sample the next first codebook token using the backbone model
  141. 4. Repeat until stopping criteria is met
  142. Csm supports two stopping criteria:
  143. - stop when the generated sequence is at max_length
  144. - stop when all the generated codebook tokens are the codebook_eos_token_id
  145. """
  146. # init values
  147. # *************** Csm specific ***************
  148. pad_token_id = self.config.codebook_pad_token_id
  149. has_eos_stopping_criteria = generation_config._eos_token_tensor is not None
  150. # ============================================
  151. output_attentions = generation_config.output_attentions
  152. output_hidden_states = generation_config.output_hidden_states
  153. output_scores = generation_config.output_scores
  154. output_logits = generation_config.output_logits
  155. return_dict_in_generate = generation_config.return_dict_in_generate
  156. do_sample = generation_config.do_sample
  157. # init attention / hidden states / scores tuples
  158. scores = () if (return_dict_in_generate and output_scores) else None
  159. raw_logits = () if (return_dict_in_generate and output_logits) else None
  160. decoder_attentions = () if (return_dict_in_generate and output_attentions) else None
  161. decoder_hidden_states = () if (return_dict_in_generate and output_hidden_states) else None
  162. # keep track of which sequences are already finished
  163. batch_size, cur_len = input_ids.shape[:2]
  164. this_peer_finished = False
  165. unfinished_sequences = torch.ones(batch_size, dtype=torch.long, device=input_ids.device)
  166. model_kwargs = self._get_initial_cache_position(cur_len, input_ids.device, model_kwargs)
  167. # *************** Csm specific ***************
  168. if input_ids.ndim == 2 and model_kwargs.get("inputs_embeds") is None:
  169. # in the case where the passed input_ids correspond to text tokens, i.e. don't have a third dimension for codebook ids,
  170. # we need to remove the input length to the MaxLengthCriteria stopping criteria has such input are not returned
  171. for criterion in stopping_criteria:
  172. if isinstance(criterion, MaxLengthCriteria):
  173. criterion.max_length -= cur_len
  174. # ============================================
  175. model_forward = self.__call__
  176. compile_forward = self._valid_auto_compile_criteria(model_kwargs, generation_config)
  177. if compile_forward:
  178. os.environ["TOKENIZERS_PARALLELISM"] = "0"
  179. model_forward = self.get_compiled_call(generation_config.compile_config)
  180. is_prefill = True
  181. while self._has_unfinished_sequences(
  182. this_peer_finished,
  183. synced_gpus,
  184. device=input_ids.device,
  185. ):
  186. # prepare model inputs
  187. model_inputs = self.prepare_inputs_for_generation(input_ids, **model_kwargs)
  188. # prepare variable output controls (note: some models won't accept all output controls)
  189. model_inputs.update({"output_attentions": output_attentions} if output_attentions else {})
  190. # *************** Csm specific ***************
  191. model_inputs.update({"output_hidden_states": True})
  192. # ============================================
  193. if is_prefill:
  194. outputs = self(**model_inputs, return_dict=True)
  195. is_prefill = False
  196. else:
  197. outputs = model_forward(**model_inputs, return_dict=True)
  198. # synced_gpus: don't waste resources running the code we don't need; kwargs must be updated before skipping
  199. model_kwargs = self._update_model_kwargs_for_generation(
  200. outputs,
  201. model_kwargs,
  202. )
  203. if synced_gpus and this_peer_finished:
  204. continue
  205. # Clone is needed to avoid keeping a hanging ref to outputs.logits which may be very large for first iteration
  206. # (the clone itself is always small)
  207. next_token_logits = outputs.logits[:, -1, :].clone().float()
  208. next_token_logits = next_token_logits.to(input_ids.device)
  209. # pre-process distribution
  210. next_token_scores = logits_processor(input_ids, next_token_logits)
  211. # Store scores, attentions and hidden_states when required
  212. if return_dict_in_generate:
  213. if output_scores:
  214. scores += (next_token_scores,)
  215. if output_logits:
  216. raw_logits += (next_token_logits,)
  217. if output_attentions:
  218. decoder_attentions += (outputs.attentions,)
  219. if output_hidden_states:
  220. decoder_hidden_states += (outputs.hidden_states,)
  221. # token selection
  222. if do_sample:
  223. probs = nn.functional.softmax(next_token_scores, dim=-1)
  224. # TODO (joao): this OP throws "skipping cudagraphs due to ['incompatible ops']", find solution
  225. next_tokens = torch.multinomial(probs, num_samples=1).squeeze(1)
  226. else:
  227. next_tokens = torch.argmax(next_token_scores, dim=-1)
  228. # *************** Csm specific ***************
  229. # infer the depth decoder
  230. first_codebook_ids = next_tokens[:, None]
  231. # adds place holder in position 0 that will be replaced by the backbone_last_hidden_state
  232. depth_decoder_input_ids = nn.functional.pad(first_codebook_ids, (1, 0), value=0)
  233. backbone_last_hidden_state = outputs.hidden_states[-1][:, -1, :]
  234. depth_decoder_outputs = self.depth_decoder.generate(
  235. input_ids=depth_decoder_input_ids, backbone_last_hidden_state=backbone_last_hidden_state.clone()
  236. )
  237. codebook_ids = (
  238. depth_decoder_outputs
  239. if isinstance(depth_decoder_outputs, torch.Tensor)
  240. else depth_decoder_outputs.sequences
  241. )
  242. # remove the place holder in position 0
  243. codebook_ids = codebook_ids[:, 1:]
  244. next_tokens = codebook_ids
  245. # finished sentences should have their next token be a padding token
  246. if has_eos_stopping_criteria:
  247. next_tokens = next_tokens * unfinished_sequences.unsqueeze(-1) + pad_token_id * (
  248. 1 - unfinished_sequences.unsqueeze(-1)
  249. )
  250. # update generated ids, model inputs, and length for next step
  251. if input_ids.ndim == 2:
  252. input_ids = next_tokens[:, None, :]
  253. else:
  254. input_ids = torch.cat([input_ids, next_tokens[:, None, :]], dim=1)
  255. # ============================================
  256. if streamer is not None:
  257. streamer.put(next_tokens.cpu())
  258. # *************** Csm specific ***************
  259. # for the eos stopping criteria, is it expected that the eos token is the same for each codebook !!!!
  260. unfinished_sequences = unfinished_sequences & ~(
  261. input_ids[:, -1, :-1] == self.config.codebook_eos_token_id
  262. ).all(-1)
  263. # ============================================
  264. unfinished_sequences = unfinished_sequences & ~stopping_criteria(input_ids, scores)
  265. this_peer_finished = unfinished_sequences.max() == 0
  266. cur_len += 1
  267. # This is needed to properly delete outputs.logits which may be very large for first iteration
  268. # Otherwise a reference to outputs is kept which keeps the logits alive in the next iteration
  269. del outputs
  270. # *************** Csm specific ***************
  271. del depth_decoder_outputs
  272. # ============================================
  273. if streamer is not None:
  274. streamer.end()
  275. if return_dict_in_generate:
  276. return GenerateDecoderOnlyOutput(
  277. sequences=input_ids,
  278. scores=scores,
  279. logits=raw_logits,
  280. attentions=decoder_attentions,
  281. hidden_states=decoder_hidden_states,
  282. past_key_values=model_kwargs.get("past_key_values"),
  283. )
  284. else:
  285. return input_ids
  286. def generate(
  287. self,
  288. input_ids: Optional[torch.Tensor] = None,
  289. input_values: Optional[torch.Tensor] = None,
  290. input_values_cutoffs: Optional[torch.Tensor] = None,
  291. generation_config: Optional[GenerationConfig] = None,
  292. logits_processor: Optional[LogitsProcessorList] = None,
  293. stopping_criteria: Optional[StoppingCriteriaList] = None,
  294. synced_gpus: Optional[bool] = None,
  295. streamer: Optional["BaseStreamer"] = None,
  296. output_audio: Optional[bool] = False,
  297. **kwargs,
  298. ) -> Union[GenerateNonBeamOutput, torch.LongTensor]:
  299. r"""
  300. This method overrides [`~generation.utils.GenerationMixin.generate`] to match the specifics of the Csm model.
  301. Indeed, Csm model requires a custom generation sampling step:
  302. 1. Infer the backbone model to sample the first codebook token
  303. 2. Call generate on the depth decoder with the first codebook token as `input_ids` to sample the next codebook tokens
  304. 3. Use these generated codebook tokens as `input_ids` to sample the next first codebook token using the backbone model
  305. 4. Repeat until stopping criteria is met
  306. <Tip warning={true}>
  307. Most generation-controlling parameters are set in `generation_config` which, if not passed, will be set to the
  308. model's default generation configuration. You can override any `generation_config` by passing the corresponding
  309. parameters to generate(), e.g. `.generate(inputs, do_sample=True)`.
  310. </Tip>
  311. Parameters:
  312. inputs_ids (`torch.Tensor` of shape (batch_size, seq_length), *optional*):
  313. The sequence used as a prompt for the backbone model.
  314. input_values (`torch.Tensor` of shape (batch_size, channels, max_concatenated_audio_length), *optional*):
  315. The batched audio input values, where each batch entry contains the concatenation of all audio segments for that entry.
  316. These values will be encoded into codebook tokens using the codec model and merged with the text input ids provided in `input_ids`.
  317. input_values_cutoffs (`torch.Tensor` of shape (batch_size, max_num_audio), *optional*):
  318. Specify the end positions of audio segments within each batch entry, relative to the concatenated audio input.
  319. If a batch entry has fewer segments than the maximum, it is padded with -1. For example, in a batch of 2 sequences
  320. where the first contains 2 audio segments of length l1, and the second contains 1 audio segment of length l2,
  321. the input_values_cutoffs would be: [[l1, 2 * l1], [l2, -1]].
  322. generation_config ([`~generation.GenerationConfig`], *optional*):
  323. The generation configuration to be used as base parametrization for the generation call. `**kwargs`
  324. passed to generate matching the attributes of `generation_config` will override them. If
  325. `generation_config` is not provided, the default will be used, which has the following loading
  326. priority: 1) from the `generation_config.json` model file, if it exists; 2) from the model
  327. configuration. Please note that unspecified parameters will inherit [`~generation.GenerationConfig`]'s
  328. default values, whose documentation should be checked to parameterize generation.
  329. logits_processor (`LogitsProcessorList`, *optional*):
  330. Custom logits processors that complement the default logits processors built from arguments and
  331. generation config. If a logit processor is passed that is already created with the arguments or a
  332. generation config an error is thrown. This feature is intended for advanced users.
  333. stopping_criteria (`StoppingCriteriaList`, *optional*):
  334. Custom stopping criteria that complements the default stopping criteria built from arguments and a
  335. generation config. If a stopping criteria is passed that is already created with the arguments or a
  336. generation config an error is thrown. If your stopping criteria depends on the `scores` input, make
  337. sure you pass `return_dict_in_generate=True, output_scores=True` to `generate`. This feature is
  338. intended for advanced users.
  339. synced_gpus (`bool`, *optional*):
  340. Whether to continue running the while loop until max_length. Unless overridden, this flag will be set
  341. to `True` if using `FullyShardedDataParallel` or DeepSpeed ZeRO Stage 3 with multiple GPUs to avoid
  342. deadlocking if one GPU finishes generating before other GPUs. Otherwise, defaults to `False`.
  343. streamer (`BaseStreamer`, *optional*):
  344. Streamer object that will be used to stream the generated sequences. Generated tokens are passed
  345. through `streamer.put(token_ids)` and the streamer is responsible for any further processing.
  346. output_audio (`bool`, *optional*):
  347. Whether to return the generated audio.
  348. kwargs (`dict[str, Any]`, *optional*):
  349. Ad hoc parametrization of `generation_config` and/or additional model-specific kwargs that will be
  350. forwarded to the `forward` function of the model. Depth decoder specific kwargs should be prefixed with *depth_decoder_*.
  351. Return:
  352. [`CsmGenerateOutput`] or `torch.LongTensor` or `list[torch.FloatTensor]`: A [`CsmGenerateOutput`]
  353. (if `return_dict_in_generate=True` or when `config.return_dict_in_generate=True`) or a `torch.LongTensor` when `output_audio=False`
  354. or a `list[torch.FloatTensor]` otherwise.
  355. Example:
  356. ```python
  357. >>> from transformers import CsmProcessor, CsmForConditionalGeneration
  358. >>> from datasets import load_dataset, Audio
  359. >>> model_id = "sesame/csm-1b"
  360. >>> torch_device = "cuda" if torch.cuda.is_available() else "cpu"
  361. >>> processor = AutoProcessor.from_pretrained(model_id)
  362. >>> ds = load_dataset("hf-internal-testing/dailytalk-dummy", split="train")
  363. >>> # ensure the audio is 24kHz
  364. >>> ds = ds.cast_column("audio", Audio(sampling_rate=24000))
  365. >>> conversation = []
  366. >>> # prepare a conversation with text and corresponding audio
  367. >>> for text, audio, speaker_id in zip(ds[:4]["text"], ds[:4]["audio"], ds[:4]["speaker_id"]):
  368. ... conversation.append(
  369. ... {
  370. ... "role": f"{speaker_id}",
  371. ... "content": [{"type": "text", "text": text}, {"type": "audio", "path": audio["array"]}],
  372. ... }
  373. ... )
  374. >>> # text prompt
  375. >>> conversation.append({"role": f"{ds[4]['speaker_id']}", "content": [{"type": "text", "text": ds[4]["text"]}]})
  376. >>> inputs = processor.apply_chat_template(
  377. ... conversation,
  378. ... tokenize=True,
  379. ... return_dict=True,
  380. ... ).to(torch_device)
  381. >>> model = CsmForConditionalGeneration.from_pretrained(model_id, device_map=torch_device)
  382. >>> audio = model.generate(**inputs, output_audio=True)
  383. >>> processor.save_audio(audio, "output.wav")
  384. ```
  385. """
  386. generate_output = super().generate(
  387. input_ids=input_ids,
  388. input_values=input_values,
  389. input_values_cutoffs=input_values_cutoffs,
  390. generation_config=generation_config,
  391. logits_processor=logits_processor,
  392. stopping_criteria=stopping_criteria,
  393. synced_gpus=synced_gpus,
  394. streamer=streamer,
  395. **kwargs,
  396. )
  397. generate_returned_dict = not isinstance(generate_output, torch.Tensor)
  398. audio = None
  399. if output_audio:
  400. generated_audio_codes = generate_output.sequences if generate_returned_dict else generate_output
  401. # infer the codec model
  402. audio = []
  403. with torch.no_grad():
  404. # =======================================
  405. # TODO: @eustlb, this should be batched !!!
  406. # but requires making sure batched inference of the codec model works as intended
  407. for audio_codes_batch in generated_audio_codes:
  408. eos_idxs = (audio_codes_batch == self.config.codebook_eos_token_id).all(dim=-1).nonzero()
  409. if eos_idxs.numel() != 0:
  410. cutoff_idx = eos_idxs.min()
  411. else:
  412. cutoff_idx = audio_codes_batch.shape[0]
  413. audio_codes_batch = audio_codes_batch[:cutoff_idx]
  414. codec_decode_output = self.codec_model.decode(audio_codes_batch.transpose(0, 1).unsqueeze(0))
  415. audio.append(codec_decode_output.audio_values[0, 0])
  416. # =======================================
  417. if generate_returned_dict:
  418. return CsmGenerateOutput(audio=audio, **generate_output)
  419. elif output_audio:
  420. return audio
  421. else:
  422. return generate_output