modeling_moshi.py 121 KB

1234567891011121314151617181920212223242526272829303132333435363738394041424344454647484950515253545556575859606162636465666768697071727374757677787980818283848586878889909192939495969798991001011021031041051061071081091101111121131141151161171181191201211221231241251261271281291301311321331341351361371381391401411421431441451461471481491501511521531541551561571581591601611621631641651661671681691701711721731741751761771781791801811821831841851861871881891901911921931941951961971981992002012022032042052062072082092102112122132142152162172182192202212222232242252262272282292302312322332342352362372382392402412422432442452462472482492502512522532542552562572582592602612622632642652662672682692702712722732742752762772782792802812822832842852862872882892902912922932942952962972982993003013023033043053063073083093103113123133143153163173183193203213223233243253263273283293303313323333343353363373383393403413423433443453463473483493503513523533543553563573583593603613623633643653663673683693703713723733743753763773783793803813823833843853863873883893903913923933943953963973983994004014024034044054064074084094104114124134144154164174184194204214224234244254264274284294304314324334344354364374384394404414424434444454464474484494504514524534544554564574584594604614624634644654664674684694704714724734744754764774784794804814824834844854864874884894904914924934944954964974984995005015025035045055065075085095105115125135145155165175185195205215225235245255265275285295305315325335345355365375385395405415425435445455465475485495505515525535545555565575585595605615625635645655665675685695705715725735745755765775785795805815825835845855865875885895905915925935945955965975985996006016026036046056066076086096106116126136146156166176186196206216226236246256266276286296306316326336346356366376386396406416426436446456466476486496506516526536546556566576586596606616626636646656666676686696706716726736746756766776786796806816826836846856866876886896906916926936946956966976986997007017027037047057067077087097107117127137147157167177187197207217227237247257267277287297307317327337347357367377387397407417427437447457467477487497507517527537547557567577587597607617627637647657667677687697707717727737747757767777787797807817827837847857867877887897907917927937947957967977987998008018028038048058068078088098108118128138148158168178188198208218228238248258268278288298308318328338348358368378388398408418428438448458468478488498508518528538548558568578588598608618628638648658668678688698708718728738748758768778788798808818828838848858868878888898908918928938948958968978988999009019029039049059069079089099109119129139149159169179189199209219229239249259269279289299309319329339349359369379389399409419429439449459469479489499509519529539549559569579589599609619629639649659669679689699709719729739749759769779789799809819829839849859869879889899909919929939949959969979989991000100110021003100410051006100710081009101010111012101310141015101610171018101910201021102210231024102510261027102810291030103110321033103410351036103710381039104010411042104310441045104610471048104910501051105210531054105510561057105810591060106110621063106410651066106710681069107010711072107310741075107610771078107910801081108210831084108510861087108810891090109110921093109410951096109710981099110011011102110311041105110611071108110911101111111211131114111511161117111811191120112111221123112411251126112711281129113011311132113311341135113611371138113911401141114211431144114511461147114811491150115111521153115411551156115711581159116011611162116311641165116611671168116911701171117211731174117511761177117811791180118111821183118411851186118711881189119011911192119311941195119611971198119912001201120212031204120512061207120812091210121112121213121412151216121712181219122012211222122312241225122612271228122912301231123212331234123512361237123812391240124112421243124412451246124712481249125012511252125312541255125612571258125912601261126212631264126512661267126812691270127112721273127412751276127712781279128012811282128312841285128612871288128912901291129212931294129512961297129812991300130113021303130413051306130713081309131013111312131313141315131613171318131913201321132213231324132513261327132813291330133113321333133413351336133713381339134013411342134313441345134613471348134913501351135213531354135513561357135813591360136113621363136413651366136713681369137013711372137313741375137613771378137913801381138213831384138513861387138813891390139113921393139413951396139713981399140014011402140314041405140614071408140914101411141214131414141514161417141814191420142114221423142414251426142714281429143014311432143314341435143614371438143914401441144214431444144514461447144814491450145114521453145414551456145714581459146014611462146314641465146614671468146914701471147214731474147514761477147814791480148114821483148414851486148714881489149014911492149314941495149614971498149915001501150215031504150515061507150815091510151115121513151415151516151715181519152015211522152315241525152615271528152915301531153215331534153515361537153815391540154115421543154415451546154715481549155015511552155315541555155615571558155915601561156215631564156515661567156815691570157115721573157415751576157715781579158015811582158315841585158615871588158915901591159215931594159515961597159815991600160116021603160416051606160716081609161016111612161316141615161616171618161916201621162216231624162516261627162816291630163116321633163416351636163716381639164016411642164316441645164616471648164916501651165216531654165516561657165816591660166116621663166416651666166716681669167016711672167316741675167616771678167916801681168216831684168516861687168816891690169116921693169416951696169716981699170017011702170317041705170617071708170917101711171217131714171517161717171817191720172117221723172417251726172717281729173017311732173317341735173617371738173917401741174217431744174517461747174817491750175117521753175417551756175717581759176017611762176317641765176617671768176917701771177217731774177517761777177817791780178117821783178417851786178717881789179017911792179317941795179617971798179918001801180218031804180518061807180818091810181118121813181418151816181718181819182018211822182318241825182618271828182918301831183218331834183518361837183818391840184118421843184418451846184718481849185018511852185318541855185618571858185918601861186218631864186518661867186818691870187118721873187418751876187718781879188018811882188318841885188618871888188918901891189218931894189518961897189818991900190119021903190419051906190719081909191019111912191319141915191619171918191919201921192219231924192519261927192819291930193119321933193419351936193719381939194019411942194319441945194619471948194919501951195219531954195519561957195819591960196119621963196419651966196719681969197019711972197319741975197619771978197919801981198219831984198519861987198819891990199119921993199419951996199719981999200020012002200320042005200620072008200920102011201220132014201520162017201820192020202120222023202420252026202720282029203020312032203320342035203620372038203920402041204220432044204520462047204820492050205120522053205420552056205720582059206020612062206320642065206620672068206920702071207220732074207520762077207820792080208120822083208420852086208720882089209020912092209320942095209620972098209921002101210221032104210521062107210821092110211121122113211421152116211721182119212021212122212321242125212621272128212921302131213221332134213521362137213821392140214121422143214421452146214721482149215021512152215321542155215621572158215921602161216221632164216521662167216821692170217121722173217421752176217721782179218021812182218321842185218621872188218921902191219221932194219521962197219821992200220122022203220422052206220722082209221022112212221322142215221622172218221922202221222222232224222522262227222822292230223122322233223422352236223722382239224022412242224322442245224622472248224922502251225222532254225522562257225822592260226122622263226422652266226722682269227022712272227322742275227622772278227922802281228222832284228522862287228822892290229122922293229422952296229722982299230023012302230323042305230623072308230923102311231223132314231523162317231823192320232123222323232423252326232723282329233023312332233323342335233623372338233923402341234223432344234523462347234823492350235123522353235423552356235723582359236023612362236323642365236623672368236923702371237223732374237523762377237823792380238123822383238423852386238723882389239023912392239323942395239623972398239924002401240224032404240524062407240824092410241124122413241424152416241724182419242024212422242324242425242624272428242924302431243224332434243524362437243824392440244124422443244424452446244724482449245024512452245324542455245624572458245924602461246224632464246524662467246824692470247124722473247424752476247724782479248024812482248324842485
  1. # coding=utf-8
  2. # Copyright 2024 Kyutai and The HuggingFace Inc. team. All rights reserved.
  3. #
  4. # Licensed under the Apache License, Version 2.0 (the "License");
  5. # you may not use this file except in compliance with the License.
  6. # You may obtain a copy of the License at
  7. #
  8. # http://www.apache.org/licenses/LICENSE-2.0
  9. #
  10. # Unless required by applicable law or agreed to in writing, software
  11. # distributed under the License is distributed on an "AS IS" BASIS,
  12. # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  13. # See the License for the specific language governing permissions and
  14. # limitations under the License.
  15. """PyTorch Moshi model."""
  16. import math
  17. from dataclasses import dataclass
  18. from typing import Any, Optional, Union
  19. import torch
  20. import torch.nn as nn
  21. from torch.nn import CrossEntropyLoss
  22. from ...activations import ACT2FN
  23. from ...cache_utils import Cache, DynamicCache, StaticCache
  24. from ...generation import GenerationConfig, GenerationMixin
  25. from ...modeling_attn_mask_utils import AttentionMaskConverter
  26. from ...modeling_flash_attention_utils import flash_attn_supports_top_left_mask, is_flash_attn_available
  27. from ...modeling_layers import GradientCheckpointingLayer
  28. from ...modeling_outputs import BaseModelOutputWithPast, CausalLMOutputWithPast, ModelOutput, Seq2SeqLMOutput
  29. from ...modeling_rope_utils import ROPE_INIT_FUNCTIONS, dynamic_rope_update
  30. from ...modeling_utils import PreTrainedModel
  31. from ...utils import auto_docstring, is_torch_flex_attn_available, logging
  32. from ...utils.deprecation import deprecate_kwarg
  33. from ..auto.modeling_auto import AutoModel
  34. from .configuration_moshi import MoshiConfig, MoshiDepthConfig
  35. if is_flash_attn_available():
  36. from ...modeling_flash_attention_utils import _flash_attention_forward
  37. if is_torch_flex_attn_available():
  38. from torch.nn.attention.flex_attention import BlockMask
  39. from ...integrations.flex_attention import make_flex_block_causal_mask
  40. logger = logging.get_logger(__name__)
  41. @dataclass
  42. @auto_docstring(
  43. custom_intro="""
  44. Outputs of [`MoshiForConditionalConditionalGeneration.generate`].
  45. """
  46. )
  47. class MoshiConditionalGenerationGenerateOutput(ModelOutput):
  48. r"""
  49. audio_sequences (`torch.LongTensor` of shape `(batch_size*num_return_sequences, 1, sequence_length)`, *optional*):
  50. The generated audio waveforms.
  51. sequences (`torch.LongTensor` of shape `(batch_size*num_return_sequences, sequence_length)`):
  52. The generated text sequences. The second dimension (sequence_length) is either equal to `max_length` or shorter
  53. if all batches finished early due to the `eos_token_id`.
  54. sequences_scores (`torch.FloatTensor` of shape `(batch_size*num_return_sequences)`, *optional*, returned when `output_scores=True`):
  55. Final beam scores of the generated `sequences`.
  56. scores (`tuple(torch.FloatTensor)` *optional*, returned when `output_scores=True`):
  57. Beam transition scores for each vocabulary token at each generation step. Beam transition scores consisting
  58. of log probabilities of tokens conditioned on log softmax of previously generated tokens in this beam.
  59. Tuple of `torch.FloatTensor` with up to `max_new_tokens` elements (one element for each generated token),
  60. with each tensor of shape `(batch_size*num_beams, config.vocab_size)`.
  61. logits (`tuple(torch.FloatTensor)` *optional*, returned when `output_logits=True`):
  62. Unprocessed prediction scores of the language modeling head (scores for each vocabulary token before SoftMax)
  63. at each generation step. Tuple of `torch.FloatTensor` with up to `max_new_tokens` elements (one element for
  64. each generated token), with each tensor of shape `(batch_size, config.vocab_size)`.
  65. beam_indices (`torch.LongTensor`, *optional*, returned when `output_scores=True`):
  66. Beam indices of generated token id at each generation step. `torch.LongTensor` of shape
  67. `(batch_size*num_return_sequences, sequence_length)`.
  68. attentions (`tuple(tuple(torch.FloatTensor))`, *optional*, returned when `output_attentions=True`):
  69. Tuple (one element for each generated token) of tuples (one element for each layer of the decoder) of
  70. `torch.FloatTensor` of shape `(batch_size*num_beams, num_heads, generated_length, sequence_length)`.
  71. hidden_states (`tuple(tuple(torch.FloatTensor))`, *optional*, returned when `output_hidden_states=True`):
  72. Tuple (one element for each generated token) of tuples (one element for each layer of the decoder) of
  73. `torch.FloatTensor` of shape `(batch_size*num_beams*num_return_sequences, generated_length, hidden_size)`.
  74. past_key_values (`Cache`, *optional*, returned when `use_cache=True`):
  75. Contains the model cache, used to speed up decoding. Different models have a different cache format, check
  76. the model's documentation. Usually, a [`~cache_utils.Cache`] instance.
  77. audio_codes (`torch.LongTensor` of shape `(batch_size*num_return_sequences, num_codeooks, sequence_length)`, *optional*):
  78. The generated audio codes. Returned if `return_audio_codes=True`. Intermediate audio "tokens" which transforms to `audio_sequences` once passed through the audio decoder.
  79. """
  80. audio_sequences: Optional[torch.Tensor] = None
  81. sequences: Optional[torch.LongTensor] = None
  82. sequences_scores: Optional[torch.FloatTensor] = None
  83. scores: Optional[tuple[torch.FloatTensor]] = None
  84. logits: Optional[tuple[torch.FloatTensor]] = None
  85. beam_indices: Optional[torch.LongTensor] = None
  86. attentions: Optional[tuple[tuple[torch.FloatTensor]]] = None
  87. hidden_states: Optional[tuple[tuple[torch.FloatTensor]]] = None
  88. past_key_values: Optional[Cache] = None
  89. audio_codes: Optional[torch.LongTensor] = None
  90. @dataclass
  91. @auto_docstring(
  92. custom_intro="""
  93. `MoshiForCausalLM` outputs.
  94. """
  95. )
  96. class MoshiCausalLMOutputWithPast(ModelOutput):
  97. r"""
  98. loss (`torch.FloatTensor` of shape `(1,)`, *optional*, returned when `labels` is provided):
  99. Language modeling loss (for next-token prediction).
  100. logits (`torch.FloatTensor` of shape `(batch_size, sequence_length, config.vocab_size)`):
  101. Prediction scores of the language modeling head (scores for each vocabulary token before SoftMax).
  102. past_key_values (`Cache`, *optional*, returned when `use_cache=True` is passed or when `config.use_cache=True`):
  103. It is a [`~cache_utils.Cache`] instance. For more details, see our [kv cache guide](https://huggingface.co/docs/transformers/en/kv_cache).
  104. Contains pre-computed hidden-states (key and values in the self-attention blocks) that can be used (see
  105. `past_key_values` input) to speed up sequential decoding.
  106. """
  107. loss: Optional[torch.FloatTensor] = None
  108. logits: Optional[torch.FloatTensor] = None
  109. last_hidden_state: Optional[torch.FloatTensor] = None
  110. past_key_values: Optional[Cache] = None
  111. hidden_states: Optional[tuple[torch.FloatTensor, ...]] = None
  112. attentions: Optional[tuple[torch.FloatTensor, ...]] = None
  113. @dataclass
  114. @auto_docstring(
  115. custom_intro="""
  116. `MoshiForConditionalGeneration` outputs.
  117. """
  118. )
  119. class MoshiConditionalGenerationOutputWithPast(ModelOutput):
  120. r"""
  121. loss (`torch.FloatTensor` of shape `(1,)`, *optional*, returned when `text_labels` is provided):
  122. Text language modeling loss (for next-token prediction).
  123. logits (`torch.FloatTensor` of shape `(batch_size, sequence_length, config.vocab_size)`):
  124. Prediction scores of the text language modeling head (scores for each vocabulary token before SoftMax).
  125. past_key_values (`Cache`, *optional*, returned when `use_cache=True` is passed or when `config.use_cache=True`):
  126. It is a [`~cache_utils.Cache`] instance. For more details, see our [kv cache guide](https://huggingface.co/docs/transformers/en/kv_cache).
  127. Contains pre-computed hidden-states (key and values in the self-attention blocks) that can be used (see
  128. `past_key_values` input) to speed up sequential decoding.
  129. depth_loss (`torch.FloatTensor` of shape `(1,)`, *optional*, returned when `audio_labels` is provided):
  130. Audio language modeling loss (for next-token prediction).
  131. audio_logits (`torch.FloatTensor` of shape `(batch_size, sequence_length, config.vocab_size)`):
  132. Prediction scores of the audio language modeling heads.
  133. depth_past_key_values (`Cache`, *optional*, returned when `use_cache=True` is passed or when `config.use_cache=True`):
  134. Past key-values of the depth decoder.
  135. depth_hidden_states (`tuple(torch.FloatTensor)`, *optional*, returned when `output_hidden_states=True` is passed or when `config.output_hidden_states=True`):
  136. Hidden states of the depth decoder
  137. depth_attentions (`tuple(torch.FloatTensor)`, *optional*, returned when `output_attentions=True` is passed or when `config.output_attentions=True`):
  138. Depth decoder's Attentions weights after the attention softmax, used to compute the weighted average in the self-attention
  139. heads.
  140. """
  141. loss: Optional[torch.FloatTensor] = None
  142. logits: Optional[torch.FloatTensor] = None
  143. last_hidden_state: Optional[torch.FloatTensor] = None
  144. past_key_values: Optional[Cache] = None
  145. hidden_states: Optional[tuple[torch.FloatTensor, ...]] = None
  146. attentions: Optional[tuple[torch.FloatTensor, ...]] = None
  147. depth_loss: Optional[torch.FloatTensor] = None
  148. audio_logits: Optional[torch.FloatTensor] = None
  149. depth_past_key_values: Optional[Cache] = None
  150. depth_hidden_states: Optional[tuple[torch.FloatTensor, ...]] = None
  151. depth_attentions: Optional[tuple[torch.FloatTensor, ...]] = None
  152. @dataclass
  153. @auto_docstring
  154. class MoshiUnconditionalInput(ModelOutput):
  155. r"""
  156. input_ids (`torch.Tensor `of shape `(batch_size, sequence_length), *optional*):
  157. The sequence used as a text prompt for the generation.
  158. user_audio_codes (`torch.Tensor `of shape `(batch_size, num_codebooks, sequence_length), *optional*):
  159. The audio codes used as audio user prompt for the generation. Has priority over `user_input_values` and represents the audio "tokens" of `user_input_values` once passed through the audio encoder.
  160. moshi_audio_codes (`torch.Tensor `of shape `(batch_size, num_codebooks, sequence_length), *optional*):
  161. The audio codes used as audio Moshi prompt for the generation. Has priority over `moshi_input_values` and represents the audio "tokens" of `moshi_input_values` once passed through the audio encoder.
  162. attention_mask (`torch.LongTensor`) of shape `(batch_size, sequence_length)`, *optional*):
  163. Attention mask to avoid performing attention on padding token indices. Mask values selected in `[0,
  164. 1]`: 1 for tokens that are **not masked**, 0 for tokens that are **masked**.
  165. """
  166. input_ids: Optional[torch.LongTensor] = None
  167. user_audio_codes: Optional[torch.Tensor] = None
  168. moshi_audio_codes: Optional[torch.Tensor] = None
  169. attention_mask: Optional[torch.LongTensor] = None
  170. # Copied from transformers.models.gemma.modeling_gemma.GemmaRMSNorm with Gemma->Moshi
  171. class MoshiRMSNorm(nn.Module):
  172. def __init__(self, dim: int, eps: float = 1e-6):
  173. super().__init__()
  174. self.eps = eps
  175. self.weight = nn.Parameter(torch.ones(dim)) # Ignore copy
  176. def _norm(self, x):
  177. return x * torch.rsqrt(x.pow(2).mean(-1, keepdim=True) + self.eps)
  178. # Ignore copy
  179. def forward(self, x):
  180. output = self._norm(x.float())
  181. output = output * self.weight.float()
  182. return output.type_as(x)
  183. def extra_repr(self):
  184. return f"{tuple(self.weight.shape)}, eps={self.eps}"
  185. class MoshiFlexibleLinear(nn.Module):
  186. def __init__(self, input_size, output_size, num_layers):
  187. super().__init__()
  188. # Stack the weights for N layers into a single tensor (num_layers, output_size, input_size)
  189. self.weight = nn.Parameter(torch.randn(num_layers, output_size, input_size))
  190. def forward(self, x, layer_idx=None):
  191. """
  192. `MoshiFlexibleLinear` creates one linear layer per codebook. There's multiple ways to use it.
  193. In the default case, `sequence_length=num_layers`, so each element of the sequence will be matmul to the weights corresponding to its index on the sequence.
  194. For more advanced cases, one can specify which codebook's layer(s) to use with `layer_idx`.
  195. If `layer_idx` indicates a single integer, all of the element of the sequence will be matmul to this single codebook's layer.
  196. But if `layer_idx` is a tensor of shape `(seq_length,)`, it will matmul each i-th element of the input sequence to the corresponding layer `weight[i]`.
  197. Args:
  198. x (`torch.FloatTensor): input to the layer of shape `(batch, num_layers, embed_dim)` or of shape `(batch, seq_length, embed_dim)`
  199. layer_idx (`torch.Tensor`, *optional*):
  200. Can be used to specify which codebook's layers(s) to use.
  201. If it's a tensor of shape `(seq_length,)`, will matmul each element of the sequence to the corresponding weights.
  202. But if `layer_idx` is a tensor of shape `(seq_length,)`, it will matmul each i-th element of the input sequence to the corresponding layer `weight[i]`.
  203. """
  204. # Use torch.gather to select the corresponding weights for each sample
  205. # (codebooks, output_size, hidden_size)
  206. selected_weights = torch.index_select(self.weight, 0, layer_idx) if layer_idx is not None else self.weight
  207. # (1, codebooks, hidden_size, output_size)
  208. selected_weights = selected_weights.transpose(1, 2)[None, :, :, :]
  209. # (batch_size, codebooks, 1, hidden_size) x (1, codebooks, hidden_size, output_size)
  210. # -> (batch_size, codebooks, 1, output_size)
  211. x = torch.matmul(x[:, :, None, :], selected_weights)
  212. # (batch_size, codebooks, output_size)
  213. return x.squeeze(2)
  214. class MoshiLinear(nn.Module):
  215. def __init__(self, input_dim, output_dim, num_codebooks, use_flexible_linear=False):
  216. super().__init__()
  217. self.use_flexible_linear = use_flexible_linear
  218. if not use_flexible_linear:
  219. self.linear = nn.Linear(input_dim, output_dim, bias=False)
  220. else:
  221. self.linear = MoshiFlexibleLinear(input_dim, output_dim, num_layers=num_codebooks)
  222. def forward(self, x, layer_idx=None):
  223. if self.use_flexible_linear:
  224. return self.linear(x, layer_idx)
  225. else:
  226. return self.linear(x)
  227. # Copied from transformers.models.mistral.modeling_mistral.MistralRotaryEmbedding with Mistral->Moshi
  228. class MoshiRotaryEmbedding(nn.Module):
  229. inv_freq: torch.Tensor # fix linting for `register_buffer`
  230. def __init__(self, config: MoshiConfig, device=None):
  231. super().__init__()
  232. # BC: "rope_type" was originally "type"
  233. if hasattr(config, "rope_scaling") and isinstance(config.rope_scaling, dict):
  234. self.rope_type = config.rope_scaling.get("rope_type", config.rope_scaling.get("type"))
  235. else:
  236. self.rope_type = "default"
  237. self.max_seq_len_cached = config.max_position_embeddings
  238. self.original_max_seq_len = config.max_position_embeddings
  239. self.config = config
  240. self.rope_init_fn = ROPE_INIT_FUNCTIONS[self.rope_type]
  241. inv_freq, self.attention_scaling = self.rope_init_fn(self.config, device)
  242. self.register_buffer("inv_freq", inv_freq, persistent=False)
  243. self.original_inv_freq = self.inv_freq
  244. @torch.no_grad()
  245. @dynamic_rope_update # power user: used with advanced RoPE types (e.g. dynamic rope)
  246. def forward(self, x, position_ids):
  247. inv_freq_expanded = self.inv_freq[None, :, None].float().expand(position_ids.shape[0], -1, 1).to(x.device)
  248. position_ids_expanded = position_ids[:, None, :].float()
  249. device_type = x.device.type if isinstance(x.device.type, str) and x.device.type != "mps" else "cpu"
  250. with torch.autocast(device_type=device_type, enabled=False): # Force float32
  251. freqs = (inv_freq_expanded.float() @ position_ids_expanded.float()).transpose(1, 2)
  252. emb = torch.cat((freqs, freqs), dim=-1)
  253. cos = emb.cos() * self.attention_scaling
  254. sin = emb.sin() * self.attention_scaling
  255. return cos.to(dtype=x.dtype), sin.to(dtype=x.dtype)
  256. # Copied from transformers.models.llama.modeling_llama.rotate_half
  257. def rotate_half(x):
  258. """Rotates half the hidden dims of the input."""
  259. x1 = x[..., : x.shape[-1] // 2]
  260. x2 = x[..., x.shape[-1] // 2 :]
  261. return torch.cat((-x2, x1), dim=-1)
  262. # Copied from transformers.models.llama.modeling_llama.apply_rotary_pos_emb
  263. def apply_rotary_pos_emb(q, k, cos, sin, position_ids=None, unsqueeze_dim=1):
  264. """Applies Rotary Position Embedding to the query and key tensors.
  265. Args:
  266. q (`torch.Tensor`): The query tensor.
  267. k (`torch.Tensor`): The key tensor.
  268. cos (`torch.Tensor`): The cosine part of the rotary embedding.
  269. sin (`torch.Tensor`): The sine part of the rotary embedding.
  270. position_ids (`torch.Tensor`, *optional*):
  271. Deprecated and unused.
  272. unsqueeze_dim (`int`, *optional*, defaults to 1):
  273. The 'unsqueeze_dim' argument specifies the dimension along which to unsqueeze cos[position_ids] and
  274. sin[position_ids] so that they can be properly broadcasted to the dimensions of q and k. For example, note
  275. that cos[position_ids] and sin[position_ids] have the shape [batch_size, seq_len, head_dim]. Then, if q and
  276. k have the shape [batch_size, heads, seq_len, head_dim], then setting unsqueeze_dim=1 makes
  277. cos[position_ids] and sin[position_ids] broadcastable to the shapes of q and k. Similarly, if q and k have
  278. the shape [batch_size, seq_len, heads, head_dim], then set unsqueeze_dim=2.
  279. Returns:
  280. `tuple(torch.Tensor)` comprising of the query and key tensors rotated using the Rotary Position Embedding.
  281. """
  282. cos = cos.unsqueeze(unsqueeze_dim)
  283. sin = sin.unsqueeze(unsqueeze_dim)
  284. q_embed = (q * cos) + (rotate_half(q) * sin)
  285. k_embed = (k * cos) + (rotate_half(k) * sin)
  286. return q_embed, k_embed
  287. class MoshiGatingMLP(nn.Module):
  288. def __init__(self, config, use_flexible_linear=False):
  289. super().__init__()
  290. self.activation_fn = ACT2FN[config.hidden_act]
  291. ffn_dim = config.ffn_dim
  292. hidden_size = config.hidden_size
  293. num_layers = config.num_codebooks if use_flexible_linear else 1
  294. if num_layers == 1:
  295. self.fc1 = nn.Linear(hidden_size, ffn_dim, bias=False)
  296. self.fc2 = nn.Linear(ffn_dim // 2, hidden_size, bias=False)
  297. else:
  298. self.fc1 = MoshiFlexibleLinear(hidden_size, ffn_dim, num_layers)
  299. self.fc2 = MoshiFlexibleLinear(ffn_dim // 2, hidden_size, num_layers)
  300. def forward(self, hidden_states: torch.Tensor, layer_idx: Optional[int] = None) -> torch.Tensor:
  301. hidden_states = self.fc1(hidden_states) if layer_idx is None else self.fc1(hidden_states, layer_idx)
  302. batch_size, sequence_length, _ = hidden_states.shape
  303. hidden_states = hidden_states.view(batch_size, sequence_length, 2, -1)
  304. hidden_states = self.activation_fn(hidden_states[..., 0, :]) * hidden_states[..., 1, :]
  305. hidden_states = self.fc2(hidden_states) if layer_idx is None else self.fc2(hidden_states, layer_idx)
  306. return hidden_states
  307. # Copied from transformers.models.llama.modeling_llama.repeat_kv
  308. def repeat_kv(hidden_states: torch.Tensor, n_rep: int) -> torch.Tensor:
  309. """
  310. This is the equivalent of torch.repeat_interleave(x, dim=1, repeats=n_rep). The hidden states go from (batch,
  311. num_key_value_heads, seqlen, head_dim) to (batch, num_attention_heads, seqlen, head_dim)
  312. """
  313. batch, num_key_value_heads, slen, head_dim = hidden_states.shape
  314. if n_rep == 1:
  315. return hidden_states
  316. hidden_states = hidden_states[:, :, None, :, :].expand(batch, num_key_value_heads, n_rep, slen, head_dim)
  317. return hidden_states.reshape(batch, num_key_value_heads * n_rep, slen, head_dim)
  318. class MoshiAttention(nn.Module):
  319. """Multi-headed attention from 'Attention Is All You Need' paper"""
  320. def __init__(self, config: MoshiConfig, layer_idx: Optional[int] = None, use_flexible_linear=False, use_rope=True):
  321. super().__init__()
  322. self.config = config
  323. self.layer_idx = layer_idx
  324. if layer_idx is None:
  325. logger.warning_once(
  326. f"Instantiating {self.__class__.__name__} without passing a `layer_idx` is not recommended and will "
  327. "lead to errors during the forward call if caching is used. Please make sure to provide a `layer_idx` "
  328. "when creating this class."
  329. )
  330. self.attention_dropout = config.attention_dropout
  331. self.hidden_size = config.hidden_size
  332. self.num_heads = config.num_attention_heads
  333. self.head_dim = config.head_dim
  334. self.num_key_value_heads = config.num_key_value_heads
  335. self.num_key_value_groups = self.num_heads // self.num_key_value_heads
  336. self.max_position_embeddings = config.max_position_embeddings
  337. self.is_causal = True
  338. self.scaling = 1 / math.sqrt(self.head_dim)
  339. if self.hidden_size % self.num_heads != 0:
  340. raise ValueError(
  341. f"hidden_size must be divisible by num_heads (got `hidden_size`: {self.hidden_size}"
  342. f" and `num_heads`: {self.num_heads})."
  343. )
  344. self.q_proj = MoshiLinear(
  345. self.hidden_size, self.num_heads * self.head_dim, config.num_codebooks, use_flexible_linear
  346. )
  347. self.k_proj = MoshiLinear(
  348. self.hidden_size, self.num_key_value_heads * self.head_dim, config.num_codebooks, use_flexible_linear
  349. )
  350. self.v_proj = MoshiLinear(
  351. self.hidden_size, self.num_key_value_heads * self.head_dim, config.num_codebooks, use_flexible_linear
  352. )
  353. self.o_proj = MoshiLinear(
  354. self.num_heads * self.head_dim, self.hidden_size, config.num_codebooks, use_flexible_linear
  355. )
  356. # rotary embeddings are not used in the depth decoder
  357. self.rotary_emb = None
  358. if use_rope:
  359. self.rope_theta = config.rope_theta
  360. self.rotary_emb = MoshiRotaryEmbedding(config)
  361. @deprecate_kwarg("past_key_value", new_name="past_key_values", version="4.58")
  362. def forward(
  363. self,
  364. hidden_states: torch.Tensor,
  365. attention_mask: Optional[torch.Tensor] = None,
  366. position_ids: Optional[torch.LongTensor] = None,
  367. past_key_values: Optional[Cache] = None,
  368. output_attentions: bool = False,
  369. use_cache: bool = False,
  370. cache_position: Optional[torch.LongTensor] = None,
  371. ) -> tuple[torch.Tensor, Optional[torch.Tensor], Optional[tuple[torch.Tensor]]]:
  372. bsz, q_len, _ = hidden_states.size()
  373. query_states = self.q_proj(hidden_states, cache_position) # Ignore copy
  374. key_states = self.k_proj(hidden_states, cache_position) # Ignore copy
  375. value_states = self.v_proj(hidden_states, cache_position) # Ignore copy
  376. query_states = query_states.view(bsz, q_len, self.num_heads, self.head_dim).transpose(1, 2)
  377. key_states = key_states.view(bsz, q_len, self.num_key_value_heads, self.head_dim).transpose(1, 2)
  378. value_states = value_states.view(bsz, q_len, self.num_key_value_heads, self.head_dim).transpose(1, 2)
  379. if self.rotary_emb is not None: # Ignore copy
  380. cos, sin = self.rotary_emb(value_states, position_ids) # Ignore copy
  381. query_states, key_states = apply_rotary_pos_emb(query_states, key_states, cos, sin) # Ignore copy
  382. if past_key_values is not None:
  383. # sin and cos are specific to RoPE models; cache_position needed for the static cache
  384. cache_kwargs = (
  385. {"sin": sin, "cos": cos, "cache_position": cache_position}
  386. if self.rotary_emb is not None
  387. else {"cache_position": cache_position}
  388. ) # Ignore copy
  389. key_states, value_states = past_key_values.update(key_states, value_states, self.layer_idx, cache_kwargs)
  390. key_states = repeat_kv(key_states, self.num_key_value_groups)
  391. value_states = repeat_kv(value_states, self.num_key_value_groups)
  392. attn_weights = torch.matmul(query_states, key_states.transpose(2, 3)) * self.scaling
  393. if attention_mask is not None: # no matter the length, we just slice it
  394. causal_mask = attention_mask[:, :, :, : key_states.shape[-2]]
  395. attn_weights = attn_weights + causal_mask
  396. # upcast attention to fp32
  397. attn_weights = nn.functional.softmax(attn_weights, dim=-1, dtype=torch.float32).to(query_states.dtype)
  398. attn_weights = nn.functional.dropout(attn_weights, p=self.attention_dropout, training=self.training)
  399. attn_output = torch.matmul(attn_weights, value_states)
  400. if attn_output.size() != (bsz, self.num_heads, q_len, self.head_dim):
  401. raise ValueError(
  402. f"`attn_output` should be of size {(bsz, self.num_heads, q_len, self.head_dim)}, but is"
  403. f" {attn_output.size()}"
  404. )
  405. attn_output = attn_output.transpose(1, 2).contiguous()
  406. attn_output = attn_output.view(bsz, q_len, -1)
  407. attn_output = self.o_proj(attn_output, cache_position) # Ignore copy
  408. if not output_attentions:
  409. attn_weights = None
  410. return attn_output, attn_weights
  411. # NO LONGER EXIST Copied from transformers.models.gemma.modeling_gemma.GemmaFlashAttention2 with Gemma->Moshi
  412. # TODO cyril: modular
  413. class MoshiFlashAttention2(MoshiAttention):
  414. """
  415. Moshi flash attention module. This module inherits from `MoshiAttention` as the weights of the module stays
  416. untouched. The only required change would be on the forward pass where it needs to correctly call the public API of
  417. flash attention and deal with padding tokens in case the input contains any of them.
  418. """
  419. def __init__(self, *args, **kwargs):
  420. super().__init__(*args, **kwargs)
  421. # TODO: Should be removed once Flash Attention for RoCm is bumped to 2.1.
  422. # flash_attn<2.1 generates top-left aligned causal mask, while what is needed here is bottom-right alignment, that was made default for flash_attn>=2.1. This attribute is used to handle this difference. Reference: https://github.com/Dao-AILab/flash-attention/releases/tag/v2.1.0.
  423. # Beware that with flash_attn<2.1, using q_seqlen != k_seqlen (except for the case q_seqlen == 1) produces a wrong mask (top-left).
  424. self._flash_attn_uses_top_left_mask = flash_attn_supports_top_left_mask()
  425. @deprecate_kwarg("past_key_value", new_name="past_key_values", version="4.58")
  426. def forward(
  427. self,
  428. hidden_states: torch.Tensor,
  429. attention_mask: Optional[torch.LongTensor] = None,
  430. position_ids: Optional[torch.LongTensor] = None,
  431. past_key_values: Optional[Cache] = None,
  432. output_attentions: bool = False,
  433. use_cache: bool = False,
  434. cache_position: Optional[torch.LongTensor] = None,
  435. ) -> tuple[torch.Tensor, Optional[torch.Tensor], Optional[tuple[torch.Tensor]]]:
  436. if isinstance(past_key_values, StaticCache):
  437. raise ValueError(
  438. "`static` cache implementation is not compatible with `attn_implementation==flash_attention_2` "
  439. "make sure to use `sdpa` in the mean time, and open an issue at https://github.com/huggingface/transformers"
  440. )
  441. output_attentions = False
  442. bsz, q_len, _ = hidden_states.size()
  443. query_states = self.q_proj(hidden_states, cache_position) # Ignore copy
  444. key_states = self.k_proj(hidden_states, cache_position) # Ignore copy
  445. value_states = self.v_proj(hidden_states, cache_position) # Ignore copy
  446. # Flash attention requires the input to have the shape
  447. # batch_size x seq_length x head_dim x hidden_dim
  448. # therefore we just need to keep the original shape
  449. query_states = query_states.view(bsz, q_len, self.num_heads, self.head_dim).transpose(1, 2)
  450. key_states = key_states.view(bsz, q_len, self.num_key_value_heads, self.head_dim).transpose(1, 2)
  451. value_states = value_states.view(bsz, q_len, self.num_key_value_heads, self.head_dim).transpose(1, 2)
  452. if self.rotary_emb is not None: # Ignore copy
  453. cos, sin = self.rotary_emb(value_states, position_ids) # Ignore copy
  454. query_states, key_states = apply_rotary_pos_emb(query_states, key_states, cos, sin) # Ignore copy
  455. if past_key_values is not None:
  456. # sin and cos are specific to RoPE models; cache_position needed for the static cache
  457. cache_kwargs = (
  458. {"sin": sin, "cos": cos, "cache_position": cache_position}
  459. if self.rotary_emb is not None
  460. else {"cache_position": cache_position}
  461. ) # Ignore copy
  462. key_states, value_states = past_key_values.update(key_states, value_states, self.layer_idx, cache_kwargs)
  463. # TODO: These transpose are quite inefficient but Flash Attention requires the layout [batch_size, sequence_length, num_heads, head_dim]. We would need to refactor the KV cache
  464. # to be able to avoid many of these transpose/reshape/view.
  465. query_states = query_states.transpose(1, 2)
  466. key_states = key_states.transpose(1, 2)
  467. value_states = value_states.transpose(1, 2)
  468. dropout_rate = self.attention_dropout if self.training else 0.0
  469. # In PEFT, usually we cast the layer norms in float32 for training stability reasons
  470. # therefore the input hidden states gets silently casted in float32. Hence, we need
  471. # cast them back in the correct dtype just to be sure everything works as expected.
  472. # This might slowdown training & inference so it is recommended to not cast the LayerNorms
  473. # in fp32. (MoshiRMSNorm handles it correctly)
  474. input_dtype = query_states.dtype
  475. device_type = query_states.device.type if query_states.device.type != "mps" else "cpu"
  476. if input_dtype == torch.float32:
  477. if torch.is_autocast_enabled():
  478. target_dtype = (
  479. torch.get_autocast_dtype(device_type)
  480. if hasattr(torch, "get_autocast_dtype")
  481. else torch.get_autocast_gpu_dtype()
  482. )
  483. # Handle the case where the model is quantized
  484. elif hasattr(self.config, "_pre_quantization_dtype"):
  485. target_dtype = self.config._pre_quantization_dtype
  486. else:
  487. target_dtype = self.q_proj.weight.dtype
  488. logger.warning_once(
  489. f"The input hidden states seems to be silently casted in float32, this might be related to"
  490. f" the fact you have upcasted embedding or layer norm layers in float32. We will cast back the input in"
  491. f" {target_dtype}."
  492. )
  493. query_states = query_states.to(target_dtype)
  494. key_states = key_states.to(target_dtype)
  495. value_states = value_states.to(target_dtype)
  496. attn_output = _flash_attention_forward(
  497. query_states,
  498. key_states,
  499. value_states,
  500. attention_mask,
  501. q_len,
  502. position_ids=position_ids,
  503. dropout=dropout_rate,
  504. sliding_window=getattr(self, "sliding_window", None),
  505. is_causal=self.is_causal,
  506. use_top_left_mask=self._flash_attn_uses_top_left_mask,
  507. )
  508. attn_output = attn_output.reshape(bsz, q_len, -1).contiguous()
  509. attn_output = self.o_proj(attn_output, cache_position) # Ignore copy
  510. if not output_attentions:
  511. attn_weights = None
  512. return attn_output, attn_weights
  513. # NO LONGER EXIST Copied from transformers.models.gemma.modeling_gemma.GemmaSdpaAttention with Gemma->Moshi
  514. # TODO cyril: modular
  515. class MoshiSdpaAttention(MoshiAttention):
  516. """
  517. Moshi attention module using torch.nn.functional.scaled_dot_product_attention. This module inherits from
  518. `MoshiAttention` as the weights of the module stays untouched. The only changes are on the forward pass to adapt to
  519. SDPA API.
  520. """
  521. # Adapted from MoshiAttention.forward
  522. @deprecate_kwarg("past_key_value", new_name="past_key_values", version="4.58")
  523. def forward(
  524. self,
  525. hidden_states: torch.Tensor,
  526. attention_mask: Optional[torch.Tensor] = None,
  527. position_ids: Optional[torch.LongTensor] = None,
  528. past_key_values: Optional[Cache] = None,
  529. output_attentions: bool = False,
  530. use_cache: bool = False,
  531. cache_position: Optional[torch.LongTensor] = None,
  532. **kwargs,
  533. ) -> tuple[torch.Tensor, Optional[torch.Tensor], Optional[tuple[torch.Tensor]]]:
  534. if output_attentions:
  535. # TODO: Improve this warning with e.g. `model.config.attn_implementation = "manual"` once this is implemented.
  536. logger.warning_once(
  537. "MoshiModel is using MoshiSdpaAttention, but `torch.nn.functional.scaled_dot_product_attention` does not support `output_attentions=True`. Falling back to the manual attention implementation, "
  538. 'but specifying the manual implementation will be required from Transformers version v5.0.0 onwards. This warning can be removed using the argument `attn_implementation="eager"` when loading the model.'
  539. )
  540. return super().forward(
  541. hidden_states=hidden_states,
  542. attention_mask=attention_mask,
  543. position_ids=position_ids,
  544. past_key_values=past_key_values,
  545. output_attentions=output_attentions,
  546. use_cache=use_cache,
  547. cache_position=cache_position,
  548. )
  549. bsz, q_len, _ = hidden_states.size()
  550. query_states = self.q_proj(hidden_states, cache_position) # Ignore copy
  551. key_states = self.k_proj(hidden_states, cache_position) # Ignore copy
  552. value_states = self.v_proj(hidden_states, cache_position) # Ignore copy
  553. query_states = query_states.view(bsz, q_len, self.num_heads, self.head_dim).transpose(1, 2)
  554. key_states = key_states.view(bsz, q_len, self.num_key_value_heads, self.head_dim).transpose(1, 2)
  555. value_states = value_states.view(bsz, q_len, self.num_key_value_heads, self.head_dim).transpose(1, 2)
  556. if self.rotary_emb is not None: # Ignore copy
  557. cos, sin = self.rotary_emb(value_states, position_ids) # Ignore copy
  558. query_states, key_states = apply_rotary_pos_emb(query_states, key_states, cos, sin) # Ignore copy
  559. if past_key_values is not None:
  560. # sin and cos are specific to RoPE models; cache_position needed for the static cache
  561. cache_kwargs = (
  562. {"sin": sin, "cos": cos, "cache_position": cache_position}
  563. if self.rotary_emb is not None
  564. else {"cache_position": cache_position}
  565. ) # Ignore copy
  566. key_states, value_states = past_key_values.update(key_states, value_states, self.layer_idx, cache_kwargs)
  567. key_states = repeat_kv(key_states, self.num_key_value_groups)
  568. value_states = repeat_kv(value_states, self.num_key_value_groups)
  569. causal_mask = attention_mask
  570. if attention_mask is not None:
  571. causal_mask = causal_mask[:, :, :, : key_states.shape[-2]]
  572. # SDPA with memory-efficient backend is currently (torch==2.1.2) bugged with non-contiguous inputs with custom attn_mask,
  573. # Reference: https://github.com/pytorch/pytorch/issues/112577.
  574. if query_states.device.type == "cuda" and causal_mask is not None:
  575. query_states = query_states.contiguous()
  576. key_states = key_states.contiguous()
  577. value_states = value_states.contiguous()
  578. # We dispatch to SDPA's Flash Attention or Efficient kernels via this `is_causal` if statement instead of an inline conditional assignment
  579. # in SDPA to support both torch.compile's dynamic shapes and full graph options. An inline conditional prevents dynamic shapes from compiling.
  580. is_causal = causal_mask is None and q_len > 1
  581. attn_output = torch.nn.functional.scaled_dot_product_attention(
  582. query_states,
  583. key_states,
  584. value_states,
  585. attn_mask=causal_mask,
  586. dropout_p=self.attention_dropout if self.training else 0.0,
  587. is_causal=is_causal,
  588. )
  589. attn_output = attn_output.transpose(1, 2).contiguous()
  590. attn_output = attn_output.view(bsz, q_len, -1)
  591. attn_output = self.o_proj(attn_output, cache_position) # Ignore copy
  592. return attn_output, None
  593. MOSHI_ATTENTION_CLASSES = {
  594. "eager": MoshiAttention,
  595. "flash_attention_2": MoshiFlashAttention2,
  596. "sdpa": MoshiSdpaAttention,
  597. }
  598. class MoshiDecoderLayer(GradientCheckpointingLayer):
  599. def __init__(self, config: MoshiConfig, layer_idx: int, use_flexible_linear: bool, use_rope=True):
  600. super().__init__()
  601. self.hidden_size = config.hidden_size
  602. self.use_flexible_linear = use_flexible_linear
  603. self.self_attn = MOSHI_ATTENTION_CLASSES[config._attn_implementation](
  604. config=config, layer_idx=layer_idx, use_flexible_linear=use_flexible_linear, use_rope=use_rope
  605. )
  606. self.mlp = MoshiGatingMLP(config, use_flexible_linear)
  607. self.input_layernorm = MoshiRMSNorm(self.hidden_size, eps=config.rms_norm_eps)
  608. self.post_attention_layernorm = MoshiRMSNorm(self.hidden_size, eps=config.rms_norm_eps)
  609. self.sliding_window = config.sliding_window
  610. self._attn_implementation = config._attn_implementation
  611. @deprecate_kwarg("past_key_value", new_name="past_key_values", version="4.58")
  612. def forward(
  613. self,
  614. hidden_states: torch.Tensor,
  615. attention_mask: Optional[torch.Tensor] = None,
  616. position_ids: Optional[torch.LongTensor] = None,
  617. past_key_values: Optional[Cache] = None,
  618. output_attentions: Optional[bool] = False,
  619. use_cache: Optional[bool] = False,
  620. cache_position: Optional[torch.LongTensor] = None,
  621. **kwargs,
  622. ) -> tuple[torch.FloatTensor, Optional[tuple[torch.FloatTensor, torch.FloatTensor]]]:
  623. """
  624. Args:
  625. hidden_states (`torch.FloatTensor`): input to the layer of shape `(batch, seq_len, embed_dim)`
  626. attention_mask (`torch.FloatTensor`, *optional*):
  627. attention mask of size `(batch_size, sequence_length)` if flash attention is used or `(batch_size, 1,
  628. query_sequence_length, key_sequence_length)` if default attention is used.
  629. output_attentions (`bool`, *optional*):
  630. Whether or not to return the attentions tensors of all attention layers. See `attentions` under
  631. returned tensors for more detail.
  632. use_cache (`bool`, *optional*):
  633. If set to `True`, `past_key_values` key value states are returned and can be used to speed up decoding
  634. (see `past_key_values`).
  635. past_key_values (`Cache`, *optional*): cached past key and value projection states
  636. cache_position (`torch.LongTensor` of shape `(sequence_length)`, *optional*):
  637. Indices depicting the position of the input sequence tokens in the sequence
  638. kwargs (`dict`, *optional*):
  639. Arbitrary kwargs to be ignored, used for FSDP and other methods that injects code
  640. into the model
  641. """
  642. residual = hidden_states
  643. hidden_states = self.input_layernorm(hidden_states)
  644. # Self Attention
  645. hidden_states, self_attn_weights = self.self_attn(
  646. hidden_states=hidden_states,
  647. attention_mask=attention_mask,
  648. position_ids=position_ids,
  649. past_key_values=past_key_values,
  650. output_attentions=output_attentions,
  651. use_cache=use_cache,
  652. cache_position=cache_position,
  653. **kwargs,
  654. )
  655. hidden_states = residual + hidden_states
  656. # Fully Connected
  657. residual = hidden_states
  658. hidden_states = self.post_attention_layernorm(hidden_states)
  659. hidden_states = (
  660. self.mlp(hidden_states) if not self.use_flexible_linear else self.mlp(hidden_states, cache_position)
  661. )
  662. hidden_states = residual + hidden_states
  663. outputs = (hidden_states,)
  664. if output_attentions:
  665. outputs += (self_attn_weights,)
  666. return outputs
  667. @auto_docstring
  668. class MoshiPreTrainedModel(PreTrainedModel):
  669. config: MoshiConfig
  670. base_model_prefix = "model"
  671. supports_gradient_checkpointing = True
  672. _no_split_modules = ["MoshiDecoderLayer", "MimiTransformerLayer"]
  673. _supports_flash_attn = True
  674. _supports_sdpa = True
  675. main_input_name = "input_ids"
  676. def _init_weights(self, module):
  677. std = self.config.initializer_range
  678. if isinstance(module, nn.Linear):
  679. module.weight.data.normal_(mean=0.0, std=std)
  680. if module.bias is not None:
  681. module.bias.data.zero_()
  682. elif isinstance(module, MoshiFlexibleLinear):
  683. module.weight.data.normal_()
  684. elif isinstance(module, nn.Embedding):
  685. module.weight.data.normal_(mean=0.0, std=std)
  686. if module.padding_idx is not None:
  687. module.weight.data[module.padding_idx].zero_()
  688. elif isinstance(module, MoshiRMSNorm):
  689. module.weight.data.fill_(1.0)
  690. class MoshiDepthDecoder(MoshiPreTrainedModel, GenerationMixin):
  691. """
  692. Transformer depth decoder consisting of *config.num_hidden_layers* layers. Each layer is a [`MoshiTransformerLayer`]
  693. Args:
  694. config: MoshiConfig
  695. """
  696. config: MoshiDepthConfig
  697. def __init__(self, config: MoshiDepthConfig):
  698. super().__init__(config)
  699. self.text_embed_tokens = nn.Embedding(config.vocab_size + 1, config.hidden_size)
  700. # the last codebook is never used as input
  701. self.embed_tokens = nn.ModuleList(
  702. [nn.Embedding(config.audio_vocab_size + 1, config.hidden_size) for _ in range(config.num_codebooks - 1)]
  703. )
  704. self.input_projections = MoshiFlexibleLinear(config.input_size, config.hidden_size, config.num_codebooks)
  705. self.layers = nn.ModuleList(
  706. [
  707. MoshiDecoderLayer(config, layer_idx, use_flexible_linear=True, use_rope=False)
  708. for layer_idx in range(config.num_hidden_layers)
  709. ]
  710. )
  711. self.lm_heads = MoshiFlexibleLinear(config.hidden_size, config.audio_vocab_size, config.num_codebooks)
  712. self._attn_implementation = config._attn_implementation
  713. self.gradient_checkpointing = False
  714. self.config = config
  715. def forward(
  716. self,
  717. input_ids: Optional[torch.LongTensor] = None,
  718. last_hidden_state: Optional[torch.LongTensor] = None,
  719. attention_mask: Optional[torch.BoolTensor] = None,
  720. past_key_values: Optional[Cache] = None,
  721. inputs_embeds: Optional[torch.FloatTensor] = None,
  722. use_cache: Optional[bool] = None,
  723. output_attentions: Optional[bool] = None,
  724. output_hidden_states: Optional[bool] = None,
  725. return_dict: Optional[bool] = None,
  726. position_ids: Optional[torch.LongTensor] = None,
  727. labels: Optional[torch.LongTensor] = None,
  728. cache_position: Optional[torch.LongTensor] = None,
  729. ) -> Union[tuple, BaseModelOutputWithPast]:
  730. """
  731. Args:
  732. input_ids (`torch.LongTensor` of shape `(batch_size, sequence_length)`):
  733. Indices of input sequence tokens. The first element of the sequence must the text token associated to the audio codebooks.
  734. The rest of the elements must be flatten audio codebooks. The `cache_position` argument can be used to indicate to which index is associated each token.
  735. last_hidden_state (`torch.FloatTensor` of shape `(batch_size, sequence_length, hidden_size)`):
  736. Sequence of hidden-states at the output of the last layer of the main decoder. Used to contextualize `input_ids`
  737. attention_mask (`torch.Tensor` of shape `(batch_size, sequence_length)`, *optional*):
  738. Mask to avoid performing attention on padding token indices. Mask values selected in `[0, 1]`:
  739. - 1 for tokens that are **not masked**,
  740. - 0 for tokens that are **masked**.
  741. [What are attention masks?](../glossary#attention-mask)
  742. Indices can be obtained using [`AutoTokenizer`]. See [`PreTrainedTokenizer.encode`] and
  743. [`PreTrainedTokenizer.__call__`] for details.
  744. If `past_key_values` is used, optionally only the last `input_ids` have to be input (see
  745. `past_key_values`).
  746. If you want to change padding behavior, you should read [`modeling_opt._prepare_decoder_attention_mask`]
  747. and modify to your needs. See diagram 1 in [the paper](https://huggingface.co/papers/1910.13461) for more
  748. information on the default strategy.
  749. - 1 indicates the head is **not masked**,
  750. - 0 indicates the head is **masked**.
  751. past_key_values (`Cache`, *optional*):
  752. It is a [`~cache_utils.Cache`] instance. For more details, see our [kv cache guide](https://huggingface.co/docs/transformers/en/kv_cache).
  753. If `past_key_values` are used, the user can optionally input only the last `input_ids` (those that don't
  754. have their past key value states given to this model) of shape `(batch_size, 1)` instead of all `input_ids`
  755. of shape `(batch_size, sequence_length)`.
  756. inputs_embeds (`torch.FloatTensor` of shape `(batch_size, sequence_length, hidden_size)`, *optional*):
  757. Optionally, instead of passing `input_ids` you can choose to directly pass an embedded representation. This
  758. is useful if you want more control over how to convert the inputs into associated vectors than the
  759. model's internal embedding lookup matrix.
  760. use_cache (`bool`, *optional*):
  761. If set to `True`, `past_key_values` key value states are returned and can be used to speed up decoding (see
  762. `past_key_values`).
  763. output_attentions (`bool`, *optional*):
  764. Whether or not to return the attentions tensors of all attention layers. See `attentions` under returned
  765. tensors for more detail.
  766. output_hidden_states (`bool`, *optional*):
  767. Whether or not to return the hidden states of all layers. See `hidden_states` under returned tensors for
  768. more detail.
  769. return_dict (`bool`, *optional*):
  770. Whether or not to return a [`~utils.ModelOutput`] instead of a plain tuple.
  771. position_ids (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*):
  772. Indices of positions of each input sequence tokens in the position embeddings. Selected in the range `[0,
  773. config.n_positions - 1]`.
  774. [What are position IDs?](../glossary#position-ids)
  775. labels (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*):
  776. Labels for computing the masked language modeling loss. Indices should either be in `[0, ...,
  777. config.vocab_size]` or -100 (see `input_ids` docstring). Tokens with indices set to `-100` are ignored
  778. (masked), the loss is only computed for the tokens with labels in `[0, ..., config.vocab_size]`.
  779. cache_position (`torch.Tensor`):
  780. Indices depicting the position of the input sequence tokens in the sequence.
  781. """
  782. output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
  783. output_hidden_states = (
  784. output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
  785. )
  786. use_cache = use_cache if use_cache is not None else self.config.use_cache
  787. return_dict = return_dict if return_dict is not None else self.config.use_return_dict
  788. if self.gradient_checkpointing and self.training and use_cache:
  789. logger.warning_once(
  790. "`use_cache=True` is incompatible with gradient checkpointing. Setting `use_cache=False`."
  791. )
  792. use_cache = False
  793. if use_cache and past_key_values is None and not self.training:
  794. past_key_values = DynamicCache.from_legacy_cache(past_key_values)
  795. past_seen_tokens = 0 if past_key_values is None else past_key_values.get_seq_length()
  796. if cache_position is None:
  797. cache_position = torch.arange(
  798. past_seen_tokens, past_seen_tokens + input_ids.shape[1], device=input_ids.device
  799. )
  800. if position_ids is None:
  801. position_ids = cache_position.unsqueeze(0)
  802. # If inputs_embeds is provided, it has the priority over input_ids, which won't be used
  803. if inputs_embeds is None:
  804. inputs_embeds = []
  805. for position_idx in cache_position:
  806. position_idx = position_idx.item()
  807. if position_idx == 0:
  808. inputs_embeds.append(self.text_embed_tokens(input_ids[:, [position_idx]]))
  809. else:
  810. inputs_embeds.append(
  811. self.embed_tokens[(position_idx - 1)](input_ids[:, [position_idx - past_seen_tokens]])
  812. )
  813. inputs_embeds = torch.cat(inputs_embeds, dim=1)
  814. inputs_embeds += self.input_projections(last_hidden_state, cache_position)
  815. causal_mask = None
  816. if attention_mask is not None:
  817. causal_mask = self._update_causal_mask(
  818. attention_mask, inputs_embeds, cache_position, past_key_values, output_attentions
  819. )
  820. # decoder layers
  821. all_hidden_states = () if output_hidden_states else None
  822. all_self_attns = () if output_attentions else None
  823. hidden_states = inputs_embeds
  824. for decoder_layer in self.layers:
  825. if output_hidden_states:
  826. all_hidden_states += (hidden_states,)
  827. layer_outputs = decoder_layer(
  828. hidden_states,
  829. attention_mask=causal_mask,
  830. position_ids=position_ids,
  831. past_key_values=past_key_values,
  832. output_attentions=output_attentions,
  833. use_cache=use_cache,
  834. cache_position=cache_position,
  835. )
  836. hidden_states = layer_outputs[0]
  837. if output_attentions:
  838. all_self_attns += (layer_outputs[1],)
  839. # add hidden states from the last decoder layer
  840. if output_hidden_states:
  841. all_hidden_states += (hidden_states,)
  842. logits = self.lm_heads(hidden_states, cache_position)
  843. loss = None
  844. if labels is not None:
  845. # Upcast to float if we need to compute the loss to avoid potential precision issues
  846. logits = logits.float()
  847. loss_fct = CrossEntropyLoss()
  848. labels = labels.masked_fill(labels == self.config.audio_vocab_size, -100).reshape(-1)
  849. # Enable model parallelism
  850. labels = labels.to(logits.device)
  851. loss = loss_fct(logits.reshape(-1, self.config.audio_vocab_size), labels)
  852. if not return_dict:
  853. return tuple(
  854. v for v in [loss, logits, past_key_values, all_hidden_states, all_self_attns] if v is not None
  855. )
  856. return CausalLMOutputWithPast(
  857. loss=loss,
  858. logits=logits,
  859. past_key_values=past_key_values,
  860. hidden_states=past_key_values,
  861. attentions=all_self_attns,
  862. )
  863. # Copied from transformers.models.phimoe.modeling_phimoe.PhimoeModel._update_causal_mask with Phimoe->Moshi
  864. def _update_causal_mask(
  865. self,
  866. attention_mask: Union[torch.Tensor, "BlockMask"],
  867. input_tensor: torch.Tensor,
  868. cache_position: torch.Tensor,
  869. past_key_values: Cache,
  870. output_attentions: bool = False,
  871. ):
  872. if self.config._attn_implementation == "flash_attention_2":
  873. if attention_mask is not None and past_key_values is not None:
  874. is_padding_right = attention_mask[:, -1].sum().item() != input_tensor.size()[0]
  875. if is_padding_right:
  876. raise ValueError(
  877. "You are attempting to perform batched generation with padding_side='right'"
  878. " this may lead to unexpected behaviour for Flash Attention version of Moshi. Make sure to "
  879. " call `tokenizer.padding_side = 'left'` before tokenizing the input. "
  880. )
  881. if attention_mask is not None and 0.0 in attention_mask:
  882. return attention_mask
  883. return None
  884. if self.config._attn_implementation == "flex_attention":
  885. if isinstance(attention_mask, torch.Tensor):
  886. attention_mask = make_flex_block_causal_mask(attention_mask)
  887. return attention_mask
  888. # For SDPA, when possible, we will rely on its `is_causal` argument instead of its `attn_mask` argument, in
  889. # order to dispatch on Flash Attention 2. This feature is not compatible with static cache, as SDPA will fail
  890. # to infer the attention mask.
  891. past_seen_tokens = past_key_values.get_seq_length() if past_key_values is not None else 0
  892. using_static_cache = isinstance(past_key_values, StaticCache)
  893. # When output attentions is True, sdpa implementation's forward method calls the eager implementation's forward
  894. if self.config._attn_implementation == "sdpa" and not using_static_cache and not output_attentions:
  895. if AttentionMaskConverter._ignore_causal_mask_sdpa(
  896. attention_mask,
  897. inputs_embeds=input_tensor,
  898. past_key_values_length=past_seen_tokens,
  899. sliding_window=self.config.sliding_window,
  900. is_training=self.training,
  901. ):
  902. return None
  903. dtype = input_tensor.dtype
  904. min_dtype = torch.finfo(dtype).min
  905. sequence_length = input_tensor.shape[1]
  906. # StaticCache
  907. if using_static_cache:
  908. target_length = past_key_values.get_max_cache_shape()
  909. # DynamicCache or no cache
  910. else:
  911. target_length = (
  912. attention_mask.shape[-1]
  913. if isinstance(attention_mask, torch.Tensor)
  914. else past_seen_tokens + sequence_length + 1
  915. )
  916. # In case the provided `attention` mask is 2D, we generate a causal mask here (4D).
  917. causal_mask = self._prepare_4d_causal_attention_mask_with_cache_position(
  918. attention_mask,
  919. sequence_length=sequence_length,
  920. target_length=target_length,
  921. dtype=dtype,
  922. cache_position=cache_position,
  923. batch_size=input_tensor.shape[0],
  924. config=self.config,
  925. past_key_values=past_key_values,
  926. )
  927. if (
  928. self.config._attn_implementation == "sdpa"
  929. and attention_mask is not None
  930. and attention_mask.device.type in ["cuda", "xpu", "npu"]
  931. and not output_attentions
  932. ):
  933. # Attend to all tokens in fully masked rows in the causal_mask, for example the relevant first rows when
  934. # using left padding. This is required by F.scaled_dot_product_attention memory-efficient attention path.
  935. # Details: https://github.com/pytorch/pytorch/issues/110213
  936. causal_mask = AttentionMaskConverter._unmask_unattended(causal_mask, min_dtype)
  937. return causal_mask
  938. @staticmethod
  939. # Copied from transformers.models.phimoe.modeling_phimoe.PhimoeModel._prepare_4d_causal_attention_mask_with_cache_position with Phimoe->MoshiDepth
  940. def _prepare_4d_causal_attention_mask_with_cache_position(
  941. attention_mask: torch.Tensor,
  942. sequence_length: int,
  943. target_length: int,
  944. dtype: torch.dtype,
  945. cache_position: torch.Tensor,
  946. batch_size: int,
  947. config: MoshiDepthConfig,
  948. past_key_values: Cache,
  949. ):
  950. """
  951. Creates a causal 4D mask of shape `(batch_size, 1, query_length, key_value_length)` from a 2D mask of shape
  952. `(batch_size, key_value_length)`, or if the input `attention_mask` is already 4D, do nothing.
  953. Args:
  954. attention_mask (`torch.Tensor`):
  955. A 2D attention mask of shape `(batch_size, key_value_length)` or a 4D attention mask of shape `(batch_size, 1, query_length, key_value_length)`.
  956. sequence_length (`int`):
  957. The sequence length being processed.
  958. target_length (`int`):
  959. The target length: when generating with static cache, the mask should be as long as the static cache, to account for the 0 padding, the part of the cache that is not filled yet.
  960. dtype (`torch.dtype`):
  961. The dtype to use for the 4D attention mask.
  962. cache_position (`torch.Tensor`):
  963. Indices depicting the position of the input sequence tokens in the sequence.
  964. batch_size (`torch.Tensor`):
  965. Batch size.
  966. config (`MoshiDepthConfig`):
  967. The model's configuration class
  968. past_key_values (`Cache`):
  969. The cache class that is being used currently to generate
  970. """
  971. if attention_mask is not None and attention_mask.dim() == 4:
  972. # In this case we assume that the mask comes already in inverted form and requires no inversion or slicing.
  973. causal_mask = attention_mask
  974. else:
  975. min_dtype = torch.finfo(dtype).min
  976. causal_mask = torch.full(
  977. (sequence_length, target_length), fill_value=min_dtype, dtype=dtype, device=cache_position.device
  978. )
  979. diagonal_attend_mask = torch.arange(target_length, device=cache_position.device) > cache_position.reshape(
  980. -1, 1
  981. )
  982. text_config = config.get_text_config()
  983. if getattr(text_config, "use_sliding_window", True) and text_config.sliding_window is not None:
  984. # if we have sliding window, we should not attend to tokens beyond sliding window length, so we mask them out also
  985. # the check is needed to verify is current checkpoint was trained with sliding window or not
  986. is_static_sliding_cache = isinstance(past_key_values, StaticCache) and all(past_key_values.is_sliding)
  987. if not is_static_sliding_cache or sequence_length > target_length:
  988. sliding_attend_mask = torch.arange(target_length, device=cache_position.device) <= (
  989. cache_position.reshape(-1, 1) - text_config.sliding_window
  990. )
  991. diagonal_attend_mask.bitwise_or_(sliding_attend_mask)
  992. causal_mask *= diagonal_attend_mask
  993. causal_mask = causal_mask[None, None, :, :].expand(batch_size, 1, -1, -1)
  994. if attention_mask is not None:
  995. causal_mask = causal_mask.clone() # copy to contiguous memory for in-place edit
  996. if attention_mask.shape[-1] > target_length:
  997. attention_mask = attention_mask[:, :target_length]
  998. mask_length = attention_mask.shape[-1]
  999. padding_mask = causal_mask[:, :, :, :mask_length] + attention_mask[:, None, None, :].to(
  1000. causal_mask.device
  1001. )
  1002. padding_mask = padding_mask == 0
  1003. causal_mask[:, :, :, :mask_length] = causal_mask[:, :, :, :mask_length].masked_fill(
  1004. padding_mask, min_dtype
  1005. )
  1006. return causal_mask
  1007. @auto_docstring
  1008. class MoshiModel(MoshiPreTrainedModel):
  1009. def __init__(self, config: MoshiConfig):
  1010. super().__init__(config)
  1011. self.padding_idx = config.pad_token_id
  1012. self.vocab_size = config.vocab_size
  1013. self.embed_tokens = nn.Embedding(config.vocab_size + 1, config.hidden_size, self.padding_idx)
  1014. self.layers = nn.ModuleList(
  1015. [
  1016. MoshiDecoderLayer(config, layer_idx, use_flexible_linear=False)
  1017. for layer_idx in range(config.num_hidden_layers)
  1018. ]
  1019. )
  1020. self.norm = MoshiRMSNorm(config.hidden_size, eps=config.rms_norm_eps)
  1021. self.gradient_checkpointing = False
  1022. # Initialize weights and apply final processing
  1023. self.post_init()
  1024. @auto_docstring
  1025. def forward(
  1026. self,
  1027. input_ids: Optional[torch.LongTensor] = None,
  1028. attention_mask: Optional[torch.Tensor] = None,
  1029. position_ids: Optional[torch.LongTensor] = None,
  1030. past_key_values: Optional[Union[Cache, list[torch.FloatTensor]]] = None,
  1031. inputs_embeds: Optional[torch.FloatTensor] = None,
  1032. use_cache: Optional[bool] = None,
  1033. output_attentions: Optional[bool] = None,
  1034. output_hidden_states: Optional[bool] = None,
  1035. return_dict: Optional[bool] = None,
  1036. cache_position: Optional[torch.LongTensor] = None,
  1037. ) -> Union[tuple, BaseModelOutputWithPast]:
  1038. output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
  1039. output_hidden_states = (
  1040. output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
  1041. )
  1042. use_cache = use_cache if use_cache is not None else self.config.use_cache
  1043. return_dict = return_dict if return_dict is not None else self.config.use_return_dict
  1044. if self.gradient_checkpointing and self.training and use_cache:
  1045. logger.warning_once(
  1046. "`use_cache=True` is incompatible with gradient checkpointing. Setting `use_cache=False`."
  1047. )
  1048. use_cache = False
  1049. if inputs_embeds is None:
  1050. inputs_embeds = self.embed_tokens(input_ids)
  1051. if cache_position is None:
  1052. past_seen_tokens = past_key_values.get_seq_length() if past_key_values is not None else 0
  1053. cache_position = torch.arange(
  1054. past_seen_tokens, past_seen_tokens + inputs_embeds.shape[1], device=inputs_embeds.device
  1055. )
  1056. if position_ids is None:
  1057. position_ids = cache_position.unsqueeze(0)
  1058. causal_mask = None
  1059. if attention_mask is not None:
  1060. causal_mask = self._update_causal_mask(
  1061. attention_mask, inputs_embeds, cache_position, past_key_values, output_attentions
  1062. )
  1063. # embed positions
  1064. hidden_states = inputs_embeds
  1065. if use_cache and past_key_values is None:
  1066. past_key_values = DynamicCache(config=self.config)
  1067. # decoder layers
  1068. all_hidden_states = () if output_hidden_states else None
  1069. all_self_attns = () if output_attentions else None
  1070. for decoder_layer in self.layers:
  1071. if output_hidden_states:
  1072. all_hidden_states += (hidden_states,)
  1073. layer_outputs = decoder_layer(
  1074. hidden_states,
  1075. attention_mask=causal_mask,
  1076. position_ids=position_ids,
  1077. past_key_values=past_key_values,
  1078. output_attentions=output_attentions,
  1079. use_cache=use_cache,
  1080. cache_position=cache_position,
  1081. )
  1082. hidden_states = layer_outputs[0]
  1083. if output_attentions:
  1084. all_self_attns += (layer_outputs[1],)
  1085. hidden_states = self.norm(hidden_states)
  1086. # add hidden states from the last decoder layer
  1087. if output_hidden_states:
  1088. all_hidden_states += (hidden_states,)
  1089. if not return_dict:
  1090. return tuple(
  1091. v for v in [hidden_states, past_key_values, all_hidden_states, all_self_attns] if v is not None
  1092. )
  1093. return BaseModelOutputWithPast(
  1094. last_hidden_state=hidden_states,
  1095. past_key_values=past_key_values,
  1096. hidden_states=all_hidden_states,
  1097. attentions=all_self_attns,
  1098. )
  1099. # Copied from transformers.models.phimoe.modeling_phimoe.PhimoeModel._update_causal_mask with Phimoe->Moshi
  1100. def _update_causal_mask(
  1101. self,
  1102. attention_mask: Union[torch.Tensor, "BlockMask"],
  1103. input_tensor: torch.Tensor,
  1104. cache_position: torch.Tensor,
  1105. past_key_values: Cache,
  1106. output_attentions: bool = False,
  1107. ):
  1108. if self.config._attn_implementation == "flash_attention_2":
  1109. if attention_mask is not None and past_key_values is not None:
  1110. is_padding_right = attention_mask[:, -1].sum().item() != input_tensor.size()[0]
  1111. if is_padding_right:
  1112. raise ValueError(
  1113. "You are attempting to perform batched generation with padding_side='right'"
  1114. " this may lead to unexpected behaviour for Flash Attention version of Moshi. Make sure to "
  1115. " call `tokenizer.padding_side = 'left'` before tokenizing the input. "
  1116. )
  1117. if attention_mask is not None and 0.0 in attention_mask:
  1118. return attention_mask
  1119. return None
  1120. if self.config._attn_implementation == "flex_attention":
  1121. if isinstance(attention_mask, torch.Tensor):
  1122. attention_mask = make_flex_block_causal_mask(attention_mask)
  1123. return attention_mask
  1124. # For SDPA, when possible, we will rely on its `is_causal` argument instead of its `attn_mask` argument, in
  1125. # order to dispatch on Flash Attention 2. This feature is not compatible with static cache, as SDPA will fail
  1126. # to infer the attention mask.
  1127. past_seen_tokens = past_key_values.get_seq_length() if past_key_values is not None else 0
  1128. using_static_cache = isinstance(past_key_values, StaticCache)
  1129. # When output attentions is True, sdpa implementation's forward method calls the eager implementation's forward
  1130. if self.config._attn_implementation == "sdpa" and not using_static_cache and not output_attentions:
  1131. if AttentionMaskConverter._ignore_causal_mask_sdpa(
  1132. attention_mask,
  1133. inputs_embeds=input_tensor,
  1134. past_key_values_length=past_seen_tokens,
  1135. sliding_window=self.config.sliding_window,
  1136. is_training=self.training,
  1137. ):
  1138. return None
  1139. dtype = input_tensor.dtype
  1140. min_dtype = torch.finfo(dtype).min
  1141. sequence_length = input_tensor.shape[1]
  1142. # StaticCache
  1143. if using_static_cache:
  1144. target_length = past_key_values.get_max_cache_shape()
  1145. # DynamicCache or no cache
  1146. else:
  1147. target_length = (
  1148. attention_mask.shape[-1]
  1149. if isinstance(attention_mask, torch.Tensor)
  1150. else past_seen_tokens + sequence_length + 1
  1151. )
  1152. # In case the provided `attention` mask is 2D, we generate a causal mask here (4D).
  1153. causal_mask = self._prepare_4d_causal_attention_mask_with_cache_position(
  1154. attention_mask,
  1155. sequence_length=sequence_length,
  1156. target_length=target_length,
  1157. dtype=dtype,
  1158. cache_position=cache_position,
  1159. batch_size=input_tensor.shape[0],
  1160. config=self.config,
  1161. past_key_values=past_key_values,
  1162. )
  1163. if (
  1164. self.config._attn_implementation == "sdpa"
  1165. and attention_mask is not None
  1166. and attention_mask.device.type in ["cuda", "xpu", "npu"]
  1167. and not output_attentions
  1168. ):
  1169. # Attend to all tokens in fully masked rows in the causal_mask, for example the relevant first rows when
  1170. # using left padding. This is required by F.scaled_dot_product_attention memory-efficient attention path.
  1171. # Details: https://github.com/pytorch/pytorch/issues/110213
  1172. causal_mask = AttentionMaskConverter._unmask_unattended(causal_mask, min_dtype)
  1173. return causal_mask
  1174. @staticmethod
  1175. # Copied from transformers.models.phimoe.modeling_phimoe.PhimoeModel._prepare_4d_causal_attention_mask_with_cache_position with Phimoe->Moshi
  1176. def _prepare_4d_causal_attention_mask_with_cache_position(
  1177. attention_mask: torch.Tensor,
  1178. sequence_length: int,
  1179. target_length: int,
  1180. dtype: torch.dtype,
  1181. cache_position: torch.Tensor,
  1182. batch_size: int,
  1183. config: MoshiConfig,
  1184. past_key_values: Cache,
  1185. ):
  1186. """
  1187. Creates a causal 4D mask of shape `(batch_size, 1, query_length, key_value_length)` from a 2D mask of shape
  1188. `(batch_size, key_value_length)`, or if the input `attention_mask` is already 4D, do nothing.
  1189. Args:
  1190. attention_mask (`torch.Tensor`):
  1191. A 2D attention mask of shape `(batch_size, key_value_length)` or a 4D attention mask of shape `(batch_size, 1, query_length, key_value_length)`.
  1192. sequence_length (`int`):
  1193. The sequence length being processed.
  1194. target_length (`int`):
  1195. The target length: when generating with static cache, the mask should be as long as the static cache, to account for the 0 padding, the part of the cache that is not filled yet.
  1196. dtype (`torch.dtype`):
  1197. The dtype to use for the 4D attention mask.
  1198. cache_position (`torch.Tensor`):
  1199. Indices depicting the position of the input sequence tokens in the sequence.
  1200. batch_size (`torch.Tensor`):
  1201. Batch size.
  1202. config (`MoshiConfig`):
  1203. The model's configuration class
  1204. past_key_values (`Cache`):
  1205. The cache class that is being used currently to generate
  1206. """
  1207. if attention_mask is not None and attention_mask.dim() == 4:
  1208. # In this case we assume that the mask comes already in inverted form and requires no inversion or slicing.
  1209. causal_mask = attention_mask
  1210. else:
  1211. min_dtype = torch.finfo(dtype).min
  1212. causal_mask = torch.full(
  1213. (sequence_length, target_length), fill_value=min_dtype, dtype=dtype, device=cache_position.device
  1214. )
  1215. diagonal_attend_mask = torch.arange(target_length, device=cache_position.device) > cache_position.reshape(
  1216. -1, 1
  1217. )
  1218. text_config = config.get_text_config()
  1219. if getattr(text_config, "use_sliding_window", True) and text_config.sliding_window is not None:
  1220. # if we have sliding window, we should not attend to tokens beyond sliding window length, so we mask them out also
  1221. # the check is needed to verify is current checkpoint was trained with sliding window or not
  1222. is_static_sliding_cache = isinstance(past_key_values, StaticCache) and all(past_key_values.is_sliding)
  1223. if not is_static_sliding_cache or sequence_length > target_length:
  1224. sliding_attend_mask = torch.arange(target_length, device=cache_position.device) <= (
  1225. cache_position.reshape(-1, 1) - text_config.sliding_window
  1226. )
  1227. diagonal_attend_mask.bitwise_or_(sliding_attend_mask)
  1228. causal_mask *= diagonal_attend_mask
  1229. causal_mask = causal_mask[None, None, :, :].expand(batch_size, 1, -1, -1)
  1230. if attention_mask is not None:
  1231. causal_mask = causal_mask.clone() # copy to contiguous memory for in-place edit
  1232. if attention_mask.shape[-1] > target_length:
  1233. attention_mask = attention_mask[:, :target_length]
  1234. mask_length = attention_mask.shape[-1]
  1235. padding_mask = causal_mask[:, :, :, :mask_length] + attention_mask[:, None, None, :].to(
  1236. causal_mask.device
  1237. )
  1238. padding_mask = padding_mask == 0
  1239. causal_mask[:, :, :, :mask_length] = causal_mask[:, :, :, :mask_length].masked_fill(
  1240. padding_mask, min_dtype
  1241. )
  1242. return causal_mask
  1243. @auto_docstring(
  1244. custom_intro="""
  1245. The Moshi decoder model with a text language modelling head on top. Only usable for text.
  1246. """
  1247. )
  1248. class MoshiForCausalLM(MoshiPreTrainedModel, GenerationMixin):
  1249. _tied_weights_keys = ["model.embed_tokens.weight", "lm_head.weight"]
  1250. # Copied from transformers.models.gemma.modeling_gemma.GemmaForCausalLM.__init__ with Gemma->Moshi
  1251. def __init__(self, config):
  1252. super().__init__(config)
  1253. self.model = MoshiModel(config)
  1254. self.vocab_size = config.vocab_size
  1255. self.lm_head = nn.Linear(config.hidden_size, config.vocab_size, bias=False)
  1256. # Initialize weights and apply final processing
  1257. self.post_init()
  1258. @auto_docstring
  1259. def forward(
  1260. self,
  1261. input_ids: Optional[torch.LongTensor] = None,
  1262. attention_mask: Optional[torch.Tensor] = None,
  1263. position_ids: Optional[torch.LongTensor] = None,
  1264. past_key_values: Optional[Union[Cache, list[torch.FloatTensor]]] = None,
  1265. inputs_embeds: Optional[torch.FloatTensor] = None,
  1266. use_cache: Optional[bool] = None,
  1267. output_attentions: Optional[bool] = None,
  1268. output_hidden_states: Optional[bool] = None,
  1269. return_dict: Optional[bool] = None,
  1270. cache_position: Optional[torch.LongTensor] = None,
  1271. labels: Optional[torch.LongTensor] = None,
  1272. logits_to_keep: Union[int, torch.Tensor] = 0,
  1273. **kwargs,
  1274. ) -> Union[tuple, MoshiCausalLMOutputWithPast]:
  1275. r"""
  1276. labels (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*):
  1277. Labels for computing the masked language modeling loss. Indices should either be in `[0, ...,
  1278. config.vocab_size]` or -100 (see `input_ids` docstring). Tokens with indices set to `-100` are ignored
  1279. (masked), the loss is only computed for the tokens with labels in `[0, ..., config.vocab_size]`.
  1280. Example:
  1281. ```python
  1282. >>> from transformers import AutoTokenizer, MoshiForCausalLM
  1283. >>> model = MoshiForCausalLM.from_pretrained("kmhf/hf-moshiko")
  1284. >>> tokenizer = AutoTokenizer.from_pretrained("kmhf/hf-moshiko")
  1285. >>> prompt = "What is your favorite condiment?"
  1286. >>> inputs = tokenizer(prompt, return_tensors="pt")
  1287. >>> # Generate
  1288. >>> generate_ids = model.generate(inputs.input_ids, max_length=30)
  1289. >>> tokenizer.batch_decode(generate_ids, skip_special_tokens=True, clean_up_tokenization_spaces=False)[0]
  1290. "What is your favorite condiment?"
  1291. ```"""
  1292. output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
  1293. output_hidden_states = (
  1294. output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
  1295. )
  1296. return_dict = return_dict if return_dict is not None else self.config.use_return_dict
  1297. # decoder outputs consists of (dec_features, layer_state, dec_hidden, dec_attn)
  1298. outputs = self.model(
  1299. input_ids=input_ids,
  1300. attention_mask=attention_mask,
  1301. position_ids=position_ids,
  1302. past_key_values=past_key_values,
  1303. inputs_embeds=inputs_embeds,
  1304. use_cache=use_cache,
  1305. output_attentions=output_attentions,
  1306. output_hidden_states=output_hidden_states,
  1307. return_dict=return_dict,
  1308. cache_position=cache_position,
  1309. )
  1310. hidden_states = outputs[0]
  1311. # Only compute necessary logits, and do not upcast them to float if we are not computing the loss
  1312. slice_indices = slice(-logits_to_keep, None) if isinstance(logits_to_keep, int) else logits_to_keep
  1313. logits = self.lm_head(hidden_states[:, slice_indices, :])
  1314. loss = None
  1315. if labels is not None:
  1316. # Upcast to float if we need to compute the loss to avoid potential precision issues
  1317. logits = logits.float()
  1318. # Shift so that tokens < n predict n
  1319. shift_logits = logits[..., :-1, :].contiguous()
  1320. shift_labels = labels[..., 1:].contiguous()
  1321. # Flatten the tokens
  1322. shift_logits = shift_logits.view(-1, self.config.vocab_size)
  1323. shift_labels = shift_labels.view(-1)
  1324. # Enable model parallelism
  1325. shift_labels = shift_labels.to(shift_logits.device)
  1326. loss = self.loss_function(
  1327. shift_logits,
  1328. shift_labels,
  1329. vocab_size=self.config.vocab_size,
  1330. **kwargs,
  1331. )
  1332. if not return_dict:
  1333. output = (
  1334. logits,
  1335. hidden_states,
  1336. ) + outputs[1:]
  1337. return (loss,) + output if loss is not None else output
  1338. return MoshiCausalLMOutputWithPast(
  1339. loss=loss,
  1340. logits=logits,
  1341. last_hidden_state=hidden_states, # Ignore copy
  1342. past_key_values=outputs.past_key_values,
  1343. hidden_states=outputs.hidden_states,
  1344. attentions=outputs.attentions,
  1345. )
  1346. @auto_docstring(
  1347. custom_intro="""
  1348. The original Moshi model with an audio encoder, a Moshi depth decoder and a Moshi decoder, for speech-to-speech.
  1349. """
  1350. )
  1351. class MoshiForConditionalGeneration(MoshiPreTrainedModel, GenerationMixin):
  1352. _tied_weights_keys = ["decoder.model.embed_tokens.weight", "decoder.lm_head.weight"]
  1353. config: MoshiConfig
  1354. main_input_name = "input_ids"
  1355. supports_gradient_checkpointing = True
  1356. _supports_flash_attn = True
  1357. _supports_sdpa = True
  1358. def __init__(self, config: MoshiConfig):
  1359. super().__init__(config)
  1360. # We have 2 * num_codebooks audio embedding layers because we have the user input channel and the model output channel.
  1361. self.embed_tokens = nn.ModuleList(
  1362. [nn.Embedding(config.audio_vocab_size + 1, config.hidden_size) for _ in range(2 * config.num_codebooks)]
  1363. )
  1364. self.audio_encoder = AutoModel.from_config(config.audio_encoder_config)
  1365. self.decoder = MoshiForCausalLM(config)
  1366. self.depth_decoder = MoshiDepthDecoder._from_config(config.depth_decoder_config)
  1367. self.num_codebooks = config.num_codebooks
  1368. self.post_init()
  1369. def get_audio_encoder(self):
  1370. return self.audio_encoder
  1371. def get_depth_decoder(self):
  1372. return self.depth_decoder
  1373. @auto_docstring
  1374. def forward(
  1375. self,
  1376. input_ids: Optional[torch.LongTensor] = None,
  1377. attention_mask: Optional[torch.BoolTensor] = None,
  1378. user_input_values: Optional[torch.FloatTensor] = None,
  1379. user_audio_codes: Optional[torch.Tensor] = None,
  1380. moshi_input_values: Optional[torch.FloatTensor] = None,
  1381. moshi_audio_codes: Optional[torch.Tensor] = None,
  1382. past_key_values: Optional[Cache] = None,
  1383. inputs_embeds: Optional[torch.FloatTensor] = None,
  1384. text_labels: Optional[torch.LongTensor] = None,
  1385. audio_labels: Optional[torch.LongTensor] = None,
  1386. use_cache: Optional[bool] = None,
  1387. output_attentions: Optional[bool] = None,
  1388. output_hidden_states: Optional[bool] = None,
  1389. return_dict: Optional[bool] = None,
  1390. **kwargs,
  1391. ) -> Union[tuple, Seq2SeqLMOutput]:
  1392. r"""
  1393. user_input_values (`torch.Tensor `of shape `(batch_size, 1, audio_sequence_length), *optional*):
  1394. The audio waveforms used as audio user prompt for the generation.
  1395. user_audio_codes (`torch.Tensor `of shape `(batch_size, num_codebooks, sequence_length), *optional*):
  1396. The audio codes used as audio user prompt for the generation. Has priority over `user_input_values` and represents the audio "tokens" of `user_input_values` once passed through the audio encoder.
  1397. moshi_input_values (`torch.Tensor `of shape `(batch_size, 1, audio_sequence_length), *optional*):
  1398. The audio waveforms used as audio Moshi prompt for the generation.
  1399. moshi_audio_codes (`torch.Tensor `of shape `(batch_size, num_codebooks, sequence_length), *optional*):
  1400. The audio codes used as audio Moshi prompt for the generation. Has priority over `moshi_input_values` and represents the audio "tokens" of `moshi_input_values` once passed through the audio encoder.
  1401. inputs_embeds (`torch.FloatTensor` of shape `(batch_size, sequence_length, hidden_size)`, *optional*):
  1402. Optionally, instead of passing `input_ids` you can choose to directly pass an embedded
  1403. representation. If `past_key_values` is used, optionally only the last `inputs_embeds` have to be
  1404. input (see `past_key_values`). This is useful if you want more control over how to convert
  1405. `input_ids` indices into associated vectors than the model's internal embedding lookup matrix.
  1406. If `input_ids` and `inputs_embeds` are both unset, `inputs_embeds` takes the value
  1407. of `inputs_embeds`.
  1408. text_labels (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*):
  1409. Labels for text language modeling. Note that the labels **are shifted** inside the model, i.e. you can set
  1410. `labels = input_ids` Indices are selected in `[-100, 0, ..., config.vocab_size]` All labels set to `-100`
  1411. are ignored (masked), the loss is only computed for labels in `[0, ..., config.vocab_size]`
  1412. audio_labels (`torch.LongTensor` of shape `(batch_size, num_codebooks, sequence_length)`, *optional*):
  1413. Labels for language modeling. Note that the labels **are shifted** inside the model, i.e. you can set
  1414. `labels = input_ids` Indices are selected in `[-100, 0, ..., config.vocab_size]` All labels set to `-100`
  1415. are ignored (masked), the loss is only computed for labels in `[0, ..., config.audio_vocab_size]`
  1416. Examples:
  1417. ```python
  1418. >>> from transformers import MoshiForConditionalGeneration
  1419. >>> import torch
  1420. >>> model = MoshiForConditionalGeneration.from_pretrained("kmhf/hf-moshiko")
  1421. >>> inputs = moshi.get_unconditional_inputs()
  1422. >>> logits = model(**inputs, ).logits
  1423. >>> logits.shape # (bsz, seq_len, text_vocab_size)
  1424. torch.Size([1, 1, 32000])
  1425. ```"""
  1426. return_dict = return_dict if return_dict is not None else self.config.use_return_dict
  1427. kwargs_audio_encoder = {
  1428. argument[len("audio_encoder_")]: value
  1429. for argument, value in kwargs.items()
  1430. if argument.startswith("audio_encoder_")
  1431. }
  1432. kwargs_decoder = {
  1433. argument[len("decoder_") :]: value for argument, value in kwargs.items() if argument.startswith("decoder_")
  1434. }
  1435. kwargs_depth_decoder = {
  1436. argument[len("depth_decoder_") :]: value
  1437. for argument, value in kwargs.items()
  1438. if argument.startswith("depth_decoder_")
  1439. }
  1440. # If inputs_embeds is provided, it has the priority over input_ids and audio_codes, which won't be used
  1441. if inputs_embeds is None:
  1442. if user_input_values is not None and user_audio_codes is None:
  1443. user_audio_codes = self.audio_encoder.encode(
  1444. user_input_values, num_quantizers=self.num_codebooks, **kwargs_audio_encoder
  1445. )[0]
  1446. if moshi_input_values is not None and moshi_audio_codes is None:
  1447. moshi_audio_codes = self.audio_encoder.encode(
  1448. moshi_input_values, num_quantizers=self.num_codebooks, **kwargs_audio_encoder
  1449. )[0]
  1450. audio_codes = torch.cat([moshi_audio_codes, user_audio_codes], dim=1)
  1451. if input_ids is None and audio_codes is None:
  1452. raise ValueError(
  1453. "You must provide at least one of `input_ids`, `inputs_embeds`, `input_values` and `audio_codes`."
  1454. )
  1455. if input_ids is not None:
  1456. inputs_embeds = self.decoder.model.embed_tokens(input_ids)
  1457. if audio_codes is not None:
  1458. audio_inputs_embeds = sum(
  1459. self.embed_tokens[codebook](audio_codes[:, codebook]) for codebook in range(audio_codes.shape[1])
  1460. )
  1461. inputs_embeds = (
  1462. audio_inputs_embeds
  1463. if inputs_embeds is None
  1464. else audio_inputs_embeds + inputs_embeds.to(audio_inputs_embeds.device)
  1465. )
  1466. # Decode
  1467. decoder_outputs = self.decoder(
  1468. attention_mask=attention_mask,
  1469. inputs_embeds=inputs_embeds,
  1470. output_attentions=output_attentions,
  1471. output_hidden_states=output_hidden_states,
  1472. use_cache=use_cache,
  1473. past_key_values=past_key_values,
  1474. return_dict=True,
  1475. labels=text_labels,
  1476. **kwargs_decoder,
  1477. )
  1478. decoder_last_hidden_state = decoder_outputs.last_hidden_state
  1479. depth_decoder_outputs = None
  1480. final_loss = decoder_outputs.loss
  1481. if text_labels is not None and audio_labels is not None:
  1482. # To use depth decoder forward here, we actually need oracle input ids since we're supposed to pass the true input ids
  1483. audio_labels = self.build_delay_pattern_mask(
  1484. audio_labels,
  1485. bos_token_id=self.config.audio_vocab_size,
  1486. pad_token_id=self.config.audio_vocab_size,
  1487. max_length=audio_labels.shape[-1] + 1,
  1488. )[0]
  1489. # (batch_size, sequence_length) -> (batch_size * sequence_length, 1)
  1490. text_labels = text_labels.view(-1, 1)
  1491. # (batch_size, num_codebooks, sequence_length) -> (batch_size * sequence_length, num_codebooks)
  1492. audio_labels = audio_labels.transpose(1, 2).reshape(-1, audio_labels.shape[1])
  1493. depth_input_ids = torch.cat([text_labels, audio_labels], dim=1)
  1494. # keep the last codebook out of input_ids
  1495. depth_input_ids = depth_input_ids[:, :-1]
  1496. # (batch_size, sequence_length, dim) -> (batch_size * sequence_length, 1, dim)
  1497. decoder_last_hidden_state = decoder_last_hidden_state.view(-1, 1, decoder_last_hidden_state.shape[-1])
  1498. depth_decoder_outputs = self.depth_decoder(
  1499. last_hidden_state=decoder_last_hidden_state,
  1500. input_ids=depth_input_ids,
  1501. attention_mask=attention_mask,
  1502. labels=audio_labels,
  1503. **kwargs_depth_decoder,
  1504. )
  1505. final_loss += depth_decoder_outputs.loss
  1506. if not return_dict:
  1507. outputs = decoder_outputs.to_tuple()
  1508. if depth_decoder_outputs is not None:
  1509. outputs += depth_decoder_outputs.to_tuple()
  1510. return outputs
  1511. return MoshiConditionalGenerationOutputWithPast(
  1512. loss=decoder_outputs.loss,
  1513. logits=decoder_outputs.logits,
  1514. last_hidden_state=decoder_last_hidden_state,
  1515. past_key_values=decoder_outputs.past_key_values,
  1516. hidden_states=decoder_outputs.hidden_states,
  1517. attentions=decoder_outputs.attentions,
  1518. depth_loss=None if depth_decoder_outputs is None else depth_decoder_outputs.loss,
  1519. audio_logits=None if depth_decoder_outputs is None else depth_decoder_outputs.logits,
  1520. depth_past_key_values=None if decoder_outputs is None else decoder_outputs.past_key_values,
  1521. depth_hidden_states=None if decoder_outputs is None else decoder_outputs.hidden_states,
  1522. depth_attentions=None if decoder_outputs is None else decoder_outputs.attentions,
  1523. )
  1524. def _prepare_attention_mask_for_generation(
  1525. self,
  1526. input_ids: torch.LongTensor,
  1527. generation_config: GenerationConfig,
  1528. kwargs: dict[str, Any],
  1529. ) -> torch.LongTensor:
  1530. pad_token_id = generation_config.pad_token_id
  1531. eos_token_id = generation_config.eos_token_id
  1532. default_attention_mask = torch.ones(input_ids.shape, dtype=torch.long, device=input_ids.device)
  1533. if pad_token_id is None:
  1534. return default_attention_mask
  1535. is_pad_token_in_inputs = (pad_token_id is not None) and torch.isin(input_ids, pad_token_id).any()
  1536. is_pad_token_not_equal_to_eos_token_id = (eos_token_id is None) or ~torch.isin(
  1537. eos_token_id, pad_token_id
  1538. ).any()
  1539. can_infer_attention_mask = is_pad_token_in_inputs * is_pad_token_not_equal_to_eos_token_id
  1540. attention_mask_from_padding = input_ids.ne(pad_token_id).long()
  1541. attention_mask = (
  1542. attention_mask_from_padding * can_infer_attention_mask + default_attention_mask * ~can_infer_attention_mask
  1543. )
  1544. return attention_mask
  1545. def _prepare_inputs_embeds_for_generation(
  1546. self,
  1547. input_ids: Optional[torch.LongTensor] = None,
  1548. user_input_values: Optional[torch.FloatTensor] = None,
  1549. user_audio_codes: Optional[torch.Tensor] = None,
  1550. moshi_input_values: Optional[torch.FloatTensor] = None,
  1551. moshi_audio_codes: Optional[torch.Tensor] = None,
  1552. inputs_embeds: Optional[torch.FloatTensor] = None,
  1553. attention_mask: Optional[torch.Tensor] = None,
  1554. generation_config: Optional[GenerationConfig] = None,
  1555. apply_delay_pattern_mask: bool = False,
  1556. concat_unconditional_inputs: bool = False,
  1557. ):
  1558. user_delay_pattern_mask = None
  1559. moshi_delay_pattern_mask = None
  1560. if (
  1561. inputs_embeds is None
  1562. and input_ids is None
  1563. and user_input_values is None
  1564. and user_audio_codes is None
  1565. and moshi_input_values is None
  1566. and moshi_audio_codes is None
  1567. ):
  1568. raise ValueError(
  1569. "You must provide at least one of `input_ids`, `user_input_values`, `moshi_input_values`, `user_audio_codes`, `moshi_audio_codes` or `inputs_embeds`."
  1570. )
  1571. # in case inputs_embeds is passed, we might still need to create delay pattern masks
  1572. if inputs_embeds is None or apply_delay_pattern_mask:
  1573. if user_input_values is not None and user_audio_codes is None:
  1574. user_audio_codes = self.audio_encoder.encode(user_input_values, num_quantizers=self.num_codebooks)[0]
  1575. if moshi_input_values is not None and moshi_audio_codes is None:
  1576. moshi_audio_codes = self.audio_encoder.encode(moshi_input_values, num_quantizers=self.num_codebooks)[0]
  1577. if inputs_embeds is None and concat_unconditional_inputs:
  1578. unconditional_inputs = self.get_unconditional_inputs(num_samples=user_audio_codes.shape[0])
  1579. moshi_audio_codes = torch.cat([unconditional_inputs.moshi_audio_codes, moshi_audio_codes], dim=2)
  1580. user_audio_codes = torch.cat([unconditional_inputs.user_audio_codes, user_audio_codes], dim=2)
  1581. input_ids = torch.cat([unconditional_inputs.input_ids, input_ids], dim=1)
  1582. if attention_mask is not None:
  1583. attention_mask = torch.cat([unconditional_inputs.attention_mask, attention_mask], dim=1)
  1584. if inputs_embeds is None or apply_delay_pattern_mask:
  1585. if apply_delay_pattern_mask and user_audio_codes is not None:
  1586. user_audio_codes, user_delay_pattern_mask = self.build_delay_pattern_mask(
  1587. user_audio_codes,
  1588. bos_token_id=self.config.audio_vocab_size,
  1589. pad_token_id=self.config.audio_vocab_size,
  1590. max_length=generation_config.max_length,
  1591. )
  1592. if apply_delay_pattern_mask and moshi_audio_codes is not None:
  1593. moshi_audio_codes, moshi_delay_pattern_mask = self.build_delay_pattern_mask(
  1594. moshi_audio_codes,
  1595. bos_token_id=self.config.audio_vocab_size,
  1596. pad_token_id=self.config.audio_vocab_size,
  1597. max_length=generation_config.max_length,
  1598. )
  1599. # If inputs_embeds is provided, it has the priority over input_ids and audio_codes, which won't be used
  1600. if inputs_embeds is None:
  1601. audio_inputs_embeds = None
  1602. if user_audio_codes is not None and moshi_audio_codes is not None:
  1603. audio_codes = torch.cat([moshi_audio_codes, user_audio_codes], dim=1)
  1604. audio_inputs_embeds = sum(
  1605. self.embed_tokens[codebook](audio_codes[:, codebook]) for codebook in range(audio_codes.shape[1])
  1606. )
  1607. elif moshi_audio_codes is not None:
  1608. audio_codes = moshi_audio_codes
  1609. audio_inputs_embeds = sum(
  1610. self.embed_tokens[codebook](audio_codes[:, codebook]) for codebook in range(audio_codes.shape[1])
  1611. )
  1612. elif user_audio_codes is not None:
  1613. audio_codes = user_audio_codes
  1614. audio_inputs_embeds = sum(
  1615. self.embed_tokens[codebook](audio_codes[:, codebook + self.num_codebooks])
  1616. for codebook in range(audio_codes.shape[1])
  1617. )
  1618. if input_ids is not None:
  1619. inputs_embeds = self.decoder.model.embed_tokens(input_ids)
  1620. if audio_inputs_embeds is not None:
  1621. inputs_embeds = (
  1622. audio_inputs_embeds
  1623. if inputs_embeds is None
  1624. else audio_inputs_embeds + inputs_embeds.to(audio_inputs_embeds.device)
  1625. )
  1626. return (
  1627. inputs_embeds,
  1628. input_ids,
  1629. user_audio_codes,
  1630. moshi_audio_codes,
  1631. user_delay_pattern_mask,
  1632. moshi_delay_pattern_mask,
  1633. attention_mask,
  1634. )
  1635. @torch.no_grad()
  1636. def generate(
  1637. self,
  1638. input_ids: Optional[torch.LongTensor] = None,
  1639. user_input_values: Optional[torch.FloatTensor] = None,
  1640. user_audio_codes: Optional[torch.Tensor] = None,
  1641. moshi_input_values: Optional[torch.FloatTensor] = None,
  1642. moshi_audio_codes: Optional[torch.Tensor] = None,
  1643. inputs_embeds: Optional[torch.FloatTensor] = None,
  1644. return_audio_waveforms: Optional[bool] = True,
  1645. return_audio_codes: Optional[bool] = None,
  1646. concat_unconditional_inputs: Optional[bool] = True,
  1647. **kwargs,
  1648. ) -> torch.LongTensor:
  1649. """
  1650. Generates sequences of text token ids and audio tokens ids.
  1651. Parameters:
  1652. input_ids (`torch.Tensor `of shape `(batch_size, sequence_length), *optional*):
  1653. The sequence used as a text prompt for the generation.
  1654. user_input_values (`torch.Tensor `of shape `(batch_size, 1, audio_sequence_length), *optional*):
  1655. The audio waveforms used as audio user prompt for the generation.
  1656. user_audio_codes (`torch.Tensor `of shape `(batch_size, num_codebooks, sequence_length), *optional*):
  1657. The audio codes used as audio user prompt for the generation. Has priority over `user_input_values` and represents the audio "tokens" of `user_input_values` once passed through the audio encoder.
  1658. moshi_input_values (`torch.Tensor `of shape `(batch_size, 1, audio_sequence_length), *optional*):
  1659. The audio waveforms used as audio Moshi prompt for the generation.
  1660. moshi_audio_codes (`torch.Tensor `of shape `(batch_size, num_codebooks, sequence_length), *optional*):
  1661. The audio codes used as audio Moshi prompt for the generation. Has priority over `moshi_input_values` and represents the audio "tokens" of `moshi_input_values` once passed through the audio encoder.
  1662. inputs_embeds (`torch.FloatTensor` of shape `(batch_size, sequence_length, hidden_size)`, *optional*):
  1663. Optionally, instead of passing `input_ids` and the audio inputs you can choose to directly pass an embedded representation. This
  1664. is useful if you want more control over how to convert the inputs into associated vectors than the
  1665. model's internal embedding lookup matrix.
  1666. return_audio_waveforms (`bool`, *optional*, defaults to `True`):
  1667. If `False`, won't generate the audio waveforms.
  1668. return_audio_codes (`bool`, *optional*):
  1669. If `True`, will also returns the generated audio codes, i.e the intermediate audio "tokens" which transforms to `audio_sequences` once passed through the audio decoder.
  1670. concat_unconditional_inputs (`bool`, *optional*, defaults to `True`):
  1671. If `False`, won't concatenate initial audio and text tokens.
  1672. kwargs (`dict[str, Any]`, *optional*):
  1673. Remaining dictionary of keyword arguments that are passed to the `generate` method. Refers to the
  1674. original [`generate` docstrings](https://huggingface.co/docs/transformers/main/en/main_classes/text_generation#transformers.GenerationMixin.generate)
  1675. for more information on how to use them.
  1676. Note that keywords with a *depth_* prefix will be input for the `generate` method of the
  1677. depth decoder. Otherwise, the latter will use its default generation config.
  1678. Return:
  1679. [`MoshiConditionalGenerationGenerateOutput`]
  1680. """
  1681. # multiple generate -> need to create/update device map
  1682. if hasattr(self, "hf_device_map") and not hasattr(self.depth_decoder, "hf_device_map"):
  1683. self.depth_decoder.hf_device_map = {}
  1684. if "" in self.hf_device_map:
  1685. self.depth_decoder.hf_device_map = self.hf_device_map
  1686. else:
  1687. main_device = [d for d in self.hf_device_map.values() if d not in ["cpu", "disk"]][0]
  1688. self.depth_decoder.hf_device_map = {
  1689. key[len("depth_decoder") :]: main_device if value in ["cpu", "disk"] else value
  1690. for key, value in self.hf_device_map.items()
  1691. if key.startswith("depth_decoder")
  1692. }
  1693. # need to remove depth_decoder from the top device_map so that we assign correctly the device for each layer idx in the cache
  1694. self.hf_device_map = {
  1695. key: value for key, value in self.hf_device_map.items() if not key.startswith("depth_decoder")
  1696. }
  1697. # retrieve depth decoder kwargs
  1698. depth_decoder_kwargs_keys = {argument for argument in kwargs if argument.startswith("depth_decoder_")}
  1699. kwargs_depth_decoder = {
  1700. argument[len("depth_decoder_") :]: kwargs.pop(argument) for argument in depth_decoder_kwargs_keys
  1701. }
  1702. # needs to prepare generation config, even though it'll be done again in `generate`
  1703. generation_config, kwargs = self._prepare_generation_config(kwargs.pop("generation_config", None), **kwargs)
  1704. input_ids, user_audio_codes, moshi_audio_codes, concat_unconditional_inputs = (
  1705. self._check_and_maybe_initialize_inputs(
  1706. input_ids=input_ids,
  1707. user_input_values=user_input_values,
  1708. user_audio_codes=user_audio_codes,
  1709. moshi_input_values=moshi_input_values,
  1710. moshi_audio_codes=moshi_audio_codes,
  1711. inputs_embeds=inputs_embeds,
  1712. concat_unconditional_inputs=concat_unconditional_inputs,
  1713. )
  1714. )
  1715. inputs = inputs_embeds if input_ids is None else input_ids
  1716. input_ids_length = inputs.shape[-1] + 1 if concat_unconditional_inputs else inputs.shape[-1]
  1717. has_default_max_length = kwargs.get("max_length") is None and generation_config.max_length is not None
  1718. has_default_min_length = kwargs.get("min_length") is None and generation_config.min_length is not None
  1719. generation_config = self._prepare_generated_length(
  1720. generation_config=generation_config,
  1721. has_default_max_length=has_default_max_length,
  1722. has_default_min_length=has_default_min_length,
  1723. model_input_name="inputs_embeds" if input_ids is None else "input_ids",
  1724. inputs_tensor=inputs,
  1725. input_ids_length=input_ids_length,
  1726. )
  1727. # retrieve depth decoder generation config if it exists
  1728. if hasattr(generation_config, "depth_decoder_config"):
  1729. depth_decoder_generation_config = generation_config.depth_decoder_config
  1730. else:
  1731. # we need to control the number of tokens generated by the depth decoder
  1732. depth_decoder_generation_config = {
  1733. "min_length": self.num_codebooks + 1,
  1734. "max_length": self.num_codebooks + 1,
  1735. "cache_implementation": "static",
  1736. }
  1737. # update kwargs_depth_decoder: kwargs_depth_decoder have priority over depth_decoder_generation_config
  1738. depth_decoder_generation_config.update(kwargs_depth_decoder)
  1739. kwargs_depth_decoder = depth_decoder_generation_config
  1740. attention_mask = kwargs.pop("attention_mask", None)
  1741. if attention_mask is None:
  1742. attention_mask = self._prepare_attention_mask_for_generation(
  1743. input_ids=input_ids,
  1744. generation_config=generation_config,
  1745. kwargs=kwargs,
  1746. )
  1747. (
  1748. inputs_embeds,
  1749. input_ids,
  1750. user_audio_codes,
  1751. moshi_audio_codes,
  1752. user_delay_pattern_mask,
  1753. moshi_delay_pattern_mask,
  1754. attention_mask,
  1755. ) = self._prepare_inputs_embeds_for_generation(
  1756. input_ids=input_ids,
  1757. user_input_values=user_input_values,
  1758. user_audio_codes=user_audio_codes,
  1759. moshi_input_values=moshi_input_values,
  1760. moshi_audio_codes=moshi_audio_codes,
  1761. inputs_embeds=inputs_embeds,
  1762. attention_mask=attention_mask,
  1763. generation_config=generation_config,
  1764. apply_delay_pattern_mask=True,
  1765. concat_unconditional_inputs=concat_unconditional_inputs,
  1766. )
  1767. # create blank user inputs - moshi needs a constant stream of user inputs
  1768. blank_input_values = torch.zeros(
  1769. (inputs_embeds.shape[0], 1, int(self.config.sampling_rate / self.config.audio_encoder_config.frame_rate)),
  1770. dtype=self.dtype,
  1771. device=self.device,
  1772. )
  1773. blank_user_audio_codes = self.audio_encoder.encode(blank_input_values, num_quantizers=self.num_codebooks)[0]
  1774. # set delay pattern mask for the rest of the generation
  1775. kwargs["user_delay_pattern_mask"] = (
  1776. user_delay_pattern_mask if user_delay_pattern_mask is not None else kwargs.get("user_delay_pattern_mask")
  1777. )
  1778. kwargs["moshi_delay_pattern_mask"] = (
  1779. moshi_delay_pattern_mask
  1780. if moshi_delay_pattern_mask is not None
  1781. else kwargs.get("moshi_delay_pattern_mask")
  1782. )
  1783. self.generated_audio_codes = torch.repeat_interleave(
  1784. moshi_audio_codes, max(generation_config.num_beams, generation_config.num_return_sequences), dim=0
  1785. )
  1786. return_dict_in_generate = generation_config.num_beams > 1 or generation_config.return_dict_in_generate
  1787. output_scores = generation_config.num_beams > 1 or generation_config.output_scores
  1788. outputs = super().generate(
  1789. inputs_embeds=inputs_embeds,
  1790. input_ids=input_ids,
  1791. generation_config=generation_config,
  1792. blank_user_audio_codes=blank_user_audio_codes,
  1793. kwargs_depth_decoder=kwargs_depth_decoder,
  1794. return_dict_in_generate=return_dict_in_generate,
  1795. output_scores=output_scores,
  1796. attention_mask=attention_mask,
  1797. **kwargs,
  1798. )
  1799. if not return_audio_waveforms and not return_audio_codes:
  1800. if return_dict_in_generate and not generation_config.return_dict_in_generate:
  1801. return outputs.sequences
  1802. return outputs
  1803. # check if outputs is a dict or tokens
  1804. if not return_dict_in_generate:
  1805. output_text_ids = outputs
  1806. else:
  1807. output_text_ids = outputs.sequences
  1808. if generation_config.num_return_sequences > 1:
  1809. moshi_delay_pattern_mask = torch.repeat_interleave(
  1810. moshi_delay_pattern_mask, generation_config.num_return_sequences, dim=0
  1811. )
  1812. if generation_config.num_beams > 1:
  1813. # we need to reorganize self.last_hidden_states and generated audio codes according to the beam_indices
  1814. # Beam indices are of shape `input_length + number_generated_tokens` but actually starts
  1815. # indexing indices at index 0 instead of index `input_length-1`.
  1816. # We thus discard the last `input_length` indices that are never used.
  1817. beam_indices = outputs.beam_indices[:, : -moshi_audio_codes.shape[-1]]
  1818. generated_audio_codes = self.generated_audio_codes[:, :, moshi_audio_codes.shape[-1] :]
  1819. # we've generated audio tokens `number_generated_tokens-1` times, so we use the corresponding beam indices to
  1820. # retrieve the right audio tokens
  1821. expanded_beam_indices = beam_indices[:, :-1].unsqueeze(1).expand(-1, self.num_codebooks, -1)
  1822. generated_audio_codes = torch.gather(generated_audio_codes, dim=0, index=expanded_beam_indices)
  1823. # now, rebuild generated audio codes, this time with the right beam tracking
  1824. moshi_audio_codes = torch.repeat_interleave(
  1825. moshi_audio_codes, generation_config.num_return_sequences, dim=0
  1826. )
  1827. self.generated_audio_codes = torch.cat((moshi_audio_codes, generated_audio_codes), dim=2)
  1828. # use the last beam indice to retrieve the right self.last_hidden_state
  1829. self.last_hidden_state = torch.index_select(self.last_hidden_state, dim=0, index=beam_indices[:, -1])
  1830. # we need to make a last generation with the latest generated tokens
  1831. last_hidden_state = self.last_hidden_state.view(-1, 1, self.last_hidden_state.shape[-1])
  1832. last_generated_audio_codes = self.depth_decoder.generate(
  1833. last_hidden_state=last_hidden_state,
  1834. input_ids=output_text_ids[:, -1:].view(-1, 1),
  1835. **kwargs_depth_decoder,
  1836. )
  1837. last_generated_audio_codes = last_generated_audio_codes[:, 1:].unsqueeze(2)
  1838. self.generated_audio_codes = torch.cat([self.generated_audio_codes, last_generated_audio_codes], dim=2)
  1839. # apply the pattern mask to the final audio ids
  1840. output_audio_codes = self.apply_delay_pattern_mask(self.generated_audio_codes, moshi_delay_pattern_mask)
  1841. # revert the pattern delay mask by filtering the pad token id and bos token ids
  1842. mask = moshi_delay_pattern_mask != self.config.audio_vocab_size
  1843. output_audio_codes = output_audio_codes[mask].reshape(mask.shape[0], self.num_codebooks, -1)
  1844. output_values = None
  1845. if return_audio_waveforms:
  1846. output_values = self.audio_encoder.decode(
  1847. output_audio_codes,
  1848. ).audio_values
  1849. output_audio_codes = output_audio_codes if return_audio_codes else None
  1850. if generation_config.return_dict_in_generate:
  1851. return MoshiConditionalGenerationGenerateOutput(
  1852. audio_sequences=output_values, audio_codes=output_audio_codes, **outputs
  1853. )
  1854. return MoshiConditionalGenerationGenerateOutput(
  1855. audio_sequences=output_values, sequences=output_text_ids, audio_codes=output_audio_codes
  1856. )
  1857. def prepare_inputs_for_generation(
  1858. self,
  1859. input_ids,
  1860. past_key_values=None,
  1861. attention_mask=None,
  1862. inputs_embeds=None,
  1863. cache_position=None,
  1864. position_ids=None,
  1865. use_cache=True,
  1866. logits_to_keep=None,
  1867. user_delay_pattern_mask=None,
  1868. moshi_delay_pattern_mask=None,
  1869. kwargs_depth_decoder=None,
  1870. blank_user_audio_codes: Optional[torch.FloatTensor] = None,
  1871. **kwargs,
  1872. ):
  1873. # Overwritten -- Moshi has custom post-processing on the prepared inputs.
  1874. # If we have cache: let's slice `input_ids` through `cache_position`, to keep only the unprocessed tokens
  1875. # Exception 1: when passing input_embeds, input_ids may be missing entries
  1876. # Exception 2: some generation methods do special slicing of input_ids, so we don't need to do it here
  1877. # Exception 3: with synced GPUs cache_position may go out of bounds, but we only want dummy token in that case.
  1878. # (we can't check exception 3 while compiling)
  1879. if past_key_values is not None:
  1880. if (
  1881. inputs_embeds is not None # Exception 1
  1882. or cache_position[-1] >= input_ids.shape[1] # Exception 3
  1883. ):
  1884. input_ids = input_ids[:, -cache_position.shape[0] :]
  1885. elif input_ids.shape[1] != cache_position.shape[0]: # Default case (the "else", a no op, is Exception 2)
  1886. input_ids = input_ids[:, cache_position]
  1887. # if `inputs_embeds` are passed, we only want to use them in the 1st generation step
  1888. if inputs_embeds is not None and cache_position[0] == 0:
  1889. model_inputs = {"inputs_embeds": inputs_embeds, "input_ids": None}
  1890. else:
  1891. model_inputs = {"input_ids": input_ids, "inputs_embeds": None}
  1892. if isinstance(past_key_values, StaticCache) and attention_mask.ndim == 2:
  1893. if model_inputs["inputs_embeds"] is not None:
  1894. batch_size, sequence_length, _ = inputs_embeds.shape
  1895. device = inputs_embeds.device
  1896. else:
  1897. batch_size, sequence_length = input_ids.shape
  1898. device = input_ids.device
  1899. attention_mask = self.decoder.model._prepare_4d_causal_attention_mask_with_cache_position(
  1900. attention_mask,
  1901. sequence_length=sequence_length,
  1902. target_length=past_key_values.get_max_cache_shape(),
  1903. dtype=self.decoder.lm_head.weight.dtype,
  1904. device=device,
  1905. cache_position=cache_position,
  1906. batch_size=batch_size,
  1907. config=self.config,
  1908. past_key_values=past_key_values,
  1909. )
  1910. model_inputs.update(
  1911. {
  1912. "position_ids": position_ids,
  1913. "past_key_values": past_key_values,
  1914. "use_cache": use_cache,
  1915. "attention_mask": attention_mask,
  1916. "cache_position": cache_position,
  1917. }
  1918. )
  1919. # 2. Now that everything is prepared, generate audio_codes using the depth decoder
  1920. # we want to do it after a first token has been generated
  1921. if model_inputs["input_ids"] is not None:
  1922. last_hidden_state = kwargs.pop("last_hidden_state")
  1923. # (batch_size, sequence_length, dim) -> (batch_size * sequence_length, 1, dim)
  1924. last_hidden_state = last_hidden_state.view(-1, 1, last_hidden_state.shape[-1])
  1925. input_ids = model_inputs.pop("input_ids")
  1926. generated_audio_codes = self.depth_decoder.generate(
  1927. last_hidden_state=last_hidden_state,
  1928. input_ids=input_ids.view(-1, 1),
  1929. **kwargs_depth_decoder,
  1930. )
  1931. # the first tokens are text tokens
  1932. generated_audio_codes = generated_audio_codes[:, 1:].unsqueeze(2)
  1933. user_audio_codes = self.apply_delay_pattern_mask(
  1934. torch.cat(
  1935. [self.generated_audio_codes, blank_user_audio_codes.to(self.generated_audio_codes.device)], dim=2
  1936. ),
  1937. user_delay_pattern_mask,
  1938. )[:, :, -1:]
  1939. self.generated_audio_codes = self.apply_delay_pattern_mask(
  1940. torch.cat([self.generated_audio_codes, generated_audio_codes], dim=2), moshi_delay_pattern_mask
  1941. )
  1942. inputs_embeds, _, _, _, _, _, _ = self._prepare_inputs_embeds_for_generation(
  1943. input_ids, moshi_audio_codes=self.generated_audio_codes[:, :, -1:], user_audio_codes=user_audio_codes
  1944. )
  1945. model_inputs["input_ids"] = None
  1946. model_inputs["inputs_embeds"] = inputs_embeds
  1947. # Forward ALL kwargs that are uninitialized (e.g. `use_cache`).
  1948. for key, value in kwargs.items():
  1949. if key not in model_inputs:
  1950. model_inputs[key] = value
  1951. return model_inputs
  1952. def _update_model_kwargs_for_generation(
  1953. self,
  1954. outputs: ModelOutput,
  1955. model_kwargs: dict[str, Any],
  1956. is_encoder_decoder: bool = False,
  1957. num_new_tokens: int = 1,
  1958. ) -> dict[str, Any]:
  1959. model_kwargs = super()._update_model_kwargs_for_generation(
  1960. outputs, model_kwargs, is_encoder_decoder, num_new_tokens
  1961. )
  1962. # update last_hidden_state that'll be used in the depth decoder
  1963. model_kwargs["last_hidden_state"] = outputs.get("last_hidden_state")[:, -1:]
  1964. # dirty, but we need to make a last depth_decoder.generate
  1965. self.last_hidden_state = outputs.get("last_hidden_state")[:, -1:]
  1966. return model_kwargs
  1967. def get_input_embeddings(self):
  1968. return self.decoder.get_input_embeddings()
  1969. def set_input_embeddings(self, value):
  1970. self.decoder.set_input_embeddings(value)
  1971. def get_output_embeddings(self):
  1972. return self.decoder.get_output_embeddings()
  1973. def set_output_embeddings(self, new_embeddings):
  1974. self.decoder.set_output_embeddings(new_embeddings)
  1975. def freeze_audio_encoder(self):
  1976. """
  1977. Freeze the audio encoder weights.
  1978. """
  1979. for param in self.audio_encoder.parameters():
  1980. param.requires_grad = False
  1981. self.audio_encoder._requires_grad = False
  1982. def freeze_depth_decoder(self):
  1983. """
  1984. Freeze the depth encoder weights.
  1985. """
  1986. for param in self.depth_decoder.parameters():
  1987. param.requires_grad = False
  1988. self.depth_decoder._requires_grad = False
  1989. @staticmethod
  1990. # Copied from transformers.models.musicgen.modeling_musicgen.MusicgenForCausalLM.apply_delay_pattern_mask
  1991. def apply_delay_pattern_mask(input_ids, decoder_pad_token_mask):
  1992. """Apply a delay pattern mask to the decoder input ids, only preserving predictions where
  1993. the mask is set to -1, and otherwise setting to the value detailed in the mask."""
  1994. seq_len = input_ids.shape[-1]
  1995. decoder_pad_token_mask = decoder_pad_token_mask[..., :seq_len]
  1996. input_ids = torch.where(decoder_pad_token_mask == -1, input_ids, decoder_pad_token_mask)
  1997. return input_ids
  1998. def build_delay_pattern_mask(
  1999. self, input_ids: torch.LongTensor, bos_token_id: int, pad_token_id: int, max_length: Optional[int] = None
  2000. ):
  2001. """Build a delayed pattern mask to the input_ids. Each codebook, except the first one, is offset by
  2002. one, giving a delayed pattern mask at the start of sequence and end of sequence. Take the example where there
  2003. are 4 codebooks and a max sequence length of 6, we have the delayed pattern mask of shape `(codebooks,
  2004. seq_len)`:
  2005. - [-1, -1, -1, -1, -1, P]
  2006. - [ B, -1, -1, -1, -1, -1]
  2007. - [ B, -1, -1, -1, -1, -1]
  2008. - [ B, -1, -1, -1, -1, -1]
  2009. where B is the beginning-of-sentence token, P is the special padding token id and -1 indicates that the token is valid for prediction. If we include
  2010. a prompt (input ids), the -1 positions indicate where new tokens should be predicted. Otherwise, the
  2011. mask is set to the value in the prompt:
  2012. - [ a0, a1, -1, -1, -1, P]
  2013. - [ B, b0, b1, -1, -1, -1]
  2014. - [ B, c0, c1, -1, -1, -1]
  2015. - [ B, d0, d1, -1, -1, -1]
  2016. where a-d indicate the codebook channel and 0/1 indicates the temporality. Now, we only override the -1
  2017. tokens in our prediction.
  2018. """
  2019. bsz, num_codebooks, seq_len = input_ids.shape
  2020. max_length = max_length if max_length is not None else self.generation_config.max_length
  2021. input_ids_shifted = (
  2022. torch.ones((bsz, num_codebooks, max_length), dtype=torch.long, device=input_ids.device) * -1
  2023. )
  2024. # the first codebook channel is not shifted
  2025. seq_len_to_keep = min(seq_len, max_length - 1)
  2026. input_ids_shifted[:, 0, :seq_len_to_keep] = input_ids[:, 0, :seq_len_to_keep]
  2027. # fill the shifted ids with the prompt entries
  2028. input_ids_shifted[:, 1:, 1 : seq_len_to_keep + 1] = input_ids[:, 1:, :seq_len_to_keep]
  2029. # fill with BOS and PAD
  2030. input_ids_shifted[:, 1:, 0] = bos_token_id
  2031. input_ids_shifted[:, 0, -1] = pad_token_id
  2032. # construct a pattern mask that indicates the positions of BOS and PAD tokens for each codebook
  2033. pattern_mask = input_ids_shifted
  2034. input_ids = input_ids_shifted[..., :seq_len_to_keep]
  2035. return input_ids, pattern_mask
  2036. def get_unconditional_inputs(self, num_samples=1):
  2037. """
  2038. Helper function to get null inputs for unconditional generation, enabling the model to be used without the
  2039. feature extractor or tokenizer.
  2040. Args:
  2041. num_samples (int, *optional*):
  2042. Number of audio samples to unconditionally generate.
  2043. max_new_tokens (int, *optional*):
  2044. Number of tokens to generate for each sample. More tokens means longer audio samples, at the expense of
  2045. longer inference (since more audio tokens need to be generated per sample).
  2046. Example:
  2047. ```python
  2048. >>> from transformers import MoshiForConditionalGeneration
  2049. >>> model = MoshiForConditionalGeneration.from_pretrained("kmhf/hf-moshiko-pytorch-bf16")
  2050. >>> # get the unconditional (or 'null') inputs for the model
  2051. >>> unconditional_inputs = model.get_unconditional_inputs(num_samples=1)
  2052. >>> audio_samples = model.generate(**unconditional_inputs, max_new_tokens=256)
  2053. ```"""
  2054. input_ids = torch.ones((num_samples, 1), device=self.device, dtype=torch.int64) * self.config.vocab_size
  2055. user_audio_codes = (
  2056. torch.ones((num_samples, self.num_codebooks, 1), device=self.device, dtype=torch.int64)
  2057. * self.config.audio_vocab_size
  2058. )
  2059. moshi_audio_codes = (
  2060. torch.ones((num_samples, self.num_codebooks, 1), device=self.device, dtype=torch.int64)
  2061. * self.config.audio_vocab_size
  2062. )
  2063. attention_mask = torch.ones((num_samples, 1), device=self.device, dtype=torch.long)
  2064. return MoshiUnconditionalInput(
  2065. input_ids=input_ids,
  2066. user_audio_codes=user_audio_codes,
  2067. moshi_audio_codes=moshi_audio_codes,
  2068. attention_mask=attention_mask,
  2069. )
  2070. def _check_and_maybe_initialize_inputs(
  2071. self,
  2072. input_ids=None,
  2073. user_input_values=None,
  2074. user_audio_codes=None,
  2075. moshi_input_values=None,
  2076. moshi_audio_codes=None,
  2077. inputs_embeds=None,
  2078. concat_unconditional_inputs=None,
  2079. ):
  2080. inputs = input_ids if inputs_embeds is None else inputs_embeds
  2081. user_input = user_audio_codes if user_input_values is None else user_input_values
  2082. moshi_input = moshi_audio_codes if moshi_input_values is None else moshi_input_values
  2083. one_input_has_been_passed = (user_input is not None) or (moshi_input is not None) or (inputs is not None)
  2084. # concat_unconditional_inputs will be False if inputs_embeds is used
  2085. concat_unconditional_inputs = concat_unconditional_inputs and not (
  2086. inputs_embeds is not None and input_ids is None
  2087. )
  2088. # if one or two of the three required inputs have been passed, throws an error
  2089. if one_input_has_been_passed and (user_input is None):
  2090. raise ValueError(
  2091. "No user audio inputs have been passed alongside the other inputs. Make sure either `user_input_values` or `user_audio_codes` is passed or use `MoshiForConditionalGeneration.get_unconditional_inputs`. Check the `MoshiForConditionalGeneration` docstrings for more information."
  2092. )
  2093. elif one_input_has_been_passed and (moshi_input is None):
  2094. raise ValueError(
  2095. "No Moshi audio inputs have been passed alongside the other inputs. Make sure either `moshi_input_values` or `moshi_audio_codes` is passed or use `MoshiForConditionalGeneration.get_unconditional_inputs`. Check the `MoshiForConditionalGeneration` docstrings for more information."
  2096. )
  2097. elif one_input_has_been_passed and (inputs is None):
  2098. raise ValueError(
  2099. "No `input_ids` or `inputs_embeds` have been passed alongside the other inputs. Make sure `input_ids` is passed or use `MoshiForConditionalGeneration.get_unconditional_inputs`. Check the `MoshiForConditionalGeneration` docstrings for more information."
  2100. )
  2101. elif not one_input_has_been_passed:
  2102. # if no inputs have been passed, use default values
  2103. unconditional_inputs = self.get_unconditional_inputs()
  2104. input_ids = unconditional_inputs.input_ids
  2105. user_audio_codes = unconditional_inputs.user_audio_codes
  2106. moshi_audio_codes = unconditional_inputs.moshi_audio_codes
  2107. # in that case, no need to concat unconditional inputs
  2108. concat_unconditional_inputs = False
  2109. else:
  2110. # check if same sequence length
  2111. user_seq_length = user_input.shape[-1]
  2112. moshi_seq_length = moshi_input.shape[-1]
  2113. tokens_seq_length = inputs.shape[1]
  2114. ratio = self.config.audio_encoder_config.frame_rate / self.config.sampling_rate
  2115. moshi_seq_length = math.ceil(moshi_seq_length * ratio) if moshi_audio_codes is None else moshi_seq_length
  2116. user_seq_length = math.ceil(user_seq_length * ratio) if user_audio_codes is None else user_seq_length
  2117. if tokens_seq_length != moshi_seq_length or tokens_seq_length != user_seq_length:
  2118. raise ValueError(
  2119. "At least one of the 3 inputs of `MoshiForConditionalGeneration` doesn't have the same sequence length as the others."
  2120. "Make sure that they all have the same sequence length. Check the `MoshiForConditionalGeneration` docstrings for more information."
  2121. )
  2122. return input_ids, user_audio_codes, moshi_audio_codes, concat_unconditional_inputs
  2123. __all__ = ["MoshiForCausalLM", "MoshiForConditionalGeneration", "MoshiModel", "MoshiPreTrainedModel"]