generation_dia.py 21 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463
  1. # coding=utf-8
  2. # Copyright 2025 The Nari Labs and HuggingFace Inc. team. All rights reserved.
  3. #
  4. # Licensed under the Apache License, Version 2.0 (the "License");
  5. # you may not use this file except in compliance with the License.
  6. # You may obtain a copy of the License at
  7. #
  8. # http://www.apache.org/licenses/LICENSE-2.0
  9. #
  10. # Unless required by applicable law or agreed to in writing, software
  11. # distributed under the License is distributed on an "AS IS" BASIS,
  12. # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  13. # See the License for the specific language governing permissions and
  14. # limitations under the License.
  15. from typing import Any, Callable, Optional, Union
  16. import torch
  17. import torch.distributed as dist
  18. from ...generation.logits_process import (
  19. DiaClassifierFreeGuidanceLogitsProcessor,
  20. DiaEOSChannelFilterLogitsProcessor,
  21. DiaEOSDelayPatternLogitsProcessor,
  22. LogitsProcessorList,
  23. TemperatureLogitsWarper,
  24. )
  25. from ...generation.stopping_criteria import StoppingCriteriaList
  26. from ...generation.streamers import BaseStreamer
  27. from ...generation.utils import GenerateOutput, GenerationConfig, GenerationMixin, GenerationMode
  28. from ...integrations.deepspeed import is_deepspeed_zero3_enabled
  29. from ...integrations.fsdp import is_fsdp_managed_module
  30. from ...modeling_utils import PreTrainedModel
  31. from ...utils import logging
  32. logger = logging.get_logger(__name__)
  33. class DiaGenerationMixin(GenerationMixin):
  34. # Indicates CFG which needs preparation to be properly handled by repeats
  35. _uses_cfg = None
  36. def _get_logits_processor(
  37. self,
  38. generation_config: GenerationConfig,
  39. input_ids_seq_length: Optional[int] = None,
  40. encoder_input_ids: Optional[torch.LongTensor] = None,
  41. prefix_allowed_tokens_fn: Optional[Callable[[int, torch.Tensor], list[int]]] = None,
  42. logits_processor: Optional[LogitsProcessorList] = None,
  43. device: Optional[str] = None,
  44. model_kwargs: Optional[dict[str, Any]] = None,
  45. negative_prompt_ids: Optional[torch.Tensor] = None,
  46. negative_prompt_attention_mask: Optional[torch.Tensor] = None,
  47. ) -> LogitsProcessorList:
  48. # Need either custom order or custom processor instead
  49. # (Temporarily disabling those for the super function)
  50. original_guidance_scale = generation_config.guidance_scale
  51. original_temperature = generation_config.temperature
  52. generation_config.guidance_scale = None
  53. generation_config.temperature = None
  54. # Get base processors and those we can integrate easily
  55. custom_processors = LogitsProcessorList()
  56. if original_temperature is not None and original_temperature != 1.0:
  57. custom_processors.append(TemperatureLogitsWarper(original_temperature))
  58. custom_processors.append(
  59. DiaEOSChannelFilterLogitsProcessor(
  60. num_channels=len(self.config.delay_pattern),
  61. eos_token_id=self.config.eos_token_id,
  62. )
  63. )
  64. merged_processors = super()._get_logits_processor(
  65. generation_config=generation_config,
  66. input_ids_seq_length=input_ids_seq_length,
  67. encoder_input_ids=encoder_input_ids,
  68. prefix_allowed_tokens_fn=None,
  69. logits_processor=custom_processors,
  70. device=device,
  71. model_kwargs=model_kwargs,
  72. negative_prompt_ids=negative_prompt_ids,
  73. negative_prompt_attention_mask=negative_prompt_attention_mask,
  74. )
  75. # Custom processors we need at specific positions
  76. if original_guidance_scale is not None and original_guidance_scale != 1:
  77. cfg_processor = DiaClassifierFreeGuidanceLogitsProcessor(
  78. guidance_scale=original_guidance_scale,
  79. guidance_top_k=generation_config.top_k,
  80. )
  81. merged_processors.insert(0, cfg_processor)
  82. merged_processors.append(
  83. DiaEOSDelayPatternLogitsProcessor(
  84. delay_pattern=self.config.delay_pattern,
  85. eos_token_id=self.config.eos_token_id,
  86. max_generation_len=generation_config.max_length,
  87. device=device,
  88. )
  89. )
  90. # Enable temporarily disabled values back
  91. generation_config.guidance_scale = original_guidance_scale
  92. generation_config.temperature = original_temperature
  93. return merged_processors
  94. def _prepare_generation_config(
  95. self, generation_config: Optional[GenerationConfig], use_model_defaults: Optional[bool] = None, **kwargs: Any
  96. ) -> tuple[GenerationConfig, dict]:
  97. generation_config, model_kwargs = super()._prepare_generation_config(
  98. generation_config, use_model_defaults, **kwargs
  99. )
  100. # We allow generation up to max length + max delay pattern
  101. # (will revert back to max length after generation)
  102. generation_config.max_length += max(self.config.delay_pattern)
  103. # Internal flag to indicate CFG that needs to prepare unconditioned input
  104. self._uses_cfg = generation_config.guidance_scale is not None and generation_config.guidance_scale != 1
  105. return generation_config, model_kwargs
  106. def _prepare_model_inputs(
  107. self,
  108. inputs: Optional[torch.Tensor] = None,
  109. bos_token_id: Optional[torch.Tensor] = None,
  110. model_kwargs: Optional[dict[str, torch.Tensor]] = None,
  111. ) -> tuple[torch.Tensor, Optional[str], dict[str, torch.Tensor]]:
  112. inputs, input_name, model_kwargs = super()._prepare_model_inputs(
  113. inputs=inputs,
  114. bos_token_id=bos_token_id,
  115. model_kwargs=model_kwargs,
  116. )
  117. # If CFG is requested we fill in the unconditioned parts
  118. if self._uses_cfg:
  119. unconditioned_inputs = torch.zeros_like(inputs)
  120. inputs = torch.cat([inputs, unconditioned_inputs], dim=0)
  121. if model_kwargs.get("attention_mask", None) is not None:
  122. model_kwargs["attention_mask"] = model_kwargs["attention_mask"].repeat(2, 1)
  123. return inputs, input_name, model_kwargs
  124. def _prepare_decoder_input_ids_for_generation(
  125. self,
  126. batch_size: int,
  127. model_input_name: str,
  128. model_kwargs: dict[str, torch.Tensor],
  129. decoder_start_token_id: torch.Tensor,
  130. device: Optional[torch.device] = None,
  131. ) -> tuple[torch.LongTensor, dict[str, torch.Tensor]]:
  132. """Prepares `decoder_input_ids` for generation with encoder-decoder models"""
  133. # 1. Check whether the user has defined `decoder_input_ids` and `decoder_attention_mask`; if not error out
  134. decoder_input_ids = decoder_attention_mask = None
  135. if model_kwargs is not None and "decoder_input_ids" in model_kwargs:
  136. decoder_input_ids = model_kwargs.pop("decoder_input_ids")
  137. if model_kwargs is not None and "decoder_attention_mask" in model_kwargs:
  138. decoder_attention_mask = model_kwargs.pop("decoder_attention_mask")
  139. # We allow generating without preparation (no proper delay) but discourage it
  140. if decoder_input_ids is None or decoder_attention_mask is None:
  141. logger.warning_once(
  142. "In order to generate with Dia, we need the processed audio input: Got `decoder_input_ids`:"
  143. f" {decoder_input_ids is not None} and got `decoder_attention_mask`={decoder_attention_mask is not None}."
  144. f" This can be achieved via the [`DiaProcessor`] but now defaulting to non-delayed generation."
  145. )
  146. num_channels = self.config.decoder_config.num_channels
  147. real_batch_size = batch_size // 2 if self._uses_cfg else batch_size
  148. if decoder_input_ids is None:
  149. decoder_input_ids = torch.full(
  150. (real_batch_size, 1, num_channels), decoder_start_token_id, dtype=torch.long, device=device
  151. )
  152. decoder_attention_mask = torch.ones(
  153. size=(real_batch_size, decoder_input_ids.shape[1]), dtype=torch.long, device=device
  154. )
  155. # 2. Determine the valid input and what works as mask within the input
  156. delay_mask = decoder_input_ids.long()
  157. valid_input_size = (
  158. decoder_input_ids.shape[1] - (decoder_input_ids[:, :, 0] == self.config.pad_token_id).sum(dim=-1).max()
  159. )
  160. decoder_input_ids = delay_mask[:, :valid_input_size].transpose(1, 2).long()
  161. decoder_attention_mask = decoder_attention_mask[:, :valid_input_size].long()
  162. # 3. Overwrite into model kwargs
  163. model_kwargs["decoder_attention_mask"] = decoder_attention_mask
  164. model_kwargs["decoder_delay_mask"] = delay_mask
  165. return decoder_input_ids, model_kwargs
  166. def prepare_inputs_for_generation(
  167. self,
  168. input_ids,
  169. encoder_outputs=None, # Using this to easily get the batch size
  170. decoder_delay_mask=None,
  171. **kwargs,
  172. ):
  173. # Reshape decoder input_ids to 3D to be compile friendly and to fit the expected model input shape
  174. batch_size = encoder_outputs[0].shape[0] // 2 if self._uses_cfg else encoder_outputs[0].shape[0]
  175. input_ids = input_ids.reshape(batch_size, self.config.decoder_config.num_channels, -1).transpose(1, 2)
  176. # Base method handles most things except CFG and the delay pattern mask
  177. model_inputs = super().prepare_inputs_for_generation(input_ids, encoder_outputs=encoder_outputs, **kwargs)
  178. # Post processing for CFG and overwriting via delay pattern mask
  179. # 1. Delay pattern mask -- force tokens if not allowed to predict (!= pad_token in mask)
  180. model_inputs["decoder_input_ids"] = self.apply_delay_mask(
  181. input_ids, self.config.pad_token_id, decoder_delay_mask
  182. )
  183. # Depending on cache usage we need to pass all or just one
  184. if model_inputs.get("use_cache", False) and model_inputs["cache_position"][0] > 0:
  185. model_inputs["decoder_input_ids"] = model_inputs["decoder_input_ids"][:, -1, :][:, None, :]
  186. # Be compile friendly
  187. model_inputs["decoder_input_ids"] = model_inputs["decoder_input_ids"].contiguous()
  188. # 2. Apply CFG duplication if needed
  189. if self._uses_cfg:
  190. for key in ["decoder_input_ids", "decoder_attention_mask", "decoder_position_ids"]:
  191. if model_inputs.get(key, None) is not None:
  192. # double first dimension and keep everything else the same
  193. repeat_pattern = tuple([2] + [1] * (model_inputs[key].ndim - 1))
  194. model_inputs[key] = model_inputs[key].repeat(*repeat_pattern)
  195. return model_inputs
  196. @staticmethod
  197. def apply_delay_mask(input_ids: torch.Tensor, pad_id: int, delay_mask: Optional[torch.Tensor]) -> torch.Tensor:
  198. if delay_mask is None:
  199. return input_ids
  200. mask_len = min(input_ids.shape[1], delay_mask.shape[1])
  201. valid_mask = delay_mask[:, :mask_len, :]
  202. valid_input = input_ids[:, :mask_len, :]
  203. # Overwrite the respective parts of the input
  204. input_ids[:, :mask_len, :] = torch.where(valid_mask == pad_id, valid_input, valid_mask)
  205. return input_ids
  206. def _main_generate_loop(
  207. self,
  208. inputs: Optional[torch.Tensor] = None,
  209. generation_config: Optional[GenerationConfig] = None,
  210. logits_processor: Optional[LogitsProcessorList] = None,
  211. stopping_criteria: Optional[StoppingCriteriaList] = None,
  212. prefix_allowed_tokens_fn: Optional[Callable[[int, torch.Tensor], list[int]]] = None,
  213. synced_gpus: Optional[bool] = None,
  214. assistant_model: Optional["PreTrainedModel"] = None,
  215. streamer: Optional["BaseStreamer"] = None,
  216. negative_prompt_ids: Optional[torch.Tensor] = None,
  217. negative_prompt_attention_mask: Optional[torch.Tensor] = None,
  218. use_model_defaults: Optional[bool] = None,
  219. custom_generate: Optional[str] = None,
  220. **kwargs,
  221. ):
  222. # ********** mostly taken from main generate function up to calling the different methods (see NOTE) **********
  223. # 1. Handle `generation_config` and kwargs that might update it, and validate the `.generate()` call
  224. generation_mode_kwargs = self._extract_generation_mode_kwargs(
  225. custom_generate,
  226. kwargs,
  227. synced_gpus,
  228. assistant_model,
  229. streamer,
  230. )
  231. generation_config, model_kwargs = self._prepare_generation_config(
  232. generation_config, use_model_defaults, **kwargs
  233. )
  234. generation_mode = generation_config.get_generation_mode(assistant_model)
  235. self._validate_model_kwargs(model_kwargs.copy())
  236. self._validate_generation_mode(generation_mode, generation_config, generation_mode_kwargs)
  237. # 2. Set generation parameters if not already defined
  238. if synced_gpus is None:
  239. synced_gpus = (is_deepspeed_zero3_enabled() or is_fsdp_managed_module(self)) and dist.get_world_size() > 1
  240. logits_processor = logits_processor if logits_processor is not None else LogitsProcessorList()
  241. stopping_criteria = stopping_criteria if stopping_criteria is not None else StoppingCriteriaList()
  242. # 3. Define model inputs
  243. kwargs_has_attention_mask = model_kwargs.get("attention_mask", None) is not None
  244. inputs_tensor, model_input_name, model_kwargs = self._prepare_model_inputs(
  245. inputs, generation_config.bos_token_id, model_kwargs
  246. )
  247. batch_size = inputs_tensor.shape[0]
  248. device = inputs_tensor.device
  249. self._prepare_special_tokens(generation_config, kwargs_has_attention_mask, device=device)
  250. # 4. Define other model kwargs
  251. if "encoder_outputs" not in model_kwargs:
  252. # if model is encoder decoder encoder_outputs are created and added to `model_kwargs`
  253. model_kwargs = self._prepare_encoder_decoder_kwargs_for_generation(
  254. inputs_tensor, model_kwargs, model_input_name, generation_config
  255. )
  256. # 5. Prepare `input_ids` which will be used for auto-regressive generation
  257. input_ids, model_kwargs = self._prepare_decoder_input_ids_for_generation(
  258. batch_size=batch_size,
  259. model_input_name=model_input_name,
  260. model_kwargs=model_kwargs,
  261. decoder_start_token_id=generation_config._decoder_start_token_tensor,
  262. device=inputs_tensor.device,
  263. )
  264. if generation_config.token_healing:
  265. input_ids = self.heal_tokens(input_ids, generation_mode_kwargs.get("tokenizer"))
  266. if streamer is not None:
  267. streamer.put(input_ids.cpu())
  268. # 6. Prepare `max_length` depending on other stopping criteria.
  269. # NOTE: incorrect `input_ids.shape[1]` previously
  270. input_ids_length = input_ids.shape[-1]
  271. has_default_max_length = kwargs.get("max_length") is None and generation_config.max_length is not None
  272. has_default_min_length = kwargs.get("min_length") is None and generation_config.min_length is not None
  273. generation_config = self._prepare_generated_length(
  274. generation_config=generation_config,
  275. has_default_max_length=has_default_max_length,
  276. has_default_min_length=has_default_min_length,
  277. model_input_name=model_input_name,
  278. inputs_tensor=inputs_tensor,
  279. input_ids_length=input_ids_length,
  280. )
  281. # If the model supports `logits_to_keep` in forward(), set it to 1 to avoid computing the whole
  282. # logit matrix. This can save a lot of memory during the first forward pass. Note that assisted decoding
  283. # dynamically overrides this value as it can need more than the last token logits
  284. if self._supports_logits_to_keep() and "logits_to_keep" not in model_kwargs:
  285. model_kwargs["logits_to_keep"] = 1
  286. self._validate_generated_length(generation_config, input_ids_length, has_default_max_length)
  287. # 7. Prepare the cache.
  288. # - `model_kwargs` may be updated in place with a cache as defined by the parameters in `generation_config`.
  289. # - different models have a different cache name expected by the model (default = "past_key_values")
  290. # - `max_length`, prepared above, is used to determine the maximum cache length
  291. max_cache_length = generation_config.max_length - 1
  292. if (
  293. inputs_tensor.shape[1] != input_ids_length
  294. and model_input_name == "inputs_embeds"
  295. and not self.config.is_encoder_decoder
  296. ):
  297. max_cache_length += inputs_tensor.shape[1]
  298. self._prepare_cache_for_generation(
  299. generation_config, model_kwargs, generation_mode, batch_size, max_cache_length
  300. )
  301. # 8. prepare logits processors and stopping criteria
  302. prepared_logits_processor = self._get_logits_processor(
  303. generation_config=generation_config,
  304. input_ids_seq_length=input_ids_length,
  305. encoder_input_ids=inputs_tensor,
  306. prefix_allowed_tokens_fn=prefix_allowed_tokens_fn,
  307. logits_processor=logits_processor,
  308. device=inputs_tensor.device,
  309. model_kwargs=model_kwargs,
  310. negative_prompt_ids=negative_prompt_ids,
  311. negative_prompt_attention_mask=negative_prompt_attention_mask,
  312. )
  313. prepared_stopping_criteria = self._get_stopping_criteria(
  314. generation_config=generation_config,
  315. stopping_criteria=stopping_criteria,
  316. tokenizer=generation_mode_kwargs.get("tokenizer"),
  317. )
  318. # Set model_kwargs `use_cache` so we can use it later in forward runs
  319. model_kwargs["use_cache"] = generation_config.use_cache
  320. # ******************* taken from main generate function up to calling the different methods *******************
  321. # Prepare inner 2D logic in generation loop
  322. input_ids = input_ids.reshape(-1, input_ids.shape[-1])
  323. # 10. go into different generation modes
  324. if generation_mode in (GenerationMode.SAMPLE, GenerationMode.GREEDY_SEARCH):
  325. # 11. expand input_ids with `num_return_sequences` additional sequences per batch
  326. if generation_config.num_return_sequences > 1:
  327. raise ValueError("`num_return_sequences>1` is incompatible with Dia.")
  328. # 12. run sample (it degenerates to greedy search when `generation_config.do_sample=False`)
  329. return self._sample(
  330. input_ids,
  331. logits_processor=prepared_logits_processor,
  332. stopping_criteria=prepared_stopping_criteria,
  333. generation_config=generation_config,
  334. **generation_mode_kwargs,
  335. **model_kwargs,
  336. )
  337. else:
  338. raise ValueError(
  339. "Got incompatible mode for generation, should be one of greedy or sampling. "
  340. "Ensure that beam search is de-activated by setting `num_beams=1`."
  341. )
  342. @torch.no_grad()
  343. def generate(
  344. self,
  345. inputs: Optional[torch.Tensor] = None,
  346. generation_config: Optional[GenerationConfig] = None,
  347. logits_processor: Optional[LogitsProcessorList] = None,
  348. stopping_criteria: Optional[StoppingCriteriaList] = None,
  349. prefix_allowed_tokens_fn: Optional[Callable[[int, torch.Tensor], list[int]]] = None,
  350. synced_gpus: Optional[bool] = None,
  351. assistant_model: Optional["PreTrainedModel"] = None,
  352. streamer: Optional["BaseStreamer"] = None,
  353. negative_prompt_ids: Optional[torch.Tensor] = None,
  354. negative_prompt_attention_mask: Optional[torch.Tensor] = None,
  355. use_model_defaults: Optional[bool] = None,
  356. custom_generate: Optional[str] = None,
  357. **kwargs,
  358. ) -> Union[GenerateOutput, torch.LongTensor]:
  359. # We expect the initial input ids to be the complete mask (delayed input)
  360. delay_mask = kwargs.get("decoder_input_ids")
  361. if delay_mask is not None:
  362. delay_mask = delay_mask.clone()
  363. output = self._main_generate_loop(
  364. inputs=inputs,
  365. generation_config=generation_config,
  366. logits_processor=logits_processor,
  367. stopping_criteria=stopping_criteria,
  368. prefix_allowed_tokens_fn=prefix_allowed_tokens_fn,
  369. synced_gpus=synced_gpus,
  370. assistant_model=assistant_model,
  371. streamer=streamer,
  372. negative_prompt_ids=negative_prompt_ids,
  373. negative_prompt_attention_mask=negative_prompt_attention_mask,
  374. use_model_defaults=use_model_defaults,
  375. custom_generate=custom_generate,
  376. **kwargs,
  377. )
  378. return_dict_in_generate = not isinstance(output, torch.Tensor)
  379. if return_dict_in_generate:
  380. output_sequences = output.sequences
  381. else:
  382. output_sequences = output
  383. # Reshape from 2D (bsz * channels, seq_len) to 3D (bsz, seq_len, channels)
  384. num_channels = self.config.decoder_config.num_channels
  385. bsz = output_sequences.shape[0] // num_channels
  386. output_sequences = output_sequences.reshape(bsz, num_channels, -1).transpose(1, 2)
  387. # Apply delay mask
  388. output_sequences = self.apply_delay_mask(output_sequences, self.config.pad_token_id, delay_mask)
  389. if return_dict_in_generate:
  390. output.sequences = output_sequences
  391. else:
  392. output = output_sequences
  393. return output