modular_csm.py 35 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525526527528529530531532533534535536537538539540541542543544545546547548549550551552553554555556557558559560561562563564565566567568569570571572573574575576577578579580581582583584585586587588589590591592593594595596597598599600601602603604605606607608609610611612613614615616617618619620621622623624625626627628629630631632633634635636637638639640641642643644645646647648649650651652653654655656657658659660661662663664665666667668669670671672673674675676677678679680681682683684685686687688689690691692693694695696697698699700701702703704705706707708709710711712713714715716717718719720721722723724725726727728729730731732733734735736737738739740741742743744745746747748749750751752753754755756757758759760761762763764765766
  1. # coding=utf-8
  2. # Copyright 2025 Sesame and The HuggingFace Inc. team. All rights reserved.
  3. #
  4. # Licensed under the Apache License, Version 2.0 (the "License");
  5. # you may not use this file except in compliance with the License.
  6. # You may obtain a copy of the License at
  7. #
  8. # http://www.apache.org/licenses/LICENSE-2.0
  9. #
  10. # Unless required by applicable law or agreed to in writing, software
  11. # distributed under the License is distributed on an "AS IS" BASIS,
  12. # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  13. # See the License for the specific language governing permissions and
  14. # limitations under the License.
  15. from dataclasses import dataclass
  16. from typing import Optional, Union
  17. import torch
  18. import torch.nn as nn
  19. from transformers.utils.generic import check_model_inputs
  20. from ...cache_utils import Cache, DynamicCache
  21. from ...generation import GenerationMixin
  22. from ...masking_utils import create_causal_mask
  23. from ...modeling_outputs import BaseModelOutputWithPast, CausalLMOutputWithPast
  24. from ...modeling_utils import PreTrainedModel
  25. from ...processing_utils import Unpack
  26. from ...utils import ModelOutput, auto_docstring, can_return_tuple, logging
  27. from ..auto import AutoModel
  28. from ..llama.modeling_llama import (
  29. LlamaAttention,
  30. LlamaDecoderLayer,
  31. LlamaForCausalLM,
  32. LlamaMLP,
  33. LlamaModel,
  34. LlamaRMSNorm,
  35. LlamaRotaryEmbedding,
  36. TransformersKwargs,
  37. )
  38. from .configuration_csm import CsmConfig, CsmDepthDecoderConfig
  39. from .generation_csm import CsmGenerationMixin
  40. logger = logging.get_logger(__name__)
  41. @dataclass
  42. @auto_docstring(
  43. custom_intro="""
  44. Base class for the model autoregressive outputs.
  45. """
  46. )
  47. class CsmOutputWithPast(ModelOutput):
  48. r"""
  49. loss (`torch.FloatTensor` of shape `(1,)`, *optional*, returned when `labels` is provided):
  50. Language modeling loss (for next-token prediction).
  51. logits (`torch.FloatTensor` of shape `(batch_size, sequence_length, config.vocab_size)`):
  52. Prediction scores of the language modeling head (scores for each vocabulary token before SoftMax).
  53. past_key_values (`Cache`, *optional*, returned when `use_cache=True` is passed or when `config.use_cache=True`):
  54. It is a [`~cache_utils.Cache`] instance. For more details, see our [kv cache guide](https://huggingface.co/docs/transformers/en/kv_cache).
  55. Contains pre-computed hidden-states (key and values in the self-attention blocks) that can be used (see
  56. `past_key_values` input) to speed up sequential decoding.
  57. depth_decoder_loss (`torch.FloatTensor` of shape `(1,)`, *optional*, returned when `labels` is provided):
  58. Language modeling loss (for next-token prediction) of the depth decoder model.
  59. depth_decoder_logits (`torch.FloatTensor` of shape `(batch_size, sequence_length, config.vocab_size)`):
  60. Prediction scores of the depth decoder (scores for each vocabulary token before SoftMax).
  61. depth_decoder_past_key_values (`Cache`, *optional*, returned when `use_cache=True` is passed or when `config.use_cache=True`):
  62. It is a [`~cache_utils.Cache`] instance. For more details, see our [kv cache guide](https://huggingface.co/docs/transformers/en/kv_cache).
  63. depth_decoder_hidden_states (`tuple(torch.FloatTensor)`, *optional*, returned when `output_hidden_states=True` is passed or when `config.output_hidden_states=True`):
  64. Tuple of `torch.FloatTensor` (one for the output of the embeddings, if the model has an embedding layer, +
  65. one for the output of each layer) of shape `(batch_size, sequence_length, hidden_size)`.
  66. Hidden-states of the model at the output of each layer plus the optional initial embedding outputs.
  67. depth_decoder_attentions (`tuple(torch.FloatTensor)`, *optional*, returned when `output_attentions=True` is passed or when `config.output_attentions=True`):
  68. Tuple of `torch.FloatTensor` (one for each layer) of shape `(batch_size, num_heads, sequence_length,
  69. sequence_length)`.
  70. backbone_loss (`torch.FloatTensor` of shape `(1,)`, *optional*, returned when `labels` is provided):
  71. Language modeling loss (for next-token prediction) of the backbone model.
  72. """
  73. loss: Optional[torch.FloatTensor] = None
  74. logits: Optional[torch.FloatTensor] = None
  75. past_key_values: Optional[Cache] = None
  76. hidden_states: Optional[tuple[torch.FloatTensor, ...]] = None
  77. attentions: Optional[tuple[torch.FloatTensor, ...]] = None
  78. depth_decoder_loss: Optional[torch.FloatTensor] = None
  79. depth_decoder_logits: Optional[torch.FloatTensor] = None
  80. depth_decoder_past_key_values: Optional[Cache] = None
  81. depth_decoder_hidden_states: Optional[tuple[torch.FloatTensor, ...]] = None
  82. depth_decoder_attentions: Optional[tuple[torch.FloatTensor, ...]] = None
  83. backbone_loss: Optional[torch.FloatTensor] = None
  84. # manually specify names for correct naming when converting from modular
  85. class CsmRMSNorm(LlamaRMSNorm):
  86. pass
  87. class CsmRotaryEmbedding(LlamaRotaryEmbedding):
  88. pass
  89. class CsmMLP(LlamaMLP):
  90. pass
  91. class CsmAttention(LlamaAttention):
  92. pass
  93. class CsmDecoderLayer(LlamaDecoderLayer):
  94. pass
  95. @auto_docstring(
  96. custom_intro="""
  97. The bare Csm Model outputting raw hidden-states without any specific head on top.
  98. """
  99. )
  100. @auto_docstring
  101. class CsmPreTrainedModel(PreTrainedModel):
  102. config: CsmConfig
  103. base_model_prefix = "model"
  104. supports_gradient_checkpointing = True
  105. _no_split_modules = ["CsmDecoderLayer"]
  106. _skip_keys_device_placement = ["past_key_values"]
  107. _supports_flash_attn = True
  108. _supports_sdpa = True
  109. # does not because of Mimi codec model
  110. # _supports_flex_attn = True
  111. _can_compile_fullgraph = True
  112. _supports_attention_backend = True
  113. _can_record_outputs = {
  114. "hidden_states": CsmDecoderLayer,
  115. "attentions": CsmAttention,
  116. }
  117. def _init_weights(self, module):
  118. super()._init_weights(module)
  119. if isinstance(module, CsmCodebooksHead):
  120. num_codebooks = module.num_codebooks
  121. for i in range(num_codebooks - 1):
  122. module.weight.data[i].normal_(mean=0.0, std=self.config.initializer_range)
  123. @auto_docstring
  124. class CsmDepthDecoderModel(LlamaModel, CsmPreTrainedModel):
  125. config: CsmDepthDecoderConfig
  126. def __init__(self, config):
  127. super().__init__(config)
  128. self.embed_tokens = nn.Embedding((config.num_codebooks * config.vocab_size), config.backbone_hidden_size)
  129. self.inputs_embeds_projector = nn.Linear(config.backbone_hidden_size, config.hidden_size, bias=False)
  130. @check_model_inputs()
  131. @auto_docstring
  132. def forward(
  133. self,
  134. input_ids: Optional[torch.LongTensor] = None,
  135. backbone_last_hidden_state: Optional[torch.FloatTensor] = None,
  136. attention_mask: Optional[torch.Tensor] = None,
  137. position_ids: Optional[torch.LongTensor] = None,
  138. past_key_values: Optional[Cache] = None,
  139. inputs_embeds: Optional[torch.FloatTensor] = None,
  140. use_cache: Optional[bool] = None,
  141. cache_position: Optional[torch.LongTensor] = None,
  142. **kwargs: Unpack[TransformersKwargs],
  143. ) -> Union[tuple, BaseModelOutputWithPast]:
  144. r"""
  145. backbone_last_hidden_state (`torch.FloatTensor` of shape `(batch_size, backbone_hidden_size)`, *optional*):
  146. The last hidden state of the backbone model. Such input is required when the first codebook token (the one generated by the backbone model)
  147. is provided in the `input_ids` argument.
  148. """
  149. if position_ids is not None and not torch.compiler.is_compiling():
  150. logger.warning_once(
  151. "Custom `position_ids` were provided but will be ignored. CSM depth decoder automatically determines position_ids "
  152. "from `cache_position` and as it requires them to be identical across the batch, the provided position_ids will be ignored."
  153. )
  154. position_ids = None
  155. if (input_ids is None) ^ (inputs_embeds is not None):
  156. raise ValueError("You must specify exactly one of input_ids or inputs_embeds.")
  157. if use_cache and past_key_values is None:
  158. past_key_values = DynamicCache(config=self.config)
  159. if cache_position is None:
  160. past_seen_tokens = past_key_values.get_seq_length() if past_key_values is not None else 0
  161. inputs_seq_length = inputs_embeds.shape[1] if inputs_embeds is not None else input_ids.shape[1]
  162. device = inputs_embeds.device if inputs_embeds is not None else input_ids.device
  163. cache_position = torch.arange(past_seen_tokens, past_seen_tokens + inputs_seq_length, device=device)
  164. if inputs_embeds is None:
  165. codebook_idxs = torch.clamp(cache_position - 1, min=0)
  166. offset = codebook_idxs * self.vocab_size
  167. inputs_embeds = self.embed_tokens(input_ids + offset)
  168. input_ids_are_first_codebook = cache_position[0] == 0
  169. if backbone_last_hidden_state is not None:
  170. inputs_embeds[:, 0] = backbone_last_hidden_state
  171. else:
  172. if not torch.compiler.is_compiling() and input_ids_are_first_codebook:
  173. logger.warning(
  174. "When the first codebook token is provided, `backbone_last_hidden_state` should also be provided for correct inference."
  175. )
  176. inputs_embeds = self.inputs_embeds_projector(inputs_embeds)
  177. causal_mask = create_causal_mask(
  178. config=self.config,
  179. input_embeds=inputs_embeds,
  180. attention_mask=attention_mask,
  181. cache_position=cache_position,
  182. past_key_values=past_key_values,
  183. position_ids=position_ids,
  184. )
  185. hidden_states = inputs_embeds
  186. # create position embeddings to be shared across the decoder layers
  187. position_ids = cache_position.unsqueeze(0)
  188. position_embeddings = self.rotary_emb(hidden_states, position_ids)
  189. for decoder_layer in self.layers[: self.config.num_hidden_layers]:
  190. hidden_states = decoder_layer(
  191. hidden_states,
  192. attention_mask=causal_mask,
  193. position_ids=position_ids,
  194. past_key_values=past_key_values,
  195. use_cache=use_cache,
  196. cache_position=cache_position,
  197. position_embeddings=position_embeddings,
  198. **kwargs,
  199. )
  200. hidden_states = self.norm(hidden_states)
  201. return BaseModelOutputWithPast(
  202. last_hidden_state=hidden_states,
  203. past_key_values=past_key_values if use_cache else None,
  204. )
  205. class CsmCodebooksHead(nn.Module):
  206. def __init__(self, hidden_size, num_codebooks, vocab_size):
  207. super().__init__()
  208. self.num_codebooks = num_codebooks
  209. self.weight = nn.Parameter(torch.empty(self.num_codebooks - 1, hidden_size, vocab_size))
  210. def forward(self, hidden_states, cache_position=None):
  211. if cache_position is None:
  212. seq_length = hidden_states.shape[1]
  213. codebook_weight = self.weight[torch.arange(seq_length)]
  214. else:
  215. codebook_idxs = cache_position - 1
  216. codebook_weight = self.weight[codebook_idxs]
  217. hidden_states = [
  218. nn.functional.linear(hidden_states[:, codebook_idx, :], codebook_weight[codebook_idx].T)
  219. for codebook_idx in range(codebook_weight.shape[0])
  220. ]
  221. hidden_states = torch.stack(hidden_states, dim=1)
  222. return hidden_states
  223. @auto_docstring(
  224. custom_intro="""
  225. The CsmDepthDecoder Model transformer, with a [`CsmCodebooksHead`] on top,
  226. which can be seen a position-specific language modeling head, allowing to use a different linear layer for each codebook
  227. (e.g. position 0 is the first codebook and uses the first codebook head, etc.)
  228. """
  229. )
  230. class CsmDepthDecoderForCausalLM(LlamaForCausalLM, GenerationMixin):
  231. _tied_weights_keys = None
  232. _tp_plan = None
  233. _pp_plan = None
  234. def __init__(self, config):
  235. super().__init__(config)
  236. del self.lm_head
  237. self.codebooks_head = CsmCodebooksHead(config.hidden_size, config.num_codebooks, config.vocab_size)
  238. self.model = CsmDepthDecoderModel(config)
  239. def prepare_inputs_for_generation(
  240. self,
  241. input_ids: torch.LongTensor,
  242. past_key_values: Optional[Cache] = None,
  243. attention_mask: Optional[torch.LongTensor] = None,
  244. inputs_embeds: Optional[torch.FloatTensor] = None,
  245. cache_position: Optional[torch.LongTensor] = None,
  246. **kwargs,
  247. ):
  248. model_inputs = super().prepare_inputs_for_generation(
  249. input_ids, past_key_values, attention_mask, inputs_embeds, cache_position, **kwargs
  250. )
  251. is_first_generation_step = model_inputs["cache_position"][0] == 0
  252. if not is_first_generation_step:
  253. model_inputs.pop("backbone_last_hidden_state")
  254. # csm depth decoder does not use position_ids
  255. model_inputs.pop("position_ids")
  256. return model_inputs
  257. @can_return_tuple
  258. @auto_docstring
  259. def forward(
  260. self,
  261. input_ids: Optional[torch.LongTensor] = None,
  262. backbone_last_hidden_state: Optional[torch.FloatTensor] = None,
  263. attention_mask: Optional[torch.Tensor] = None,
  264. position_ids: Optional[torch.LongTensor] = None,
  265. past_key_values: Optional[Union[Cache, list[torch.FloatTensor]]] = None,
  266. inputs_embeds: Optional[torch.FloatTensor] = None,
  267. labels: Optional[torch.LongTensor] = None,
  268. use_cache: Optional[bool] = None,
  269. cache_position: Optional[torch.LongTensor] = None,
  270. logits_to_keep: Union[int, torch.Tensor] = 0,
  271. **kwargs: Unpack[TransformersKwargs],
  272. ) -> Union[tuple, CausalLMOutputWithPast]:
  273. r"""
  274. backbone_last_hidden_state (`torch.FloatTensor` of shape `(batch_size, backbone_hidden_size)`, *optional*):
  275. The last hidden state of the backbone model. Such input is required when the first codebook token (the one generated by the backbone model)
  276. is provided in the `input_ids` argument.
  277. labels (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*):
  278. Labels for computing the masked language modeling loss. Indices should either be in `[0, ...,
  279. config.vocab_size]` or -100 (see `input_ids` docstring). Tokens with indices set to `-100` are ignored
  280. (masked), the loss is only computed for the tokens with labels in `[0, ..., config.vocab_size]`.
  281. """
  282. outputs = self.model(
  283. input_ids=input_ids,
  284. backbone_last_hidden_state=backbone_last_hidden_state,
  285. attention_mask=attention_mask,
  286. position_ids=position_ids,
  287. past_key_values=past_key_values,
  288. inputs_embeds=inputs_embeds,
  289. use_cache=use_cache,
  290. cache_position=cache_position,
  291. **kwargs,
  292. )
  293. hidden_states = outputs[0]
  294. # Only compute necessary logits, and do not upcast them to float if we are not computing the loss
  295. if isinstance(logits_to_keep, int):
  296. if logits_to_keep == 0:
  297. # skip idx 0 logits since it's for the concatenated backbone last hidden state
  298. slice_indices = slice(1, None)
  299. else:
  300. slice_indices = slice(-logits_to_keep, None)
  301. else:
  302. slice_indices = logits_to_keep
  303. logits = self.codebooks_head(
  304. hidden_states[:, slice_indices, :], cache_position[slice_indices] if cache_position is not None else None
  305. )
  306. logits = logits.contiguous()
  307. loss = None
  308. if labels is not None:
  309. shift_labels = labels[..., 1:].contiguous()
  310. loss = self.loss_function(
  311. logits=logits, labels=None, vocab_size=self.config.vocab_size, shift_labels=shift_labels, **kwargs
  312. )
  313. return CausalLMOutputWithPast(
  314. loss=loss,
  315. logits=logits,
  316. past_key_values=outputs.past_key_values,
  317. hidden_states=outputs.hidden_states,
  318. attentions=outputs.attentions,
  319. )
  320. class CsmBackboneModelEmbeddings(nn.Module):
  321. def __init__(self, config):
  322. super().__init__()
  323. self.embed_audio_tokens = nn.Embedding((config.num_codebooks * config.vocab_size), config.hidden_size)
  324. self.register_buffer(
  325. "audio_tokens_offsets", torch.arange(config.num_codebooks) * config.vocab_size, persistent=False
  326. )
  327. def forward(self, input_ids):
  328. input_embeds = self.embed_audio_tokens(input_ids + self.audio_tokens_offsets)
  329. input_embeds = input_embeds.sum(dim=2)
  330. return input_embeds
  331. @auto_docstring
  332. class CsmBackboneModel(LlamaModel):
  333. def __init__(self, config):
  334. super().__init__(config)
  335. self.embed_tokens = CsmBackboneModelEmbeddings(config)
  336. @check_model_inputs()
  337. @auto_docstring
  338. def forward(self, **super_kwargs):
  339. r"""
  340. input_ids (`torch.LongTensor` of shape `(batch_size, sequence_length, num_codebooks) or (batch_size, sequence_length)`):
  341. 1. (batch_size, sequence_length): corresponds to the input sequence prepared with the processor from the text prompt. Such input
  342. requires `input_values` to be provided so that audio can be encoded in codebook tokens and then merged with the text tokens.
  343. 2. (batch_size, sequence_length, num_codebooks): codebook tokens generated during the autoregressive decoding. Such input is not meant to be used by end users.
  344. Indices can be obtained using [`AutoTokenizer`]. See [`PreTrainedTokenizer.encode`] and
  345. [`PreTrainedTokenizer.__call__`] for details.
  346. [What are input IDs?](../glossary#input-ids)
  347. """
  348. return super().forward(**super_kwargs)
  349. @auto_docstring(
  350. custom_intro="""
  351. The Csm model consists of two llama-like auto-regressive transformer models: a backbone model that predicts the first codebook token and a depth decoder that predicts the other codebook tokens.
  352. """
  353. )
  354. class CsmForConditionalGeneration(CsmPreTrainedModel, CsmGenerationMixin):
  355. _tied_weights_keys = [
  356. "backbone_model.embed_tokens.embed_audio_tokens.weight",
  357. "depth_decoder.model.embed_tokens.weight",
  358. ]
  359. def __init__(self, config):
  360. super().__init__(config)
  361. self.vocab_size = config.vocab_size
  362. self.lm_head = nn.Linear(config.hidden_size, config.vocab_size, bias=False)
  363. self.embed_text_tokens = nn.Embedding(config.text_vocab_size, config.hidden_size)
  364. self.backbone_model = CsmBackboneModel._from_config(config)
  365. self.depth_decoder = CsmDepthDecoderForCausalLM._from_config(config.depth_decoder_config)
  366. self.codec_model = AutoModel.from_config(config.codec_config)
  367. self.post_init()
  368. def get_input_embeddings(self):
  369. return self.backbone_model.embed_tokens
  370. def set_input_embeddings(self, value):
  371. self.backbone_model.embed_tokens = value
  372. def _tie_weights(self):
  373. if self.config.tie_codebooks_embeddings:
  374. self._tie_or_clone_weights(
  375. self.backbone_model.embed_tokens.embed_audio_tokens,
  376. self.depth_decoder.model.embed_tokens,
  377. )
  378. @classmethod
  379. def from_pretrained(cls, *args, **kwargs):
  380. if kwargs.get("output_loading_info", False):
  381. model, loading_info = super().from_pretrained(*args, **kwargs)
  382. else:
  383. model = super().from_pretrained(*args, **kwargs)
  384. # copy depth decoder generation conf attr to the depth decoder generation config
  385. prefix = "depth_decoder_"
  386. prefix_len = len(prefix)
  387. depth_decoder_attrs = {
  388. attr[prefix_len:]: value
  389. for attr, value in vars(model.generation_config).items()
  390. if attr.startswith(prefix)
  391. }
  392. vars(model.depth_decoder.generation_config).update({"_from_model_config": False, **depth_decoder_attrs})
  393. # remove the depth decoder generation conf attr from the model generation config
  394. for attr in depth_decoder_attrs:
  395. delattr(model.generation_config, prefix + attr)
  396. if "output_loading_info" in kwargs:
  397. return model, loading_info
  398. else:
  399. return model
  400. def save_pretrained(self, *args, **kwargs):
  401. # copy the depth decoder generation config attributes to the model generation config
  402. prefix = "depth_decoder_"
  403. depth_decoder_attrs = self.depth_decoder.generation_config.to_diff_dict()
  404. depth_decoder_attrs.pop("transformers_version", None)
  405. for attr, value in depth_decoder_attrs.items():
  406. setattr(self.generation_config, prefix + attr, value)
  407. super().save_pretrained(*args, **kwargs)
  408. def _merge_input_ids_with_input_values(
  409. self,
  410. input_ids: Optional[torch.Tensor] = None,
  411. input_values: Optional[torch.Tensor] = None,
  412. input_values_cutoffs: Optional[torch.Tensor] = None,
  413. labels: Optional[torch.Tensor] = None,
  414. ) -> Optional[torch.Tensor]:
  415. """
  416. Merges the input_ids and input_values to produce a single inputs_embeds tensor:
  417. 1 - Infers the codec model on the input_values to retrieve codebook token.
  418. 2 - Embeds codebook tokens and places them at the correct positions in the inputs_embeds tensor.
  419. 3 - If labels are provided, expands them to match codebook dimensions and position the target codebook tokens in the inputs_embeds tensor.
  420. Args:
  421. input_ids (`torch.Tensor` of shape `(batch_size, sequence_length)`):
  422. The input ids to embed.
  423. input_values (`torch.Tensor` of shape `(batch_size, channels, audio_sequence_length)`):
  424. The audio input values to embed.
  425. input_values_cutoffs (`torch.Tensor` of shape `(batch_size, max_num_audio)`):
  426. The cutoffs of the audio input values relative to its batch index, padded with -1 when no audio.
  427. """
  428. inputs_embeds = self.embed_text_tokens(input_ids)
  429. if input_values is not None:
  430. # infer input_values_mask
  431. input_values_cutoffs = nn.functional.pad(input_values_cutoffs, (1, 0))
  432. audio_lengths = input_values_cutoffs[input_values_cutoffs >= 0].diff()
  433. audio_lengths = audio_lengths[audio_lengths > 0]
  434. input_values_mask = torch.arange(input_values_cutoffs.max(), device=input_values.device).expand(
  435. len(audio_lengths), -1
  436. )
  437. input_values_mask = input_values_mask < audio_lengths.unsqueeze(1)
  438. # =======================================
  439. # TODO: @eustlb, this should be batched !!!
  440. # but requires making sure batched inference of the codec model works as intended
  441. with torch.no_grad():
  442. audio_tokens_list = []
  443. for batch_input_values, batch_input_values_cutoffs in zip(input_values, input_values_cutoffs):
  444. batch_input_values_cutoffs = batch_input_values_cutoffs[batch_input_values_cutoffs >= 0]
  445. for i in range(batch_input_values_cutoffs.shape[0] - 1):
  446. start_idx = batch_input_values_cutoffs[i]
  447. end_idx = batch_input_values_cutoffs[i + 1]
  448. audio_batch = batch_input_values[..., start_idx:end_idx]
  449. codec_outputs = self.codec_model.encode(audio_batch.unsqueeze(0))
  450. codebook_ids = codec_outputs.audio_codes.transpose(1, -1)
  451. audio_tokens_list.append(codebook_ids[0])
  452. max_audio_frames = max(el.shape[0] for el in audio_tokens_list)
  453. batched_audio_token_ids = torch.stack(
  454. [nn.functional.pad(el, (0, 0, 0, max_audio_frames - el.shape[0])) for el in audio_tokens_list]
  455. )
  456. audio_codes_mask = self.codec_model.get_audio_codes_mask(input_values_mask)
  457. # =======================================
  458. audio_token_id = self.config.audio_token_id
  459. audio_token_mask = input_ids == audio_token_id
  460. audio_embeds = self.backbone_model.embed_tokens(batched_audio_token_ids)
  461. inputs_embeds[audio_token_mask] = audio_embeds[audio_codes_mask]
  462. # same for the audio eos token
  463. audio_eos_frame_ids = (
  464. torch.ones((1, 1, self.config.num_codebooks), device=input_ids.device, dtype=torch.long)
  465. * self.config.codebook_eos_token_id
  466. )
  467. audio_eos_embeds = self.backbone_model.embed_tokens(audio_eos_frame_ids).squeeze(1)
  468. audio_eos_token_mask = input_ids == self.config.audio_eos_token_id
  469. inputs_embeds[audio_eos_token_mask] = audio_eos_embeds.repeat(audio_eos_token_mask.sum(), 1)
  470. # if the labels are provided, we need to expand the labels to (batch_size, seq_length, num_codebooks)
  471. if labels is not None:
  472. labels_expanded = labels.unsqueeze(-1).repeat(1, 1, self.config.num_codebooks)
  473. labels_expanded[audio_token_mask] = batched_audio_token_ids[audio_codes_mask]
  474. labels_expanded[audio_eos_token_mask] = audio_eos_frame_ids
  475. # mask depth decoder
  476. depth_decoder_ignore_frames_idxs = (labels == -101).nonzero(as_tuple=True)
  477. labels_expanded[depth_decoder_ignore_frames_idxs[0], depth_decoder_ignore_frames_idxs[1], 1:] = -100
  478. labels = labels_expanded
  479. return {"inputs_embeds": inputs_embeds, "labels": labels}
  480. def prepare_inputs_for_generation(
  481. self,
  482. input_ids: torch.LongTensor,
  483. past_key_values: Optional[Cache] = None,
  484. attention_mask: Optional[torch.LongTensor] = None,
  485. inputs_embeds: Optional[torch.FloatTensor] = None,
  486. cache_position: Optional[torch.LongTensor] = None,
  487. **kwargs,
  488. ):
  489. model_inputs = super().prepare_inputs_for_generation(
  490. input_ids=input_ids,
  491. past_key_values=past_key_values,
  492. attention_mask=attention_mask,
  493. inputs_embeds=inputs_embeds,
  494. cache_position=cache_position,
  495. **kwargs,
  496. )
  497. if input_ids is not None and input_ids.ndim == 2 and model_inputs.get("inputs_embeds") is None:
  498. merged_inputs = self._merge_input_ids_with_input_values(
  499. input_ids=input_ids,
  500. input_values=kwargs.get("input_values"),
  501. input_values_cutoffs=kwargs.get("input_values_cutoffs"),
  502. labels=kwargs.get("labels"),
  503. )
  504. model_inputs.update(
  505. {"inputs_embeds": merged_inputs["inputs_embeds"], "labels": merged_inputs["labels"], "input_ids": None}
  506. )
  507. return model_inputs
  508. @can_return_tuple
  509. @auto_docstring
  510. def forward(
  511. self,
  512. input_ids: Optional[torch.LongTensor] = None,
  513. input_values: Optional[torch.Tensor] = None,
  514. attention_mask: Optional[torch.Tensor] = None,
  515. input_values_cutoffs: Optional[torch.Tensor] = None,
  516. position_ids: Optional[torch.LongTensor] = None,
  517. past_key_values: Optional[Union[Cache, list[torch.FloatTensor]]] = None,
  518. inputs_embeds: Optional[torch.FloatTensor] = None,
  519. labels: Optional[torch.LongTensor] = None,
  520. use_cache: Optional[bool] = None,
  521. cache_position: Optional[torch.LongTensor] = None,
  522. logits_to_keep: Union[int, torch.Tensor] = 0,
  523. **kwargs: Unpack[TransformersKwargs],
  524. ) -> Union[tuple, CsmOutputWithPast]:
  525. r"""
  526. input_ids (`torch.LongTensor` of shape `(batch_size, sequence_length, num_codebooks) or (batch_size, sequence_length)`):
  527. 1. (batch_size, sequence_length): corresponds to the input sequence prepared with the processor from the text prompt. Such input
  528. requires `input_values` to be provided so that audio can be encoded in codebook tokens and then merged with the text tokens.
  529. 2. (batch_size, sequence_length, num_codebooks): codebook tokens generated during the autoregressive decoding. Such input is not meant to be used by end users.
  530. Indices can be obtained using [`AutoTokenizer`]. See [`PreTrainedTokenizer.encode`] and
  531. [`PreTrainedTokenizer.__call__`] for details.
  532. [What are input IDs?](../glossary#input-ids)
  533. input_values_cutoffs (`torch.Tensor` of shape `(batch_size, max_num_audio)`, *optional*):
  534. Specify the end positions of audio segments within each batch entry, relative to the concatenated audio input.
  535. If a batch entry has fewer segments than the maximum, it is padded with -1. For example, in a batch of 2 sequences
  536. where the first contains 2 audio segments of length l1, and the second contains 1 audio segment of length l2,
  537. the input_values_cutoffs would be: [[l1, 2 * l1], [l2, -1]].
  538. labels (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*):
  539. Labels for computing the masked language modeling loss. Indices should be in `[config.audio_token_id, -100, -101]`.
  540. Requires targeted `input_values` to be provided as audio tokens will be inferred from it using the `codec_model`.
  541. - `config.audio_token_id` indicates an audio frames (considering sequence length elements as frames)
  542. - `-100` will be ignored in the loss computation
  543. - `-101` indicates the audio frame will be used only for the backbone model (using the first codebook token as labels)
  544. Such labels can be prepared using `output_labels=True` when calling [`CsmProcessor`].
  545. logits_to_keep (`int` or `torch.Tensor`, *optional*):
  546. Kept for compatibility. Does not support another value than:
  547. 1. `0`, which is equivalent to keeping all logits, used in the training regime
  548. 2. `1`, which is equivalent to keeping only the last logit, used in the generation regime
  549. Example:
  550. ```python
  551. >>> import torch
  552. >>> from transformers import CsmForConditionalGeneration, AutoProcessor
  553. >>> from datasets import load_dataset, Audio
  554. >>> model_id = "sesame/csm-1b"
  555. >>> torch_device = "cuda" if torch.cuda.is_available() else "cpu"
  556. >>> processor = AutoProcessor.from_pretrained(model_id)
  557. >>> ds = load_dataset("hf-internal-testing/dailytalk-dummy", split="train")
  558. >>> # ensure the audio is 24kHz
  559. >>> ds = ds.cast_column("audio", Audio(sampling_rate=24000))
  560. >>> conversation = []
  561. >>> # prepare a conversation with text and corresponding audio
  562. >>> for text, audio, speaker_id in zip(ds[:4]["text"], ds[:4]["audio"], ds[:4]["speaker_id"]):
  563. ... conversation.append(
  564. ... {
  565. ... "role": f"{speaker_id}",
  566. ... "content": [{"type": "text", "text": text}, {"type": "audio", "path": audio["array"]}],
  567. ... }
  568. ... )
  569. >>> inputs = processor.apply_chat_template(
  570. ... conversation,
  571. ... tokenize=True,
  572. ... return_dict=True,
  573. ... output_labels=True,
  574. ... ).to(torch_device)
  575. >>> model = CsmForConditionalGeneration.from_pretrained(model_id, device_map=torch_device)
  576. >>> output = model(**inputs)
  577. >>> output.loss.backward()
  578. ```"""
  579. if input_ids is not None and input_ids.ndim == 2:
  580. merged_inputs = self._merge_input_ids_with_input_values(
  581. input_ids, input_values, input_values_cutoffs, labels
  582. )
  583. inputs_embeds = merged_inputs["inputs_embeds"]
  584. labels = merged_inputs["labels"]
  585. input_ids = None
  586. backbone_outputs = self.backbone_model(
  587. input_ids=input_ids,
  588. attention_mask=attention_mask,
  589. position_ids=position_ids,
  590. past_key_values=past_key_values,
  591. inputs_embeds=inputs_embeds,
  592. use_cache=use_cache,
  593. cache_position=cache_position,
  594. **kwargs,
  595. )
  596. backbone_hidden_states = backbone_outputs[0]
  597. # Only compute necessary logits, and do not upcast them to float if we are not computing the loss
  598. slice_indices = slice(-logits_to_keep, None) if isinstance(logits_to_keep, int) else logits_to_keep
  599. backbone_logits = self.lm_head(backbone_hidden_states[:, slice_indices, :])
  600. loss = None
  601. backbone_loss = None
  602. depth_decoder_loss = None
  603. depth_decoder_outputs = None
  604. if labels is not None:
  605. # select first codebook as labels for the backbone model
  606. backbone_labels = labels[:, :, 0]
  607. backbone_loss = self.loss_function(
  608. logits=backbone_logits, labels=backbone_labels, vocab_size=self.config.vocab_size, **kwargs
  609. )
  610. # for the depth decoder, we need to select the frames to train on
  611. # those are frames where the label is not uniformly `ignore_index` along the codebook dimension
  612. train_mask = ~(labels[:, :, 1:] == -100).all(dim=-1)
  613. depth_decoder_input_ids = labels[train_mask][..., : self.config.num_codebooks - 1]
  614. # add place holder in position 0 that will be replaced by the backbone_last_hidden_state
  615. depth_decoder_input_ids = nn.functional.pad(depth_decoder_input_ids, (1, 0), value=0)
  616. train_idxs = train_mask.nonzero(as_tuple=True)
  617. backbone_last_hidden_states = backbone_hidden_states[train_idxs[0], train_idxs[1] - 1, :]
  618. depth_decoder_labels = labels[train_mask]
  619. depth_decoder_outputs = self.depth_decoder(
  620. input_ids=depth_decoder_input_ids,
  621. backbone_last_hidden_state=backbone_last_hidden_states,
  622. use_cache=use_cache,
  623. return_dict=True,
  624. labels=depth_decoder_labels,
  625. **kwargs,
  626. )
  627. depth_decoder_loss = depth_decoder_outputs.loss
  628. loss = backbone_loss + depth_decoder_loss
  629. return CsmOutputWithPast(
  630. loss=loss,
  631. backbone_loss=backbone_loss,
  632. depth_decoder_loss=depth_decoder_loss,
  633. logits=backbone_logits,
  634. past_key_values=backbone_outputs.past_key_values,
  635. hidden_states=backbone_outputs.hidden_states,
  636. attentions=backbone_outputs.attentions,
  637. depth_decoder_logits=depth_decoder_outputs.logits if depth_decoder_outputs is not None else None,
  638. depth_decoder_past_key_values=depth_decoder_outputs.past_key_values
  639. if depth_decoder_outputs is not None
  640. else None,
  641. depth_decoder_hidden_states=depth_decoder_outputs.hidden_states
  642. if depth_decoder_outputs is not None
  643. else None,
  644. depth_decoder_attentions=depth_decoder_outputs.attentions if depth_decoder_outputs is not None else None,
  645. )
  646. __all__ = [
  647. "CsmPreTrainedModel",
  648. "CsmBackboneModel",
  649. "CsmDepthDecoderModel",
  650. "CsmDepthDecoderForCausalLM",
  651. "CsmForConditionalGeneration",
  652. ]