modeling_dia.py 42 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525526527528529530531532533534535536537538539540541542543544545546547548549550551552553554555556557558559560561562563564565566567568569570571572573574575576577578579580581582583584585586587588589590591592593594595596597598599600601602603604605606607608609610611612613614615616617618619620621622623624625626627628629630631632633634635636637638639640641642643644645646647648649650651652653654655656657658659660661662663664665666667668669670671672673674675676677678679680681682683684685686687688689690691692693694695696697698699700701702703704705706707708709710711712713714715716717718719720721722723724725726727728729730731732733734735736737738739740741742743744745746747748749750751752753754755756757758759760761762763764765766767768769770771772773774775776777778779780781782783784785786787788789790791792793794795796797798799800801802803804805806807808809810811812813814815816817818819820821822823824825826827828829830831832833834835836837838839840841842843844845846847848849850851852853854855856857858859860861862863864865866867868869870871872873874875876877878879880881882883884885886887888889890891892893894895896897898899900901902903904905906907908909910911912913914915916917918919920921922923924925926927928929930931932933934935936937938939940941942943944945946947948949950951952953954955956957958
  1. # 🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨
  2. # This file was automatically generated from src/transformers/models/dia/modular_dia.py.
  3. # Do NOT edit this file manually as any edits will be overwritten by the generation of
  4. # the file from the modular. If any change should be done, please apply the change to the
  5. # modular_dia.py file directly. One of our CI enforces this.
  6. # 🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨
  7. # coding=utf-8
  8. # Copyright 2025 The Nari Labs and HuggingFace Inc. team. All rights reserved.
  9. #
  10. # Licensed under the Apache License, Version 2.0 (the "License");
  11. # you may not use this file except in compliance with the License.
  12. # You may obtain a copy of the License at
  13. #
  14. # http://www.apache.org/licenses/LICENSE-2.0
  15. #
  16. # Unless required by applicable law or agreed to in writing, software
  17. # distributed under the License is distributed on an "AS IS" BASIS,
  18. # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  19. # See the License for the specific language governing permissions and
  20. # limitations under the License.
  21. from typing import Callable, Optional, Union
  22. import torch
  23. from torch import nn
  24. from ...activations import ACT2FN
  25. from ...cache_utils import Cache, DynamicCache, EncoderDecoderCache
  26. from ...integrations import use_kernel_forward_from_hub
  27. from ...masking_utils import create_causal_mask
  28. from ...modeling_attn_mask_utils import _prepare_4d_attention_mask, _prepare_4d_attention_mask_for_sdpa
  29. from ...modeling_flash_attention_utils import FlashAttentionKwargs
  30. from ...modeling_layers import GradientCheckpointingLayer
  31. from ...modeling_outputs import (
  32. BaseModelOutput,
  33. BaseModelOutputWithPastAndCrossAttentions,
  34. Seq2SeqLMOutput,
  35. Seq2SeqModelOutput,
  36. )
  37. from ...modeling_rope_utils import ROPE_INIT_FUNCTIONS, dynamic_rope_update
  38. from ...modeling_utils import ALL_ATTENTION_FUNCTIONS, PreTrainedModel
  39. from ...processing_utils import Unpack
  40. from ...utils import (
  41. TransformersKwargs,
  42. auto_docstring,
  43. can_return_tuple,
  44. is_torch_flex_attn_available,
  45. is_torchdynamo_compiling,
  46. logging,
  47. )
  48. from ...utils.deprecation import deprecate_kwarg
  49. from .configuration_dia import DiaConfig, DiaDecoderConfig, DiaEncoderConfig
  50. from .generation_dia import DiaGenerationMixin
  51. if is_torch_flex_attn_available():
  52. from ...integrations.flex_attention import make_flex_block_causal_mask
  53. logger = logging.get_logger(__name__)
  54. @auto_docstring
  55. class DiaPreTrainedModel(PreTrainedModel):
  56. config: DiaConfig
  57. base_model_prefix = "model"
  58. supports_gradient_checkpointing = True
  59. _supports_flash_attn = True
  60. _supports_sdpa = True
  61. _supports_flex_attn = True
  62. _can_compile_fullgraph = True
  63. main_input_name = "input_ids"
  64. _no_split_modules = ["DiaEncoderLayer", "DiaDecoderLayer"]
  65. class DiaMultiChannelEmbedding(nn.Module):
  66. """In order to efficiently compute the audio embedding from the 9 different channels,
  67. we vectorize the embedding process by using a single embedding layer and an offset.
  68. Example:
  69. - num_embeds = 4
  70. - vocab_size = 8
  71. - num_channels = 3
  72. We would have offsets = [0, 8, 16]
  73. If audio_codes = [0, 1, 2, 3], [1, 3, 4, 7], [5, 6, 7, 8],
  74. then tokens = audio_codes + offsets
  75. = [0, 1, 2, 3, 9, 11, 12, 15, 21, 22, 23, 24]
  76. This allows us to use a single embedding layer for all channels.
  77. """
  78. def __init__(self, config: DiaDecoderConfig):
  79. super().__init__()
  80. self.embed = nn.Embedding(config.vocab_size * config.num_channels, config.hidden_size)
  81. self.hidden_size = config.hidden_size
  82. self.num_channels = config.num_channels
  83. offsets = torch.arange(config.num_channels, dtype=torch.long) * config.vocab_size # (C,)
  84. self.register_buffer("offsets", offsets, persistent=False)
  85. def forward(self, audio_codes: torch.Tensor) -> torch.Tensor:
  86. tokens = (audio_codes + self.offsets.to(audio_codes.device)).squeeze(1)
  87. embeds = self.embed(tokens).view(tokens.shape[0], audio_codes.shape[1], -1, self.hidden_size)
  88. return embeds.sum(dim=2)
  89. class DiaMLP(nn.Module):
  90. def __init__(self, config):
  91. super().__init__()
  92. self.config = config
  93. self.gate_up_proj = nn.Linear(config.hidden_size, 2 * config.intermediate_size, bias=False)
  94. self.down_proj = nn.Linear(config.intermediate_size, config.hidden_size, bias=False)
  95. self.activation_fn = ACT2FN[config.hidden_act]
  96. def forward(self, hidden_states: torch.FloatTensor) -> torch.FloatTensor:
  97. up_states = self.gate_up_proj(hidden_states)
  98. gate, up_states = up_states.chunk(2, dim=-1)
  99. up_states = up_states * self.activation_fn(gate)
  100. return self.down_proj(up_states)
  101. @use_kernel_forward_from_hub("RMSNorm")
  102. class DiaRMSNorm(nn.Module):
  103. def __init__(self, hidden_size, eps=1e-6):
  104. """
  105. DiaRMSNorm is equivalent to T5LayerNorm
  106. """
  107. super().__init__()
  108. self.weight = nn.Parameter(torch.ones(hidden_size))
  109. self.variance_epsilon = eps
  110. def forward(self, hidden_states):
  111. input_dtype = hidden_states.dtype
  112. hidden_states = hidden_states.to(torch.float32)
  113. variance = hidden_states.pow(2).mean(-1, keepdim=True)
  114. hidden_states = hidden_states * torch.rsqrt(variance + self.variance_epsilon)
  115. return self.weight * hidden_states.to(input_dtype)
  116. def extra_repr(self):
  117. return f"{tuple(self.weight.shape)}, eps={self.variance_epsilon}"
  118. class DiaRotaryEmbedding(nn.Module):
  119. inv_freq: torch.Tensor # fix linting for `register_buffer`
  120. def __init__(self, config: DiaConfig, device=None):
  121. super().__init__()
  122. # BC: "rope_type" was originally "type"
  123. if hasattr(config, "rope_scaling") and isinstance(config.rope_scaling, dict):
  124. self.rope_type = config.rope_scaling.get("rope_type", config.rope_scaling.get("type"))
  125. else:
  126. self.rope_type = "default"
  127. self.max_seq_len_cached = config.max_position_embeddings
  128. self.original_max_seq_len = config.max_position_embeddings
  129. self.config = config
  130. self.rope_init_fn = ROPE_INIT_FUNCTIONS[self.rope_type]
  131. inv_freq, self.attention_scaling = self.rope_init_fn(self.config, device)
  132. self.register_buffer("inv_freq", inv_freq, persistent=False)
  133. self.original_inv_freq = self.inv_freq
  134. @torch.no_grad()
  135. @dynamic_rope_update # power user: used with advanced RoPE types (e.g. dynamic rope)
  136. def forward(self, x, position_ids):
  137. inv_freq_expanded = self.inv_freq[None, :, None].float().expand(position_ids.shape[0], -1, 1).to(x.device)
  138. position_ids_expanded = position_ids[:, None, :].float()
  139. device_type = x.device.type if isinstance(x.device.type, str) and x.device.type != "mps" else "cpu"
  140. with torch.autocast(device_type=device_type, enabled=False): # Force float32
  141. freqs = (inv_freq_expanded.float() @ position_ids_expanded.float()).transpose(1, 2)
  142. emb = torch.cat((freqs, freqs), dim=-1)
  143. cos = emb.cos() * self.attention_scaling
  144. sin = emb.sin() * self.attention_scaling
  145. return cos.to(dtype=x.dtype), sin.to(dtype=x.dtype)
  146. def rotate_half(x):
  147. """Rotates half the hidden dims of the input."""
  148. x1 = x[..., : x.shape[-1] // 2]
  149. x2 = x[..., x.shape[-1] // 2 :]
  150. return torch.cat((-x2, x1), dim=-1)
  151. def apply_rotary_pos_emb(q, k, cos, sin, position_ids=None, unsqueeze_dim=1):
  152. """Applies Rotary Position Embedding to the query and key tensors.
  153. Args:
  154. q (`torch.Tensor`): The query tensor.
  155. k (`torch.Tensor`): The key tensor.
  156. cos (`torch.Tensor`): The cosine part of the rotary embedding.
  157. sin (`torch.Tensor`): The sine part of the rotary embedding.
  158. position_ids (`torch.Tensor`, *optional*):
  159. Deprecated and unused.
  160. unsqueeze_dim (`int`, *optional*, defaults to 1):
  161. The 'unsqueeze_dim' argument specifies the dimension along which to unsqueeze cos[position_ids] and
  162. sin[position_ids] so that they can be properly broadcasted to the dimensions of q and k. For example, note
  163. that cos[position_ids] and sin[position_ids] have the shape [batch_size, seq_len, head_dim]. Then, if q and
  164. k have the shape [batch_size, heads, seq_len, head_dim], then setting unsqueeze_dim=1 makes
  165. cos[position_ids] and sin[position_ids] broadcastable to the shapes of q and k. Similarly, if q and k have
  166. the shape [batch_size, seq_len, heads, head_dim], then set unsqueeze_dim=2.
  167. Returns:
  168. `tuple(torch.Tensor)` comprising of the query and key tensors rotated using the Rotary Position Embedding.
  169. """
  170. cos = cos.unsqueeze(unsqueeze_dim)
  171. sin = sin.unsqueeze(unsqueeze_dim)
  172. q_embed = (q * cos) + (rotate_half(q) * sin)
  173. k_embed = (k * cos) + (rotate_half(k) * sin)
  174. return q_embed, k_embed
  175. def repeat_kv(hidden_states: torch.Tensor, n_rep: int) -> torch.Tensor:
  176. """
  177. This is the equivalent of torch.repeat_interleave(x, dim=1, repeats=n_rep). The hidden states go from (batch,
  178. num_key_value_heads, seqlen, head_dim) to (batch, num_attention_heads, seqlen, head_dim)
  179. """
  180. batch, num_key_value_heads, slen, head_dim = hidden_states.shape
  181. if n_rep == 1:
  182. return hidden_states
  183. hidden_states = hidden_states[:, :, None, :, :].expand(batch, num_key_value_heads, n_rep, slen, head_dim)
  184. return hidden_states.reshape(batch, num_key_value_heads * n_rep, slen, head_dim)
  185. def eager_attention_forward(
  186. module: nn.Module,
  187. query: torch.Tensor,
  188. key: torch.Tensor,
  189. value: torch.Tensor,
  190. attention_mask: Optional[torch.Tensor],
  191. scaling: float,
  192. dropout: float = 0.0,
  193. **kwargs: Unpack[TransformersKwargs],
  194. ):
  195. key_states = repeat_kv(key, module.num_key_value_groups)
  196. value_states = repeat_kv(value, module.num_key_value_groups)
  197. attn_weights = torch.matmul(query, key_states.transpose(2, 3)) * scaling
  198. if attention_mask is not None:
  199. causal_mask = attention_mask[:, :, :, : key_states.shape[-2]]
  200. attn_weights = attn_weights + causal_mask
  201. attn_weights = nn.functional.softmax(attn_weights, dim=-1, dtype=torch.float32).to(query.dtype)
  202. attn_weights = nn.functional.dropout(attn_weights, p=dropout, training=module.training)
  203. attn_output = torch.matmul(attn_weights, value_states)
  204. attn_output = attn_output.transpose(1, 2).contiguous()
  205. return attn_output, attn_weights
  206. class DiaSelfAttention(nn.Module):
  207. """Multi-headed attention from 'Attention Is All You Need' paper"""
  208. def __init__(self, config: Union[DiaEncoderConfig, DiaDecoderConfig], layer_idx: int, is_causal: bool = False):
  209. super().__init__()
  210. self.config = config
  211. self.layer_idx = layer_idx
  212. self.hidden_size = config.hidden_size
  213. self.num_heads = self.config.num_attention_heads
  214. self.num_key_value_heads = self.config.num_key_value_heads or self.num_heads
  215. self.num_key_value_groups = self.num_heads // self.num_key_value_heads
  216. self.head_dim = getattr(config, "head_dim", config.hidden_size // self.num_heads)
  217. self.scaling = 1
  218. self.attention_dropout = 0.0
  219. self.is_causal = is_causal
  220. self.q_proj = nn.Linear(self.hidden_size, self.num_heads * self.head_dim, bias=False)
  221. self.k_proj = nn.Linear(self.hidden_size, self.num_key_value_heads * self.head_dim, bias=False)
  222. self.v_proj = nn.Linear(self.hidden_size, self.num_key_value_heads * self.head_dim, bias=False)
  223. self.o_proj = nn.Linear(self.num_heads * self.head_dim, self.hidden_size, bias=False)
  224. @deprecate_kwarg("past_key_value", new_name="past_key_values", version="4.58")
  225. def forward(
  226. self,
  227. hidden_states: torch.Tensor,
  228. position_embeddings: tuple[torch.Tensor, torch.Tensor],
  229. attention_mask: Optional[torch.Tensor],
  230. past_key_values: Optional[Cache] = None,
  231. cache_position: Optional[torch.LongTensor] = None,
  232. **kwargs: Unpack[TransformersKwargs],
  233. ) -> tuple[torch.Tensor, torch.Tensor]:
  234. input_shape = hidden_states.shape[:-1]
  235. hidden_shape = (*input_shape, -1, self.head_dim)
  236. query_states = self.q_proj(hidden_states).view(hidden_shape).transpose(1, 2)
  237. key_states = self.k_proj(hidden_states).view(hidden_shape).transpose(1, 2)
  238. value_states = self.v_proj(hidden_states).view(hidden_shape).transpose(1, 2)
  239. cos, sin = position_embeddings
  240. query_states, key_states = apply_rotary_pos_emb(query_states, key_states, cos, sin)
  241. if past_key_values is not None:
  242. # sin and cos are specific to RoPE models; cache_position needed for the static cache
  243. cache_kwargs = {"sin": sin, "cos": cos, "cache_position": cache_position}
  244. key_states, value_states = past_key_values.update(key_states, value_states, self.layer_idx, cache_kwargs)
  245. attention_interface: Callable = eager_attention_forward
  246. if self.config._attn_implementation != "eager":
  247. attention_interface = ALL_ATTENTION_FUNCTIONS[self.config._attn_implementation]
  248. attn_output, attn_weights = attention_interface(
  249. self,
  250. query_states,
  251. key_states,
  252. value_states,
  253. attention_mask,
  254. dropout=0.0 if not self.training else self.attention_dropout,
  255. scaling=self.scaling,
  256. **kwargs,
  257. )
  258. attn_output = attn_output.reshape(*input_shape, -1).contiguous()
  259. attn_output = self.o_proj(attn_output)
  260. return attn_output, attn_weights
  261. class DiaCrossAttention(nn.Module):
  262. """Multi-headed attention from 'Attention Is All You Need' paper"""
  263. def __init__(self, config: DiaDecoderConfig, layer_idx: int):
  264. super().__init__()
  265. self.config = config
  266. self.layer_idx = layer_idx
  267. self.hidden_size = config.hidden_size
  268. self.cross_hidden_size = config.cross_hidden_size
  269. self.num_heads = self.config.cross_num_attention_heads
  270. self.num_key_value_heads = self.config.cross_num_key_value_heads
  271. self.num_key_value_groups = self.num_heads // self.num_key_value_heads
  272. self.head_dim = config.cross_head_dim
  273. self.scaling = 1
  274. self.attention_dropout = 0.0
  275. self.is_causal = False
  276. self.q_proj = nn.Linear(self.hidden_size, self.num_heads * self.head_dim, bias=False)
  277. self.k_proj = nn.Linear(self.cross_hidden_size, self.num_key_value_heads * self.head_dim, bias=False)
  278. self.v_proj = nn.Linear(self.cross_hidden_size, self.num_key_value_heads * self.head_dim, bias=False)
  279. self.o_proj = nn.Linear(self.num_heads * self.head_dim, self.hidden_size, bias=False)
  280. def forward(
  281. self,
  282. hidden_states: torch.Tensor,
  283. cross_attention_states: torch.Tensor,
  284. attention_mask: Optional[torch.Tensor] = None,
  285. past_key_values: Optional[EncoderDecoderCache] = None,
  286. **kwargs: Unpack[FlashAttentionKwargs],
  287. ) -> tuple[torch.Tensor, Optional[torch.Tensor]]:
  288. input_shape = hidden_states.shape[:-1]
  289. hidden_shape = (*input_shape, -1, self.head_dim)
  290. cross_shape = (*cross_attention_states.shape[:-1], -1, self.head_dim)
  291. query_states = self.q_proj(hidden_states).view(hidden_shape).transpose(1, 2)
  292. is_updated = past_key_values.is_updated.get(self.layer_idx) if past_key_values is not None else False
  293. if past_key_values is not None and is_updated:
  294. # reuse k,v, cross_attentions
  295. key_states = past_key_values.cross_attention_cache.layers[self.layer_idx].keys
  296. value_states = past_key_values.cross_attention_cache.layers[self.layer_idx].values
  297. else:
  298. key_states = self.k_proj(cross_attention_states).view(cross_shape).transpose(1, 2)
  299. value_states = self.v_proj(cross_attention_states).view(cross_shape).transpose(1, 2)
  300. if past_key_values is not None:
  301. # save all states to the cache
  302. key_states, value_states = past_key_values.cross_attention_cache.update(
  303. key_states,
  304. value_states,
  305. self.layer_idx,
  306. )
  307. # set flag that curr layer for cross-attn is already updated so we can re-use in subsequent calls
  308. past_key_values.is_updated[self.layer_idx] = True
  309. attention_interface: Callable = eager_attention_forward
  310. if self.config._attn_implementation != "eager":
  311. attention_interface = ALL_ATTENTION_FUNCTIONS[self.config._attn_implementation]
  312. attn_output, attn_weights = attention_interface(
  313. self,
  314. query_states,
  315. key_states,
  316. value_states,
  317. attention_mask,
  318. scaling=self.scaling,
  319. **kwargs,
  320. )
  321. attn_output = attn_output.reshape((*input_shape, -1)).contiguous()
  322. attn_output = self.o_proj(attn_output)
  323. return attn_output, attn_weights
  324. class DiaEncoderLayer(GradientCheckpointingLayer):
  325. def __init__(self, config: DiaEncoderConfig, layer_idx: int):
  326. super().__init__()
  327. self.pre_sa_norm = DiaRMSNorm(config.hidden_size, eps=config.norm_eps)
  328. self.self_attention = DiaSelfAttention(config, layer_idx, is_causal=False)
  329. self.post_sa_norm = DiaRMSNorm(config.hidden_size, eps=config.norm_eps)
  330. self.mlp = DiaMLP(config)
  331. def forward(
  332. self,
  333. hidden_states: torch.Tensor,
  334. position_embeddings: Optional[tuple[torch.Tensor, torch.Tensor]] = None, # necessary, but kept here for BC
  335. attention_mask: Optional[torch.Tensor] = None,
  336. **kwargs: Unpack[FlashAttentionKwargs],
  337. ) -> tuple[torch.Tensor, Optional[torch.Tensor]]:
  338. residual = hidden_states
  339. normed_states = self.pre_sa_norm(hidden_states)
  340. self_attn_output, self_attn_weights = self.self_attention(
  341. normed_states,
  342. position_embeddings=position_embeddings,
  343. attention_mask=attention_mask,
  344. **kwargs,
  345. )
  346. hidden_states = residual + self_attn_output
  347. residual = hidden_states
  348. normed_states = self.post_sa_norm(hidden_states)
  349. mlp_out = self.mlp(normed_states)
  350. hidden_states = residual + mlp_out
  351. return hidden_states, self_attn_weights
  352. class DiaEncoder(DiaPreTrainedModel):
  353. def __init__(self, config: DiaEncoderConfig):
  354. super().__init__(config)
  355. self.config = config
  356. self.embedding = nn.Embedding(config.vocab_size, config.hidden_size)
  357. self.layers = nn.ModuleList(
  358. [DiaEncoderLayer(config, layer_idx) for layer_idx in range(config.num_hidden_layers)]
  359. )
  360. self.norm = DiaRMSNorm(config.hidden_size, eps=config.norm_eps)
  361. self.rotary_embeddings = DiaRotaryEmbedding(config)
  362. @auto_docstring
  363. @can_return_tuple
  364. def forward(
  365. self,
  366. input_ids: torch.Tensor,
  367. attention_mask: Optional[torch.Tensor] = None,
  368. output_attentions: Optional[bool] = False,
  369. output_hidden_states: Optional[bool] = False,
  370. **kwargs: Unpack[FlashAttentionKwargs],
  371. ) -> Union[BaseModelOutput, tuple]:
  372. hidden_states = self.embedding(input_ids)
  373. # RoPE
  374. # Note: We expect right padding and hence always generate
  375. # the position ids on the fly to reduce preparation overhead
  376. position_ids = torch.arange(input_ids.shape[-1], device=input_ids.device)[None, :]
  377. position_embeddings = self.rotary_embeddings(hidden_states, position_ids)
  378. attention_mask = self._update_full_mask(
  379. attention_mask,
  380. hidden_states,
  381. )
  382. encoder_states = () if output_hidden_states else None
  383. all_attentions = () if output_attentions else None
  384. for encoder_layer in self.layers:
  385. if output_hidden_states:
  386. encoder_states = encoder_states + (hidden_states,)
  387. layer_outputs = encoder_layer(
  388. hidden_states,
  389. position_embeddings=position_embeddings,
  390. attention_mask=attention_mask,
  391. **kwargs,
  392. )
  393. hidden_states = layer_outputs[0]
  394. if output_attentions:
  395. all_attentions = all_attentions + (layer_outputs[1],)
  396. hidden_states = self.norm(hidden_states)
  397. if output_hidden_states:
  398. encoder_states += (hidden_states,)
  399. return BaseModelOutput(
  400. last_hidden_state=hidden_states, hidden_states=encoder_states, attentions=all_attentions
  401. )
  402. # Copied from transformers.models.bart.modeling_bart.BartPreTrainedModel._update_full_mask
  403. def _update_full_mask(
  404. self,
  405. attention_mask: Union[torch.Tensor, None],
  406. inputs_embeds: torch.Tensor,
  407. ):
  408. if attention_mask is not None:
  409. if self.config._attn_implementation == "flash_attention_2":
  410. attention_mask = attention_mask if 0 in attention_mask else None
  411. elif self.config._attn_implementation == "sdpa":
  412. # output_attentions=True & head_mask can not be supported when using SDPA, fall back to
  413. # the manual implementation that requires a 4D causal mask in all cases.
  414. # [bsz, seq_len] -> [bsz, 1, tgt_seq_len, src_seq_len]
  415. attention_mask = _prepare_4d_attention_mask_for_sdpa(attention_mask, inputs_embeds.dtype)
  416. elif self.config._attn_implementation == "flex_attention":
  417. if isinstance(attention_mask, torch.Tensor):
  418. attention_mask = make_flex_block_causal_mask(attention_mask, is_causal=False)
  419. else:
  420. # [bsz, seq_len] -> [bsz, 1, tgt_seq_len, src_seq_len]
  421. attention_mask = _prepare_4d_attention_mask(attention_mask, inputs_embeds.dtype)
  422. return attention_mask
  423. class DiaDecoderLayer(GradientCheckpointingLayer):
  424. def __init__(self, config: DiaDecoderConfig, layer_idx: int):
  425. super().__init__()
  426. self.embed_dim = config.hidden_size
  427. self.self_attention = DiaSelfAttention(config, layer_idx, is_causal=True)
  428. self.cross_attention = DiaCrossAttention(config, layer_idx)
  429. self.pre_sa_norm = DiaRMSNorm(config.hidden_size, eps=config.norm_eps)
  430. self.pre_ca_norm = DiaRMSNorm(config.hidden_size, eps=config.norm_eps)
  431. self.pre_mlp_norm = DiaRMSNorm(config.hidden_size, eps=config.norm_eps)
  432. self.mlp = DiaMLP(config)
  433. def forward(
  434. self,
  435. hidden_states: torch.Tensor,
  436. position_embeddings: Optional[tuple[torch.Tensor, torch.Tensor]] = None, # necessary, but kept here for BC
  437. attention_mask: Optional[torch.Tensor] = None,
  438. encoder_hidden_states: Optional[torch.Tensor] = None,
  439. encoder_attention_mask: Optional[torch.Tensor] = None,
  440. past_key_values: Optional[EncoderDecoderCache] = None,
  441. cache_position: Optional[torch.LongTensor] = None,
  442. **kwargs,
  443. ) -> tuple[torch.Tensor, Optional[torch.Tensor], Optional[torch.Tensor]]:
  444. self_attn_cache = past_key_values
  445. if isinstance(self_attn_cache, EncoderDecoderCache):
  446. self_attn_cache = self_attn_cache.self_attention_cache
  447. residual = hidden_states
  448. normed_states = self.pre_sa_norm(hidden_states)
  449. self_attn_output, self_attn_weights = self.self_attention(
  450. normed_states,
  451. position_embeddings,
  452. attention_mask,
  453. # Needs to be an arg in order to function properly
  454. # on inplace operations to be carried (e.g. compile)
  455. self_attn_cache,
  456. cache_position=cache_position,
  457. **kwargs,
  458. )
  459. hidden_states = residual + self_attn_output
  460. residual = hidden_states
  461. normed_states = self.pre_ca_norm(hidden_states)
  462. cross_states, cross_attn_weights = self.cross_attention(
  463. normed_states,
  464. encoder_hidden_states,
  465. attention_mask=encoder_attention_mask,
  466. past_key_values=past_key_values,
  467. **kwargs,
  468. )
  469. hidden_states = residual + cross_states
  470. residual = hidden_states
  471. normed_states = self.pre_mlp_norm(hidden_states)
  472. mlp_out = self.mlp(normed_states)
  473. hidden_states = residual + mlp_out
  474. return hidden_states, self_attn_weights, cross_attn_weights
  475. class DiaDecoder(DiaPreTrainedModel):
  476. """Transformer Decoder Stack using DenseGeneral."""
  477. def __init__(self, config: DiaDecoderConfig):
  478. super().__init__(config)
  479. self.num_channels = config.num_channels
  480. self.vocab_size = config.vocab_size
  481. self.embeddings = DiaMultiChannelEmbedding(config)
  482. self.rotary_embeddings = DiaRotaryEmbedding(config)
  483. self.layers = nn.ModuleList(
  484. [DiaDecoderLayer(config, layer_idx) for layer_idx in range(config.num_hidden_layers)]
  485. )
  486. self.norm = DiaRMSNorm(config.hidden_size, eps=config.norm_eps)
  487. @auto_docstring
  488. @can_return_tuple
  489. def forward(
  490. self,
  491. input_ids: torch.Tensor,
  492. position_ids: Optional[torch.LongTensor] = None,
  493. attention_mask: Optional[torch.Tensor] = None,
  494. encoder_hidden_states: Optional[torch.FloatTensor] = None,
  495. encoder_attention_mask: Optional[torch.LongTensor] = None,
  496. past_key_values: Optional[EncoderDecoderCache] = None,
  497. output_attentions: Optional[bool] = False,
  498. output_hidden_states: Optional[bool] = False,
  499. cache_position: Optional[torch.LongTensor] = None,
  500. **kwargs,
  501. ) -> Union[BaseModelOutputWithPastAndCrossAttentions, tuple]:
  502. r"""
  503. input_ids (`torch.LongTensor` of shape `(batch_size, sequence_length, num_codebooks)`):
  504. The original `decoder_input_ids` in 3D shape to facilitate more efficient computations.
  505. [What are input IDs?](../glossary#input-ids)
  506. """
  507. batch_size, seq_length = input_ids.size()[:-1]
  508. past_key_values_length = past_key_values.get_seq_length() if past_key_values is not None else 0
  509. if cache_position is None:
  510. cache_position = torch.arange(
  511. past_key_values_length, past_key_values_length + seq_length, device=input_ids.device
  512. )
  513. if position_ids is None:
  514. position_ids = cache_position[None, :]
  515. # RoPE
  516. hidden_states = self.embeddings(input_ids)
  517. position_embeddings = self.rotary_embeddings(hidden_states, position_ids)
  518. if attention_mask is None and not is_torchdynamo_compiling():
  519. # required mask seq length can be calculated via length of past cache
  520. mask_seq_length = past_key_values_length + seq_length
  521. attention_mask = torch.ones(batch_size, mask_seq_length, device=input_ids.device)
  522. attention_mask = create_causal_mask(
  523. config=self.config,
  524. input_embeds=hidden_states,
  525. attention_mask=attention_mask,
  526. cache_position=cache_position,
  527. past_key_values=past_key_values,
  528. position_ids=position_ids,
  529. )
  530. encoder_attention_mask = self._update_cross_attn_mask(
  531. encoder_hidden_states,
  532. encoder_attention_mask,
  533. hidden_states.shape[:2],
  534. hidden_states,
  535. )
  536. all_hidden_states = () if output_hidden_states else None
  537. all_self_attns = () if output_attentions else None
  538. all_cross_attentions = () if (output_attentions and encoder_hidden_states is not None) else None
  539. for layer in self.layers:
  540. if output_hidden_states:
  541. all_hidden_states += (hidden_states,)
  542. layer_outputs = layer(
  543. hidden_states,
  544. position_embeddings,
  545. attention_mask,
  546. encoder_hidden_states,
  547. encoder_attention_mask=encoder_attention_mask,
  548. past_key_values=past_key_values,
  549. cache_position=cache_position,
  550. **kwargs,
  551. )
  552. hidden_states = layer_outputs[0]
  553. if output_attentions:
  554. all_self_attns = all_self_attns + (layer_outputs[1],)
  555. if encoder_hidden_states is not None:
  556. all_cross_attentions = all_cross_attentions + (layer_outputs[2],)
  557. hidden_states = self.norm(hidden_states)
  558. if output_hidden_states:
  559. all_hidden_states += (hidden_states,)
  560. return BaseModelOutputWithPastAndCrossAttentions(
  561. last_hidden_state=hidden_states,
  562. past_key_values=past_key_values,
  563. hidden_states=all_hidden_states,
  564. attentions=all_self_attns,
  565. cross_attentions=all_cross_attentions,
  566. )
  567. # Copied from transformers.models.bart.modeling_bart.BartPreTrainedModel._update_cross_attn_mask
  568. def _update_cross_attn_mask(
  569. self,
  570. encoder_hidden_states: Union[torch.Tensor, None],
  571. encoder_attention_mask: Union[torch.Tensor, None],
  572. input_shape: torch.Size,
  573. inputs_embeds: torch.Tensor,
  574. ):
  575. # expand encoder attention mask
  576. if encoder_hidden_states is not None and encoder_attention_mask is not None:
  577. if self.config._attn_implementation == "flash_attention_2":
  578. encoder_attention_mask = encoder_attention_mask if 0 in encoder_attention_mask else None
  579. elif self.config._attn_implementation == "sdpa":
  580. # output_attentions=True & cross_attn_head_mask can not be supported when using SDPA, and we fall back on
  581. # the manual implementation that requires a 4D causal mask in all cases.
  582. # [bsz, seq_len] -> [bsz, 1, tgt_seq_len, src_seq_len]
  583. encoder_attention_mask = _prepare_4d_attention_mask_for_sdpa(
  584. encoder_attention_mask,
  585. inputs_embeds.dtype,
  586. tgt_len=input_shape[-1],
  587. )
  588. elif self.config._attn_implementation == "flex_attention":
  589. if isinstance(encoder_attention_mask, torch.Tensor):
  590. encoder_attention_mask = make_flex_block_causal_mask(
  591. encoder_attention_mask,
  592. query_length=input_shape[-1],
  593. is_causal=False,
  594. )
  595. else:
  596. # [bsz, seq_len] -> [bsz, 1, tgt_seq_len, src_seq_len]
  597. encoder_attention_mask = _prepare_4d_attention_mask(
  598. encoder_attention_mask, inputs_embeds.dtype, tgt_len=input_shape[-1]
  599. )
  600. return encoder_attention_mask
  601. @auto_docstring(
  602. custom_intro="""
  603. The bare Dia model outputting raw hidden-states without any specific head on top.
  604. """
  605. )
  606. class DiaModel(DiaPreTrainedModel):
  607. def __init__(self, config: DiaConfig):
  608. super().__init__(config)
  609. self.config = config
  610. self.encoder = DiaEncoder(config.encoder_config)
  611. self.decoder = DiaDecoder(config.decoder_config)
  612. self.post_init()
  613. def get_encoder(self):
  614. return self.encoder
  615. @auto_docstring
  616. @can_return_tuple
  617. def forward(
  618. self,
  619. input_ids: Optional[torch.LongTensor] = None,
  620. attention_mask: Optional[torch.LongTensor] = None,
  621. decoder_input_ids: Optional[torch.LongTensor] = None,
  622. decoder_position_ids: Optional[torch.LongTensor] = None,
  623. decoder_attention_mask: Optional[torch.LongTensor] = None,
  624. encoder_outputs: Optional[Union[BaseModelOutput, tuple]] = None,
  625. past_key_values: Optional[EncoderDecoderCache] = None,
  626. use_cache: Optional[bool] = None,
  627. output_attentions: Optional[bool] = None,
  628. output_hidden_states: Optional[bool] = None,
  629. cache_position: Optional[torch.LongTensor] = None,
  630. **kwargs,
  631. ) -> Union[tuple, Seq2SeqModelOutput]:
  632. r"""
  633. decoder_input_ids (`torch.LongTensor` of shape `(batch_size * num_codebooks, target_sequence_length)
  634. or (batch_size, target_sequence_length, num_codebooks)`, *optional*):
  635. 1. (batch_size * num_codebooks, target_sequence_length): corresponds to the general use case where
  636. the audio input codebooks are flattened into the batch dimension. This also aligns with the flat-
  637. tened audio logits which are used to calculate the loss.
  638. 2. (batch_size, sequence_length, num_codebooks): corresponds to the internally used shape of
  639. Dia to calculate embeddings and subsequent steps more efficiently.
  640. If no `decoder_input_ids` are provided, it will create a tensor of `bos_token_id` with shape
  641. `(batch_size, 1, num_codebooks)`. Indices can be obtained using the [`DiaProcessor`]. See
  642. [`DiaProcessor.__call__`] for more details.
  643. [What are decoder input IDs?](../glossary#decoder-input-ids)
  644. decoder_position_ids (`torch.LongTensor` of shape `(batch_size, target_sequence_length)`):
  645. Indices of positions of each input sequence tokens in the position embeddings.
  646. Used to calculate the position embeddings up to `config.decoder_config.max_position_embeddings`.
  647. [What are position IDs?](../glossary#position-ids)
  648. """
  649. if input_ids is None and encoder_outputs is None:
  650. raise ValueError(
  651. "You should either provide text ids or the cached text encodings. Neither has been found."
  652. )
  653. output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
  654. output_hidden_states = (
  655. output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
  656. )
  657. use_cache = use_cache if use_cache is not None else self.config.use_cache
  658. if self.is_gradient_checkpointing and self.training:
  659. if use_cache:
  660. logger.warning_once(
  661. "`use_cache=True` is incompatible with gradient checkpointing. Setting `use_cache=False`..."
  662. )
  663. use_cache = False
  664. if use_cache and past_key_values is None:
  665. past_key_values = EncoderDecoderCache(DynamicCache(config=self.config), DynamicCache(config=self.config))
  666. if encoder_outputs is None:
  667. encoder_outputs = self.encoder(
  668. input_ids=input_ids,
  669. attention_mask=attention_mask,
  670. output_attentions=output_attentions,
  671. output_hidden_states=output_hidden_states,
  672. **kwargs,
  673. )
  674. # If the user passed a tuple for encoder_outputs, we wrap it in a BaseModelOutput
  675. elif not isinstance(encoder_outputs, BaseModelOutput):
  676. encoder_outputs = BaseModelOutput(
  677. last_hidden_state=encoder_outputs[0],
  678. hidden_states=encoder_outputs[1] if len(encoder_outputs) > 1 else None,
  679. attentions=encoder_outputs[2] if len(encoder_outputs) > 2 else None,
  680. )
  681. # On default we initialize the decoder with bos tokens if nothing has been provided
  682. bsz, seq_len, channels = (encoder_outputs[0].shape[0], -1, self.config.decoder_config.num_channels)
  683. if decoder_input_ids is None:
  684. decoder_input_ids = torch.full(
  685. size=(bsz, 1, channels), fill_value=self.config.bos_token_id, device=self.device
  686. )
  687. # Ensure 3D
  688. if decoder_input_ids.ndim == 2:
  689. decoder_input_ids = decoder_input_ids.reshape(bsz, channels, seq_len).transpose(1, 2)
  690. decoder_outputs = self.decoder(
  691. input_ids=decoder_input_ids,
  692. position_ids=decoder_position_ids,
  693. attention_mask=decoder_attention_mask,
  694. encoder_hidden_states=encoder_outputs[0],
  695. encoder_attention_mask=attention_mask,
  696. past_key_values=past_key_values,
  697. output_attentions=output_attentions,
  698. output_hidden_states=output_hidden_states,
  699. use_cache=use_cache,
  700. cache_position=cache_position,
  701. **kwargs,
  702. )
  703. return Seq2SeqModelOutput(
  704. last_hidden_state=decoder_outputs.last_hidden_state,
  705. past_key_values=decoder_outputs.past_key_values,
  706. decoder_hidden_states=decoder_outputs.hidden_states,
  707. decoder_attentions=decoder_outputs.attentions,
  708. cross_attentions=decoder_outputs.cross_attentions,
  709. encoder_last_hidden_state=encoder_outputs[0],
  710. encoder_hidden_states=encoder_outputs.hidden_states,
  711. encoder_attentions=encoder_outputs.attentions,
  712. )
  713. @auto_docstring(
  714. custom_intro="""
  715. The Dia model consisting of a (byte) text encoder and audio decoder with a prediction head on top.
  716. """
  717. )
  718. class DiaForConditionalGeneration(DiaPreTrainedModel, DiaGenerationMixin):
  719. base_model_prefix = "model"
  720. def __init__(self, config: DiaConfig):
  721. super().__init__(config)
  722. self.config = config
  723. self.model = DiaModel(config)
  724. self.num_channels = config.decoder_config.num_channels
  725. self.vocab_size = config.decoder_config.vocab_size
  726. self.logits_dense = nn.Linear(
  727. config.decoder_config.hidden_size, (self.num_channels * self.vocab_size), bias=False
  728. )
  729. self.loss_type = "ForMaskedLM"
  730. # Initialize weights and apply final processing
  731. self.post_init()
  732. def get_encoder(self):
  733. return self.model.get_encoder()
  734. def get_decoder(self):
  735. return self.model.get_decoder()
  736. @auto_docstring
  737. @can_return_tuple
  738. def forward(
  739. self,
  740. input_ids: Optional[torch.LongTensor] = None,
  741. attention_mask: Optional[torch.LongTensor] = None,
  742. decoder_input_ids: Optional[torch.LongTensor] = None,
  743. decoder_position_ids: Optional[torch.LongTensor] = None,
  744. decoder_attention_mask: Optional[torch.LongTensor] = None,
  745. encoder_outputs: Optional[Union[BaseModelOutput, tuple]] = None,
  746. past_key_values: Optional[EncoderDecoderCache] = None,
  747. use_cache: Optional[bool] = None,
  748. output_attentions: Optional[bool] = None,
  749. output_hidden_states: Optional[bool] = None,
  750. labels: Optional[torch.LongTensor] = None,
  751. cache_position: Optional[torch.LongTensor] = None,
  752. **kwargs,
  753. ) -> Union[tuple, Seq2SeqLMOutput]:
  754. r"""
  755. decoder_input_ids (`torch.LongTensor` of shape `(batch_size * num_codebooks, target_sequence_length)
  756. or (batch_size, target_sequence_length, num_codebooks)`, *optional*):
  757. 1. (batch_size * num_codebooks, target_sequence_length): corresponds to the general use case where
  758. the audio input codebooks are flattened into the batch dimension. This also aligns with the flat-
  759. tened audio logits which are used to calculate the loss.
  760. 2. (batch_size, sequence_length, num_codebooks): corresponds to the internally used shape of
  761. Dia to calculate embeddings and subsequent steps more efficiently.
  762. If no `decoder_input_ids` are provided, it will create a tensor of `bos_token_id` with shape
  763. `(batch_size, 1, num_codebooks)`. Indices can be obtained using the [`DiaProcessor`]. See
  764. [`DiaProcessor.__call__`] for more details.
  765. [What are decoder input IDs?](../glossary#decoder-input-ids)
  766. decoder_position_ids (`torch.LongTensor` of shape `(batch_size, target_sequence_length)`):
  767. Indices of positions of each input sequence tokens in the position embeddings.
  768. Used to calculate the position embeddings up to `config.decoder_config.max_position_embeddings`.
  769. [What are position IDs?](../glossary#position-ids)
  770. labels (`torch.LongTensor` of shape `(batch_size * num_codebooks,)`, *optional*):
  771. Labels for computing the masked language modeling loss. Indices should either be in
  772. `[0, ..., config.decoder_config.vocab_size - 1]` or -100. Tokens with indices set to `-100`
  773. are ignored (masked).
  774. """
  775. outputs = self.model(
  776. input_ids=input_ids,
  777. attention_mask=attention_mask,
  778. decoder_input_ids=decoder_input_ids,
  779. decoder_position_ids=decoder_position_ids,
  780. decoder_attention_mask=decoder_attention_mask,
  781. encoder_outputs=encoder_outputs,
  782. past_key_values=past_key_values,
  783. use_cache=use_cache,
  784. output_attentions=output_attentions,
  785. output_hidden_states=output_hidden_states,
  786. cache_position=cache_position,
  787. **kwargs,
  788. )
  789. last_hidden_state = outputs[0]
  790. batch_size = last_hidden_state.shape[0]
  791. # 3D <-> 2D makes it necessary to prioritize channel dim
  792. audio_logits = (
  793. self.logits_dense(last_hidden_state)
  794. .view((batch_size, -1, self.num_channels, self.vocab_size))
  795. .transpose(1, 2)
  796. .contiguous()
  797. .view(batch_size * self.num_channels, -1, self.vocab_size)
  798. )
  799. loss = None
  800. if labels is not None:
  801. loss = self.loss_function(logits=audio_logits, labels=labels, vocab_size=self.vocab_size, **kwargs)
  802. return Seq2SeqLMOutput(
  803. loss=loss,
  804. logits=audio_logits,
  805. past_key_values=outputs.past_key_values,
  806. decoder_hidden_states=outputs.decoder_hidden_states,
  807. decoder_attentions=outputs.decoder_attentions,
  808. cross_attentions=outputs.cross_attentions,
  809. encoder_last_hidden_state=outputs.encoder_last_hidden_state,
  810. encoder_hidden_states=outputs.encoder_hidden_states,
  811. encoder_attentions=outputs.encoder_attentions,
  812. )
  813. __all__ = ["DiaModel", "DiaPreTrainedModel", "DiaForConditionalGeneration"]