modular_dia.py 33 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525526527528529530531532533534535536537538539540541542543544545546547548549550551552553554555556557558559560561562563564565566567568569570571572573574575576577578579580581582583584585586587588589590591592593594595596597598599600601602603604605606607608609610611612613614615616617618619620621622623624625626627628629630631632633634635636637638639640641642643644645646647648649650651652653654655656657658659660661662663664665666667668669670671672673674675676677678679680681682683684685686687688689690691692693694695696697698699700701702703704705706707708709710711712713714715716717718719720721722723724725726727728729730731732733734735736737738739740741742743744745746747748749750751752753754755756757758759760761762763764765766767768769770771772773
  1. # coding=utf-8
  2. # Copyright 2025 The Nari Labs and HuggingFace Inc. team. All rights reserved.
  3. #
  4. # Licensed under the Apache License, Version 2.0 (the "License");
  5. # you may not use this file except in compliance with the License.
  6. # You may obtain a copy of the License at
  7. #
  8. # http://www.apache.org/licenses/LICENSE-2.0
  9. #
  10. # Unless required by applicable law or agreed to in writing, software
  11. # distributed under the License is distributed on an "AS IS" BASIS,
  12. # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  13. # See the License for the specific language governing permissions and
  14. # limitations under the License.
  15. """PyTorch Dia model."""
  16. from typing import Callable, Optional, Union
  17. import torch
  18. from torch import nn
  19. from ...cache_utils import DynamicCache, EncoderDecoderCache
  20. from ...masking_utils import create_causal_mask
  21. from ...modeling_attn_mask_utils import (
  22. _prepare_4d_attention_mask,
  23. _prepare_4d_attention_mask_for_sdpa,
  24. )
  25. from ...modeling_flash_attention_utils import FlashAttentionKwargs
  26. from ...modeling_layers import GradientCheckpointingLayer
  27. from ...modeling_outputs import (
  28. BaseModelOutput,
  29. BaseModelOutputWithPastAndCrossAttentions,
  30. Seq2SeqLMOutput,
  31. Seq2SeqModelOutput,
  32. )
  33. from ...modeling_utils import ALL_ATTENTION_FUNCTIONS, PreTrainedModel
  34. from ...processing_utils import Unpack
  35. from ...utils import auto_docstring, can_return_tuple, is_torch_flex_attn_available, is_torchdynamo_compiling, logging
  36. from ..llama.modeling_llama import (
  37. LlamaAttention,
  38. LlamaRMSNorm,
  39. LlamaRotaryEmbedding,
  40. eager_attention_forward,
  41. )
  42. from ..phi3.modeling_phi3 import Phi3MLP
  43. from .configuration_dia import DiaConfig, DiaDecoderConfig, DiaEncoderConfig
  44. from .generation_dia import DiaGenerationMixin
  45. if is_torch_flex_attn_available():
  46. from ...integrations.flex_attention import make_flex_block_causal_mask
  47. logger = logging.get_logger(__name__)
  48. @auto_docstring
  49. class DiaPreTrainedModel(PreTrainedModel):
  50. config: DiaConfig
  51. base_model_prefix = "model"
  52. supports_gradient_checkpointing = True
  53. _supports_flash_attn = True
  54. _supports_sdpa = True
  55. _supports_flex_attn = True
  56. _can_compile_fullgraph = True
  57. main_input_name = "input_ids"
  58. _no_split_modules = ["DiaEncoderLayer", "DiaDecoderLayer"]
  59. class DiaMultiChannelEmbedding(nn.Module):
  60. """In order to efficiently compute the audio embedding from the 9 different channels,
  61. we vectorize the embedding process by using a single embedding layer and an offset.
  62. Example:
  63. - num_embeds = 4
  64. - vocab_size = 8
  65. - num_channels = 3
  66. We would have offsets = [0, 8, 16]
  67. If audio_codes = [0, 1, 2, 3], [1, 3, 4, 7], [5, 6, 7, 8],
  68. then tokens = audio_codes + offsets
  69. = [0, 1, 2, 3, 9, 11, 12, 15, 21, 22, 23, 24]
  70. This allows us to use a single embedding layer for all channels.
  71. """
  72. def __init__(self, config: DiaDecoderConfig):
  73. super().__init__()
  74. self.embed = nn.Embedding(config.vocab_size * config.num_channels, config.hidden_size)
  75. self.hidden_size = config.hidden_size
  76. self.num_channels = config.num_channels
  77. offsets = torch.arange(config.num_channels, dtype=torch.long) * config.vocab_size # (C,)
  78. self.register_buffer("offsets", offsets, persistent=False)
  79. def forward(self, audio_codes: torch.Tensor) -> torch.Tensor:
  80. tokens = (audio_codes + self.offsets.to(audio_codes.device)).squeeze(1)
  81. embeds = self.embed(tokens).view(tokens.shape[0], audio_codes.shape[1], -1, self.hidden_size)
  82. return embeds.sum(dim=2)
  83. class DiaMLP(Phi3MLP):
  84. pass
  85. class DiaRMSNorm(LlamaRMSNorm):
  86. pass
  87. class DiaRotaryEmbedding(LlamaRotaryEmbedding):
  88. pass
  89. class DiaSelfAttention(LlamaAttention):
  90. """Multi-headed attention from 'Attention Is All You Need' paper"""
  91. def __init__(self, config: Union[DiaEncoderConfig, DiaDecoderConfig], layer_idx: int, is_causal: bool = False):
  92. nn.Module.__init__(self)
  93. self.config = config
  94. self.layer_idx = layer_idx
  95. self.hidden_size = config.hidden_size
  96. self.num_heads = self.config.num_attention_heads
  97. self.num_key_value_heads = self.config.num_key_value_heads or self.num_heads
  98. self.num_key_value_groups = self.num_heads // self.num_key_value_heads
  99. self.head_dim = getattr(config, "head_dim", config.hidden_size // self.num_heads)
  100. self.scaling = 1
  101. self.attention_dropout = 0.0
  102. self.is_causal = is_causal
  103. self.q_proj = nn.Linear(self.hidden_size, self.num_heads * self.head_dim, bias=False)
  104. self.k_proj = nn.Linear(self.hidden_size, self.num_key_value_heads * self.head_dim, bias=False)
  105. self.v_proj = nn.Linear(self.hidden_size, self.num_key_value_heads * self.head_dim, bias=False)
  106. self.o_proj = nn.Linear(self.num_heads * self.head_dim, self.hidden_size, bias=False)
  107. class DiaCrossAttention(nn.Module):
  108. """Multi-headed attention from 'Attention Is All You Need' paper"""
  109. def __init__(self, config: DiaDecoderConfig, layer_idx: int):
  110. super().__init__()
  111. self.config = config
  112. self.layer_idx = layer_idx
  113. self.hidden_size = config.hidden_size
  114. self.cross_hidden_size = config.cross_hidden_size
  115. self.num_heads = self.config.cross_num_attention_heads
  116. self.num_key_value_heads = self.config.cross_num_key_value_heads
  117. self.num_key_value_groups = self.num_heads // self.num_key_value_heads
  118. self.head_dim = config.cross_head_dim
  119. self.scaling = 1
  120. self.attention_dropout = 0.0
  121. self.is_causal = False
  122. self.q_proj = nn.Linear(self.hidden_size, self.num_heads * self.head_dim, bias=False)
  123. self.k_proj = nn.Linear(self.cross_hidden_size, self.num_key_value_heads * self.head_dim, bias=False)
  124. self.v_proj = nn.Linear(self.cross_hidden_size, self.num_key_value_heads * self.head_dim, bias=False)
  125. self.o_proj = nn.Linear(self.num_heads * self.head_dim, self.hidden_size, bias=False)
  126. def forward(
  127. self,
  128. hidden_states: torch.Tensor,
  129. cross_attention_states: torch.Tensor,
  130. attention_mask: Optional[torch.Tensor] = None,
  131. past_key_values: Optional[EncoderDecoderCache] = None,
  132. **kwargs: Unpack[FlashAttentionKwargs],
  133. ) -> tuple[torch.Tensor, Optional[torch.Tensor]]:
  134. input_shape = hidden_states.shape[:-1]
  135. hidden_shape = (*input_shape, -1, self.head_dim)
  136. cross_shape = (*cross_attention_states.shape[:-1], -1, self.head_dim)
  137. query_states = self.q_proj(hidden_states).view(hidden_shape).transpose(1, 2)
  138. is_updated = past_key_values.is_updated.get(self.layer_idx) if past_key_values is not None else False
  139. if past_key_values is not None and is_updated:
  140. # reuse k,v, cross_attentions
  141. key_states = past_key_values.cross_attention_cache.layers[self.layer_idx].keys
  142. value_states = past_key_values.cross_attention_cache.layers[self.layer_idx].values
  143. else:
  144. key_states = self.k_proj(cross_attention_states).view(cross_shape).transpose(1, 2)
  145. value_states = self.v_proj(cross_attention_states).view(cross_shape).transpose(1, 2)
  146. if past_key_values is not None:
  147. # save all states to the cache
  148. key_states, value_states = past_key_values.cross_attention_cache.update(
  149. key_states,
  150. value_states,
  151. self.layer_idx,
  152. )
  153. # set flag that curr layer for cross-attn is already updated so we can re-use in subsequent calls
  154. past_key_values.is_updated[self.layer_idx] = True
  155. attention_interface: Callable = eager_attention_forward
  156. if self.config._attn_implementation != "eager":
  157. attention_interface = ALL_ATTENTION_FUNCTIONS[self.config._attn_implementation]
  158. attn_output, attn_weights = attention_interface(
  159. self,
  160. query_states,
  161. key_states,
  162. value_states,
  163. attention_mask,
  164. scaling=self.scaling,
  165. **kwargs,
  166. )
  167. attn_output = attn_output.reshape((*input_shape, -1)).contiguous()
  168. attn_output = self.o_proj(attn_output)
  169. return attn_output, attn_weights
  170. class DiaEncoderLayer(GradientCheckpointingLayer):
  171. def __init__(self, config: DiaEncoderConfig, layer_idx: int):
  172. super().__init__()
  173. self.pre_sa_norm = DiaRMSNorm(config.hidden_size, eps=config.norm_eps)
  174. self.self_attention = DiaSelfAttention(config, layer_idx, is_causal=False)
  175. self.post_sa_norm = DiaRMSNorm(config.hidden_size, eps=config.norm_eps)
  176. self.mlp = DiaMLP(config)
  177. def forward(
  178. self,
  179. hidden_states: torch.Tensor,
  180. position_embeddings: Optional[tuple[torch.Tensor, torch.Tensor]] = None, # necessary, but kept here for BC
  181. attention_mask: Optional[torch.Tensor] = None,
  182. **kwargs: Unpack[FlashAttentionKwargs],
  183. ) -> tuple[torch.Tensor, Optional[torch.Tensor]]:
  184. residual = hidden_states
  185. normed_states = self.pre_sa_norm(hidden_states)
  186. self_attn_output, self_attn_weights = self.self_attention(
  187. normed_states,
  188. position_embeddings=position_embeddings,
  189. attention_mask=attention_mask,
  190. **kwargs,
  191. )
  192. hidden_states = residual + self_attn_output
  193. residual = hidden_states
  194. normed_states = self.post_sa_norm(hidden_states)
  195. mlp_out = self.mlp(normed_states)
  196. hidden_states = residual + mlp_out
  197. return hidden_states, self_attn_weights
  198. class DiaEncoder(DiaPreTrainedModel):
  199. def __init__(self, config: DiaEncoderConfig):
  200. super().__init__(config)
  201. self.config = config
  202. self.embedding = nn.Embedding(config.vocab_size, config.hidden_size)
  203. self.layers = nn.ModuleList(
  204. [DiaEncoderLayer(config, layer_idx) for layer_idx in range(config.num_hidden_layers)]
  205. )
  206. self.norm = DiaRMSNorm(config.hidden_size, eps=config.norm_eps)
  207. self.rotary_embeddings = DiaRotaryEmbedding(config)
  208. @auto_docstring
  209. @can_return_tuple
  210. def forward(
  211. self,
  212. input_ids: torch.Tensor,
  213. attention_mask: Optional[torch.Tensor] = None,
  214. output_attentions: Optional[bool] = False,
  215. output_hidden_states: Optional[bool] = False,
  216. **kwargs: Unpack[FlashAttentionKwargs],
  217. ) -> Union[BaseModelOutput, tuple]:
  218. hidden_states = self.embedding(input_ids)
  219. # RoPE
  220. # Note: We expect right padding and hence always generate
  221. # the position ids on the fly to reduce preparation overhead
  222. position_ids = torch.arange(input_ids.shape[-1], device=input_ids.device)[None, :]
  223. position_embeddings = self.rotary_embeddings(hidden_states, position_ids)
  224. attention_mask = self._update_full_mask(
  225. attention_mask,
  226. hidden_states,
  227. )
  228. encoder_states = () if output_hidden_states else None
  229. all_attentions = () if output_attentions else None
  230. for encoder_layer in self.layers:
  231. if output_hidden_states:
  232. encoder_states = encoder_states + (hidden_states,)
  233. layer_outputs = encoder_layer(
  234. hidden_states,
  235. position_embeddings=position_embeddings,
  236. attention_mask=attention_mask,
  237. **kwargs,
  238. )
  239. hidden_states = layer_outputs[0]
  240. if output_attentions:
  241. all_attentions = all_attentions + (layer_outputs[1],)
  242. hidden_states = self.norm(hidden_states)
  243. if output_hidden_states:
  244. encoder_states += (hidden_states,)
  245. return BaseModelOutput(
  246. last_hidden_state=hidden_states, hidden_states=encoder_states, attentions=all_attentions
  247. )
  248. # Copied from transformers.models.bart.modeling_bart.BartPreTrainedModel._update_full_mask
  249. def _update_full_mask(
  250. self,
  251. attention_mask: Union[torch.Tensor, None],
  252. inputs_embeds: torch.Tensor,
  253. ):
  254. if attention_mask is not None:
  255. if self.config._attn_implementation == "flash_attention_2":
  256. attention_mask = attention_mask if 0 in attention_mask else None
  257. elif self.config._attn_implementation == "sdpa":
  258. # output_attentions=True & head_mask can not be supported when using SDPA, fall back to
  259. # the manual implementation that requires a 4D causal mask in all cases.
  260. # [bsz, seq_len] -> [bsz, 1, tgt_seq_len, src_seq_len]
  261. attention_mask = _prepare_4d_attention_mask_for_sdpa(attention_mask, inputs_embeds.dtype)
  262. elif self.config._attn_implementation == "flex_attention":
  263. if isinstance(attention_mask, torch.Tensor):
  264. attention_mask = make_flex_block_causal_mask(attention_mask, is_causal=False)
  265. else:
  266. # [bsz, seq_len] -> [bsz, 1, tgt_seq_len, src_seq_len]
  267. attention_mask = _prepare_4d_attention_mask(attention_mask, inputs_embeds.dtype)
  268. return attention_mask
  269. class DiaDecoderLayer(GradientCheckpointingLayer):
  270. def __init__(self, config: DiaDecoderConfig, layer_idx: int):
  271. super().__init__()
  272. self.embed_dim = config.hidden_size
  273. self.self_attention = DiaSelfAttention(config, layer_idx, is_causal=True)
  274. self.cross_attention = DiaCrossAttention(config, layer_idx)
  275. self.pre_sa_norm = DiaRMSNorm(config.hidden_size, eps=config.norm_eps)
  276. self.pre_ca_norm = DiaRMSNorm(config.hidden_size, eps=config.norm_eps)
  277. self.pre_mlp_norm = DiaRMSNorm(config.hidden_size, eps=config.norm_eps)
  278. self.mlp = DiaMLP(config)
  279. def forward(
  280. self,
  281. hidden_states: torch.Tensor,
  282. position_embeddings: Optional[tuple[torch.Tensor, torch.Tensor]] = None, # necessary, but kept here for BC
  283. attention_mask: Optional[torch.Tensor] = None,
  284. encoder_hidden_states: Optional[torch.Tensor] = None,
  285. encoder_attention_mask: Optional[torch.Tensor] = None,
  286. past_key_values: Optional[EncoderDecoderCache] = None,
  287. cache_position: Optional[torch.LongTensor] = None,
  288. **kwargs,
  289. ) -> tuple[torch.Tensor, Optional[torch.Tensor], Optional[torch.Tensor]]:
  290. self_attn_cache = past_key_values
  291. if isinstance(self_attn_cache, EncoderDecoderCache):
  292. self_attn_cache = self_attn_cache.self_attention_cache
  293. residual = hidden_states
  294. normed_states = self.pre_sa_norm(hidden_states)
  295. self_attn_output, self_attn_weights = self.self_attention(
  296. normed_states,
  297. position_embeddings,
  298. attention_mask,
  299. # Needs to be an arg in order to function properly
  300. # on inplace operations to be carried (e.g. compile)
  301. self_attn_cache,
  302. cache_position=cache_position,
  303. **kwargs,
  304. )
  305. hidden_states = residual + self_attn_output
  306. residual = hidden_states
  307. normed_states = self.pre_ca_norm(hidden_states)
  308. cross_states, cross_attn_weights = self.cross_attention(
  309. normed_states,
  310. encoder_hidden_states,
  311. attention_mask=encoder_attention_mask,
  312. past_key_values=past_key_values,
  313. **kwargs,
  314. )
  315. hidden_states = residual + cross_states
  316. residual = hidden_states
  317. normed_states = self.pre_mlp_norm(hidden_states)
  318. mlp_out = self.mlp(normed_states)
  319. hidden_states = residual + mlp_out
  320. return hidden_states, self_attn_weights, cross_attn_weights
  321. class DiaDecoder(DiaPreTrainedModel):
  322. """Transformer Decoder Stack using DenseGeneral."""
  323. def __init__(self, config: DiaDecoderConfig):
  324. super().__init__(config)
  325. self.num_channels = config.num_channels
  326. self.vocab_size = config.vocab_size
  327. self.embeddings = DiaMultiChannelEmbedding(config)
  328. self.rotary_embeddings = DiaRotaryEmbedding(config)
  329. self.layers = nn.ModuleList(
  330. [DiaDecoderLayer(config, layer_idx) for layer_idx in range(config.num_hidden_layers)]
  331. )
  332. self.norm = DiaRMSNorm(config.hidden_size, eps=config.norm_eps)
  333. @auto_docstring
  334. @can_return_tuple
  335. def forward(
  336. self,
  337. input_ids: torch.Tensor,
  338. position_ids: Optional[torch.LongTensor] = None,
  339. attention_mask: Optional[torch.Tensor] = None,
  340. encoder_hidden_states: Optional[torch.FloatTensor] = None,
  341. encoder_attention_mask: Optional[torch.LongTensor] = None,
  342. past_key_values: Optional[EncoderDecoderCache] = None,
  343. output_attentions: Optional[bool] = False,
  344. output_hidden_states: Optional[bool] = False,
  345. cache_position: Optional[torch.LongTensor] = None,
  346. **kwargs,
  347. ) -> Union[BaseModelOutputWithPastAndCrossAttentions, tuple]:
  348. r"""
  349. input_ids (`torch.LongTensor` of shape `(batch_size, sequence_length, num_codebooks)`):
  350. The original `decoder_input_ids` in 3D shape to facilitate more efficient computations.
  351. [What are input IDs?](../glossary#input-ids)
  352. """
  353. batch_size, seq_length = input_ids.size()[:-1]
  354. past_key_values_length = past_key_values.get_seq_length() if past_key_values is not None else 0
  355. if cache_position is None:
  356. cache_position = torch.arange(
  357. past_key_values_length, past_key_values_length + seq_length, device=input_ids.device
  358. )
  359. if position_ids is None:
  360. position_ids = cache_position[None, :]
  361. # RoPE
  362. hidden_states = self.embeddings(input_ids)
  363. position_embeddings = self.rotary_embeddings(hidden_states, position_ids)
  364. if attention_mask is None and not is_torchdynamo_compiling():
  365. # required mask seq length can be calculated via length of past cache
  366. mask_seq_length = past_key_values_length + seq_length
  367. attention_mask = torch.ones(batch_size, mask_seq_length, device=input_ids.device)
  368. attention_mask = create_causal_mask(
  369. config=self.config,
  370. input_embeds=hidden_states,
  371. attention_mask=attention_mask,
  372. cache_position=cache_position,
  373. past_key_values=past_key_values,
  374. position_ids=position_ids,
  375. )
  376. encoder_attention_mask = self._update_cross_attn_mask(
  377. encoder_hidden_states,
  378. encoder_attention_mask,
  379. hidden_states.shape[:2],
  380. hidden_states,
  381. )
  382. all_hidden_states = () if output_hidden_states else None
  383. all_self_attns = () if output_attentions else None
  384. all_cross_attentions = () if (output_attentions and encoder_hidden_states is not None) else None
  385. for layer in self.layers:
  386. if output_hidden_states:
  387. all_hidden_states += (hidden_states,)
  388. layer_outputs = layer(
  389. hidden_states,
  390. position_embeddings,
  391. attention_mask,
  392. encoder_hidden_states,
  393. encoder_attention_mask=encoder_attention_mask,
  394. past_key_values=past_key_values,
  395. cache_position=cache_position,
  396. **kwargs,
  397. )
  398. hidden_states = layer_outputs[0]
  399. if output_attentions:
  400. all_self_attns = all_self_attns + (layer_outputs[1],)
  401. if encoder_hidden_states is not None:
  402. all_cross_attentions = all_cross_attentions + (layer_outputs[2],)
  403. hidden_states = self.norm(hidden_states)
  404. if output_hidden_states:
  405. all_hidden_states += (hidden_states,)
  406. return BaseModelOutputWithPastAndCrossAttentions(
  407. last_hidden_state=hidden_states,
  408. past_key_values=past_key_values,
  409. hidden_states=all_hidden_states,
  410. attentions=all_self_attns,
  411. cross_attentions=all_cross_attentions,
  412. )
  413. # Copied from transformers.models.bart.modeling_bart.BartPreTrainedModel._update_cross_attn_mask
  414. def _update_cross_attn_mask(
  415. self,
  416. encoder_hidden_states: Union[torch.Tensor, None],
  417. encoder_attention_mask: Union[torch.Tensor, None],
  418. input_shape: torch.Size,
  419. inputs_embeds: torch.Tensor,
  420. ):
  421. # expand encoder attention mask
  422. if encoder_hidden_states is not None and encoder_attention_mask is not None:
  423. if self.config._attn_implementation == "flash_attention_2":
  424. encoder_attention_mask = encoder_attention_mask if 0 in encoder_attention_mask else None
  425. elif self.config._attn_implementation == "sdpa":
  426. # output_attentions=True & cross_attn_head_mask can not be supported when using SDPA, and we fall back on
  427. # the manual implementation that requires a 4D causal mask in all cases.
  428. # [bsz, seq_len] -> [bsz, 1, tgt_seq_len, src_seq_len]
  429. encoder_attention_mask = _prepare_4d_attention_mask_for_sdpa(
  430. encoder_attention_mask,
  431. inputs_embeds.dtype,
  432. tgt_len=input_shape[-1],
  433. )
  434. elif self.config._attn_implementation == "flex_attention":
  435. if isinstance(encoder_attention_mask, torch.Tensor):
  436. encoder_attention_mask = make_flex_block_causal_mask(
  437. encoder_attention_mask,
  438. query_length=input_shape[-1],
  439. is_causal=False,
  440. )
  441. else:
  442. # [bsz, seq_len] -> [bsz, 1, tgt_seq_len, src_seq_len]
  443. encoder_attention_mask = _prepare_4d_attention_mask(
  444. encoder_attention_mask, inputs_embeds.dtype, tgt_len=input_shape[-1]
  445. )
  446. return encoder_attention_mask
  447. @auto_docstring(
  448. custom_intro="""
  449. The bare Dia model outputting raw hidden-states without any specific head on top.
  450. """
  451. )
  452. class DiaModel(DiaPreTrainedModel):
  453. def __init__(self, config: DiaConfig):
  454. super().__init__(config)
  455. self.config = config
  456. self.encoder = DiaEncoder(config.encoder_config)
  457. self.decoder = DiaDecoder(config.decoder_config)
  458. self.post_init()
  459. def get_encoder(self):
  460. return self.encoder
  461. @auto_docstring
  462. @can_return_tuple
  463. def forward(
  464. self,
  465. input_ids: Optional[torch.LongTensor] = None,
  466. attention_mask: Optional[torch.LongTensor] = None,
  467. decoder_input_ids: Optional[torch.LongTensor] = None,
  468. decoder_position_ids: Optional[torch.LongTensor] = None,
  469. decoder_attention_mask: Optional[torch.LongTensor] = None,
  470. encoder_outputs: Optional[Union[BaseModelOutput, tuple]] = None,
  471. past_key_values: Optional[EncoderDecoderCache] = None,
  472. use_cache: Optional[bool] = None,
  473. output_attentions: Optional[bool] = None,
  474. output_hidden_states: Optional[bool] = None,
  475. cache_position: Optional[torch.LongTensor] = None,
  476. **kwargs,
  477. ) -> Union[tuple, Seq2SeqModelOutput]:
  478. r"""
  479. decoder_input_ids (`torch.LongTensor` of shape `(batch_size * num_codebooks, target_sequence_length)
  480. or (batch_size, target_sequence_length, num_codebooks)`, *optional*):
  481. 1. (batch_size * num_codebooks, target_sequence_length): corresponds to the general use case where
  482. the audio input codebooks are flattened into the batch dimension. This also aligns with the flat-
  483. tened audio logits which are used to calculate the loss.
  484. 2. (batch_size, sequence_length, num_codebooks): corresponds to the internally used shape of
  485. Dia to calculate embeddings and subsequent steps more efficiently.
  486. If no `decoder_input_ids` are provided, it will create a tensor of `bos_token_id` with shape
  487. `(batch_size, 1, num_codebooks)`. Indices can be obtained using the [`DiaProcessor`]. See
  488. [`DiaProcessor.__call__`] for more details.
  489. [What are decoder input IDs?](../glossary#decoder-input-ids)
  490. decoder_position_ids (`torch.LongTensor` of shape `(batch_size, target_sequence_length)`):
  491. Indices of positions of each input sequence tokens in the position embeddings.
  492. Used to calculate the position embeddings up to `config.decoder_config.max_position_embeddings`.
  493. [What are position IDs?](../glossary#position-ids)
  494. """
  495. if input_ids is None and encoder_outputs is None:
  496. raise ValueError(
  497. "You should either provide text ids or the cached text encodings. Neither has been found."
  498. )
  499. output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
  500. output_hidden_states = (
  501. output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
  502. )
  503. use_cache = use_cache if use_cache is not None else self.config.use_cache
  504. if self.is_gradient_checkpointing and self.training:
  505. if use_cache:
  506. logger.warning_once(
  507. "`use_cache=True` is incompatible with gradient checkpointing. Setting `use_cache=False`..."
  508. )
  509. use_cache = False
  510. if use_cache and past_key_values is None:
  511. past_key_values = EncoderDecoderCache(DynamicCache(config=self.config), DynamicCache(config=self.config))
  512. if encoder_outputs is None:
  513. encoder_outputs = self.encoder(
  514. input_ids=input_ids,
  515. attention_mask=attention_mask,
  516. output_attentions=output_attentions,
  517. output_hidden_states=output_hidden_states,
  518. **kwargs,
  519. )
  520. # If the user passed a tuple for encoder_outputs, we wrap it in a BaseModelOutput
  521. elif not isinstance(encoder_outputs, BaseModelOutput):
  522. encoder_outputs = BaseModelOutput(
  523. last_hidden_state=encoder_outputs[0],
  524. hidden_states=encoder_outputs[1] if len(encoder_outputs) > 1 else None,
  525. attentions=encoder_outputs[2] if len(encoder_outputs) > 2 else None,
  526. )
  527. # On default we initialize the decoder with bos tokens if nothing has been provided
  528. bsz, seq_len, channels = (encoder_outputs[0].shape[0], -1, self.config.decoder_config.num_channels)
  529. if decoder_input_ids is None:
  530. decoder_input_ids = torch.full(
  531. size=(bsz, 1, channels), fill_value=self.config.bos_token_id, device=self.device
  532. )
  533. # Ensure 3D
  534. if decoder_input_ids.ndim == 2:
  535. decoder_input_ids = decoder_input_ids.reshape(bsz, channels, seq_len).transpose(1, 2)
  536. decoder_outputs = self.decoder(
  537. input_ids=decoder_input_ids,
  538. position_ids=decoder_position_ids,
  539. attention_mask=decoder_attention_mask,
  540. encoder_hidden_states=encoder_outputs[0],
  541. encoder_attention_mask=attention_mask,
  542. past_key_values=past_key_values,
  543. output_attentions=output_attentions,
  544. output_hidden_states=output_hidden_states,
  545. use_cache=use_cache,
  546. cache_position=cache_position,
  547. **kwargs,
  548. )
  549. return Seq2SeqModelOutput(
  550. last_hidden_state=decoder_outputs.last_hidden_state,
  551. past_key_values=decoder_outputs.past_key_values,
  552. decoder_hidden_states=decoder_outputs.hidden_states,
  553. decoder_attentions=decoder_outputs.attentions,
  554. cross_attentions=decoder_outputs.cross_attentions,
  555. encoder_last_hidden_state=encoder_outputs[0],
  556. encoder_hidden_states=encoder_outputs.hidden_states,
  557. encoder_attentions=encoder_outputs.attentions,
  558. )
  559. @auto_docstring(
  560. custom_intro="""
  561. The Dia model consisting of a (byte) text encoder and audio decoder with a prediction head on top.
  562. """
  563. )
  564. class DiaForConditionalGeneration(DiaPreTrainedModel, DiaGenerationMixin):
  565. base_model_prefix = "model"
  566. def __init__(self, config: DiaConfig):
  567. super().__init__(config)
  568. self.config = config
  569. self.model = DiaModel(config)
  570. self.num_channels = config.decoder_config.num_channels
  571. self.vocab_size = config.decoder_config.vocab_size
  572. self.logits_dense = nn.Linear(
  573. config.decoder_config.hidden_size, (self.num_channels * self.vocab_size), bias=False
  574. )
  575. self.loss_type = "ForMaskedLM"
  576. # Initialize weights and apply final processing
  577. self.post_init()
  578. def get_encoder(self):
  579. return self.model.get_encoder()
  580. def get_decoder(self):
  581. return self.model.get_decoder()
  582. @auto_docstring
  583. @can_return_tuple
  584. def forward(
  585. self,
  586. input_ids: Optional[torch.LongTensor] = None,
  587. attention_mask: Optional[torch.LongTensor] = None,
  588. decoder_input_ids: Optional[torch.LongTensor] = None,
  589. decoder_position_ids: Optional[torch.LongTensor] = None,
  590. decoder_attention_mask: Optional[torch.LongTensor] = None,
  591. encoder_outputs: Optional[Union[BaseModelOutput, tuple]] = None,
  592. past_key_values: Optional[EncoderDecoderCache] = None,
  593. use_cache: Optional[bool] = None,
  594. output_attentions: Optional[bool] = None,
  595. output_hidden_states: Optional[bool] = None,
  596. labels: Optional[torch.LongTensor] = None,
  597. cache_position: Optional[torch.LongTensor] = None,
  598. **kwargs,
  599. ) -> Union[tuple, Seq2SeqLMOutput]:
  600. r"""
  601. decoder_input_ids (`torch.LongTensor` of shape `(batch_size * num_codebooks, target_sequence_length)
  602. or (batch_size, target_sequence_length, num_codebooks)`, *optional*):
  603. 1. (batch_size * num_codebooks, target_sequence_length): corresponds to the general use case where
  604. the audio input codebooks are flattened into the batch dimension. This also aligns with the flat-
  605. tened audio logits which are used to calculate the loss.
  606. 2. (batch_size, sequence_length, num_codebooks): corresponds to the internally used shape of
  607. Dia to calculate embeddings and subsequent steps more efficiently.
  608. If no `decoder_input_ids` are provided, it will create a tensor of `bos_token_id` with shape
  609. `(batch_size, 1, num_codebooks)`. Indices can be obtained using the [`DiaProcessor`]. See
  610. [`DiaProcessor.__call__`] for more details.
  611. [What are decoder input IDs?](../glossary#decoder-input-ids)
  612. decoder_position_ids (`torch.LongTensor` of shape `(batch_size, target_sequence_length)`):
  613. Indices of positions of each input sequence tokens in the position embeddings.
  614. Used to calculate the position embeddings up to `config.decoder_config.max_position_embeddings`.
  615. [What are position IDs?](../glossary#position-ids)
  616. labels (`torch.LongTensor` of shape `(batch_size * num_codebooks,)`, *optional*):
  617. Labels for computing the masked language modeling loss. Indices should either be in
  618. `[0, ..., config.decoder_config.vocab_size - 1]` or -100. Tokens with indices set to `-100`
  619. are ignored (masked).
  620. """
  621. outputs = self.model(
  622. input_ids=input_ids,
  623. attention_mask=attention_mask,
  624. decoder_input_ids=decoder_input_ids,
  625. decoder_position_ids=decoder_position_ids,
  626. decoder_attention_mask=decoder_attention_mask,
  627. encoder_outputs=encoder_outputs,
  628. past_key_values=past_key_values,
  629. use_cache=use_cache,
  630. output_attentions=output_attentions,
  631. output_hidden_states=output_hidden_states,
  632. cache_position=cache_position,
  633. **kwargs,
  634. )
  635. last_hidden_state = outputs[0]
  636. batch_size = last_hidden_state.shape[0]
  637. # 3D <-> 2D makes it necessary to prioritize channel dim
  638. audio_logits = (
  639. self.logits_dense(last_hidden_state)
  640. .view((batch_size, -1, self.num_channels, self.vocab_size))
  641. .transpose(1, 2)
  642. .contiguous()
  643. .view(batch_size * self.num_channels, -1, self.vocab_size)
  644. )
  645. loss = None
  646. if labels is not None:
  647. loss = self.loss_function(logits=audio_logits, labels=labels, vocab_size=self.vocab_size, **kwargs)
  648. return Seq2SeqLMOutput(
  649. loss=loss,
  650. logits=audio_logits,
  651. past_key_values=outputs.past_key_values,
  652. decoder_hidden_states=outputs.decoder_hidden_states,
  653. decoder_attentions=outputs.decoder_attentions,
  654. cross_attentions=outputs.cross_attentions,
  655. encoder_last_hidden_state=outputs.encoder_last_hidden_state,
  656. encoder_hidden_states=outputs.encoder_hidden_states,
  657. encoder_attentions=outputs.encoder_attentions,
  658. )
  659. __all__ = ["DiaModel", "DiaPreTrainedModel", "DiaForConditionalGeneration"]