modular_moonshine.py 43 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525526527528529530531532533534535536537538539540541542543544545546547548549550551552553554555556557558559560561562563564565566567568569570571572573574575576577578579580581582583584585586587588589590591592593594595596597598599600601602603604605606607608609610611612613614615616617618619620621622623624625626627628629630631632633634635636637638639640641642643644645646647648649650651652653654655656657658659660661662663664665666667668669670671672673674675676677678679680681682683684685686687688689690691692693694695696697698699700701702703704705706707708709710711712713714715716717718719720721722723724725726727728729730731732733734735736737738739740741742743744745746747748749750751752753754755756757758759760761762763764765766767768769770771772773774775776777778779780781782783784785786787788789790791792793794795796797798799800801802803804805806807808809810811812813814815816817818819820821822823824825826827828829830831832833834835836837838839840841842843844845846847848849850851852853854855856857858859860861862863864865866867868869870871872873874875876877878879880881882883884885886887888889890891892893894895896897898899900901902903904905906907908909910911912913914915916917918919920921
  1. # Copyright 2025 The HuggingFace Inc. team. All rights reserved.
  2. #
  3. # Licensed under the Apache License, Version 2.0 (the "License");
  4. # you may not use this file except in compliance with the License.
  5. # You may obtain a copy of the License at
  6. #
  7. # http://www.apache.org/licenses/LICENSE-2.0
  8. #
  9. # Unless required by applicable law or agreed to in writing, software
  10. # distributed under the License is distributed on an "AS IS" BASIS,
  11. # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  12. # See the License for the specific language governing permissions and
  13. # limitations under the License.
  14. from typing import Callable, Optional, Union
  15. import torch
  16. import torch.nn as nn
  17. from transformers.utils.generic import OutputRecorder, check_model_inputs
  18. from ...activations import ACT2FN
  19. from ...cache_utils import Cache, DynamicCache, EncoderDecoderCache
  20. from ...configuration_utils import PretrainedConfig
  21. from ...generation import GenerationMixin
  22. from ...masking_utils import create_causal_mask
  23. from ...modeling_attn_mask_utils import _prepare_4d_attention_mask, _prepare_4d_attention_mask_for_sdpa
  24. from ...modeling_flash_attention_utils import FlashAttentionKwargs
  25. from ...modeling_layers import GradientCheckpointingLayer
  26. from ...modeling_outputs import (
  27. BaseModelOutput,
  28. BaseModelOutputWithPast,
  29. BaseModelOutputWithPastAndCrossAttentions,
  30. Seq2SeqLMOutput,
  31. Seq2SeqModelOutput,
  32. )
  33. from ...modeling_rope_utils import rope_config_validation
  34. from ...modeling_utils import ALL_ATTENTION_FUNCTIONS, PreTrainedModel
  35. from ...processing_utils import Unpack
  36. from ...utils import TransformersKwargs, auto_docstring, can_return_tuple, logging
  37. from ...utils.deprecation import deprecate_kwarg
  38. from ..glm.modeling_glm import GlmAttention, GlmRotaryEmbedding, apply_rotary_pos_emb
  39. from ..llama.modeling_llama import LlamaDecoderLayer, LlamaModel, eager_attention_forward
  40. from ..whisper.modeling_whisper import WhisperModel, shift_tokens_right
  41. logger = logging.get_logger(__name__)
  42. class MoonshineConfig(PretrainedConfig):
  43. r"""
  44. This is the configuration class to store the configuration of a [`MoonshineModel`]. It is used to instantiate a Moonshine
  45. model according to the specified arguments, defining the model architecture. Instantiating a configuration with the
  46. defaults will yield a similar configuration to that of the Moonshine
  47. [UsefulSensors/moonshine-tiny](https://huggingface.co/UsefulSensors/moonshine-tiny).
  48. Configuration objects inherit from [`PretrainedConfig`] and can be used to control the model outputs. Read the
  49. documentation from [`PretrainedConfig`] for more information.
  50. Args:
  51. vocab_size (`int`, *optional*, defaults to 32768):
  52. Vocabulary size of the Moonshine model. Defines the number of different tokens that can be represented by the
  53. `inputs_ids` passed when calling [`MoonshineModel`].
  54. hidden_size (`int`, *optional*, defaults to 288):
  55. Dimension of the hidden representations.
  56. intermediate_size (`int`, *optional*, defaults to 1152):
  57. Dimension of the MLP representations.
  58. encoder_num_hidden_layers (`int`, *optional*, defaults to 6):
  59. Number of hidden layers in the Transformer encoder.
  60. decoder_num_hidden_layers (`int`, *optional*, defaults to 6):
  61. Number of hidden layers in the Transformer decoder.
  62. encoder_num_attention_heads (`int`, *optional*, defaults to 8):
  63. Number of attention heads for each attention layer in the Transformer encoder.
  64. decoder_num_attention_heads (`int`, *optional*, defaults to 8):
  65. Number of attention heads for each attention layer in the Transformer decoder.
  66. encoder_num_key_value_heads (`int`, *optional*):
  67. This is the number of key_value heads that should be used to implement Grouped Query Attention. If
  68. `encoder_num_key_value_heads=encoder_num_attention_heads`, the model will use Multi Head Attention (MHA), if
  69. `encoder_num_key_value_heads=1` the model will use Multi Query Attention (MQA) otherwise GQA is used. When
  70. converting a multi-head checkpoint to a GQA checkpoint, each group key and value head should be constructed
  71. by meanpooling all the original heads within that group. For more details, check out [this
  72. paper](https://huggingface.co/papers/2305.13245). If it is not specified, will default to
  73. `num_attention_heads`.
  74. decoder_num_key_value_heads (`int`, *optional*):
  75. This is the number of key_value heads that should be used to implement Grouped Query Attention. If
  76. `decoder_num_key_value_heads=decoder_num_attention_heads`, the model will use Multi Head Attention (MHA), if
  77. `decoder_num_key_value_heads=1` the model will use Multi Query Attention (MQA) otherwise GQA is used. When
  78. converting a multi-head checkpoint to a GQA checkpoint, each group key and value head should be constructed
  79. by meanpooling all the original heads within that group. For more details, check out [this
  80. paper](https://huggingface.co/papers/2305.13245). If it is not specified, will default to
  81. `decoder_num_attention_heads`.
  82. pad_head_dim_to_multiple_of (`int`, *optional*):
  83. Pad head dimension in encoder and decoder to the next multiple of this value. Necessary for using certain
  84. optimized attention implementations.
  85. encoder_hidden_act (`str` or `function`, *optional*, defaults to `"gelu"`):
  86. The non-linear activation function (function or string) in the encoder.
  87. decoder_hidden_act (`str` or `function`, *optional*, defaults to `"silu"`):
  88. The non-linear activation function (function or string) in the decoder.
  89. max_position_embeddings (`int`, *optional*, defaults to 512):
  90. The maximum sequence length that this model might ever be used with.
  91. initializer_range (`float`, *optional*, defaults to 0.02):
  92. The standard deviation of the truncated_normal_initializer for initializing all weight matrices.
  93. decoder_start_token_id (`int`, *optional*, defaults to 1):
  94. Corresponds to the "<|startoftranscript|>" token, which is automatically used when no `decoder_input_ids`
  95. are provided to the `generate` function. It is used to guide the model`s generation process depending on
  96. the task.
  97. use_cache (`bool`, *optional*, defaults to `True`):
  98. Whether or not the model should return the last key/values attentions (not used by all models).
  99. rope_theta (`float`, *optional*, defaults to 10000.0):
  100. The base period of the RoPE embeddings.
  101. rope_scaling (`Dict`, *optional*):
  102. Dictionary containing the scaling configuration for the RoPE embeddings. NOTE: if you apply new rope type
  103. and you expect the model to work on longer `max_position_embeddings`, we recommend you to update this value
  104. accordingly.
  105. Expected contents:
  106. `rope_type` (`str`):
  107. The sub-variant of RoPE to use. Can be one of ['default', 'linear', 'dynamic', 'yarn', 'longrope',
  108. 'llama3'], with 'default' being the original RoPE implementation.
  109. `factor` (`float`, *optional*):
  110. Used with all rope types except 'default'. The scaling factor to apply to the RoPE embeddings. In
  111. most scaling types, a `factor` of x will enable the model to handle sequences of length x *
  112. original maximum pre-trained length.
  113. `original_max_position_embeddings` (`int`, *optional*):
  114. Used with 'dynamic', 'longrope' and 'llama3'. The original max position embeddings used during
  115. pretraining.
  116. `attention_factor` (`float`, *optional*):
  117. Used with 'yarn' and 'longrope'. The scaling factor to be applied on the attention
  118. computation. If unspecified, it defaults to value recommended by the implementation, using the
  119. `factor` field to infer the suggested value.
  120. `beta_fast` (`float`, *optional*):
  121. Only used with 'yarn'. Parameter to set the boundary for extrapolation (only) in the linear
  122. ramp function. If unspecified, it defaults to 32.
  123. `beta_slow` (`float`, *optional*):
  124. Only used with 'yarn'. Parameter to set the boundary for interpolation (only) in the linear
  125. ramp function. If unspecified, it defaults to 1.
  126. `short_factor` (`list[float]`, *optional*):
  127. Only used with 'longrope'. The scaling factor to be applied to short contexts (<
  128. `original_max_position_embeddings`). Must be a list of numbers with the same length as the hidden
  129. size divided by the number of attention heads divided by 2
  130. `long_factor` (`list[float]`, *optional*):
  131. Only used with 'longrope'. The scaling factor to be applied to long contexts (<
  132. `original_max_position_embeddings`). Must be a list of numbers with the same length as the hidden
  133. size divided by the number of attention heads divided by 2
  134. `low_freq_factor` (`float`, *optional*):
  135. Only used with 'llama3'. Scaling factor applied to low frequency components of the RoPE
  136. `high_freq_factor` (`float`, *optional*):
  137. Only used with 'llama3'. Scaling factor applied to high frequency components of the RoPE
  138. partial_rotary_factor (`float`, *optional*, defaults to 0.9):
  139. Percentage of the query and keys which will have rotary embedding.
  140. is_encoder_decoder (`bool`, *optional*, defaults to `True`):
  141. Whether the model is used as an encoder/decoder or not.
  142. attention_bias (`bool`, *optional*, defaults to `False`):
  143. Whether to use a bias in the query, key, value and output projection layers during self-attention.
  144. attention_dropout (`float`, *optional*, defaults to 0.0):
  145. The dropout ratio for the attention probabilities.
  146. bos_token_id (`int`, *optional*, defaults to 1):
  147. Denotes beginning of sequences token id.
  148. eos_token_id (`int`, *optional*, defaults to 2):
  149. Denotes end of sequences token id.
  150. Example:
  151. ```python
  152. >>> from transformers import MoonshineModel, MoonshineConfig
  153. >>> # Initializing a Moonshine style configuration
  154. >>> configuration = MoonshineConfig().from_pretrained("UsefulSensors/moonshine-tiny")
  155. >>> # Initializing a model from the configuration
  156. >>> model = MoonshineModel(configuration)
  157. >>> # Accessing the model configuration
  158. >>> configuration = model.config
  159. ```"""
  160. model_type = "moonshine"
  161. keys_to_ignore_at_inference = ["past_key_values"]
  162. attribute_map = {
  163. "num_key_value_heads": "encoder_num_key_value_heads",
  164. "num_attention_heads": "encoder_num_attention_heads",
  165. "num_hidden_layers": "encoder_num_hidden_layers",
  166. }
  167. def __init__(
  168. self,
  169. vocab_size=32768,
  170. hidden_size=288,
  171. intermediate_size=1152,
  172. encoder_num_hidden_layers=6,
  173. decoder_num_hidden_layers=6,
  174. encoder_num_attention_heads=8,
  175. decoder_num_attention_heads=8,
  176. encoder_num_key_value_heads=None,
  177. decoder_num_key_value_heads=None,
  178. pad_head_dim_to_multiple_of=None,
  179. encoder_hidden_act="gelu",
  180. decoder_hidden_act="silu",
  181. max_position_embeddings=512,
  182. initializer_range=0.02,
  183. decoder_start_token_id=1,
  184. use_cache=True,
  185. rope_theta=10000.0,
  186. rope_scaling=None,
  187. partial_rotary_factor=0.9,
  188. is_encoder_decoder=True,
  189. attention_bias=False,
  190. attention_dropout=0.0,
  191. bos_token_id=1,
  192. eos_token_id=2,
  193. **kwargs,
  194. ):
  195. self.vocab_size = vocab_size
  196. self.hidden_size = hidden_size
  197. self.intermediate_size = intermediate_size
  198. self.encoder_num_hidden_layers = encoder_num_hidden_layers
  199. self.decoder_num_hidden_layers = decoder_num_hidden_layers
  200. self.encoder_num_attention_heads = encoder_num_attention_heads
  201. self.decoder_num_attention_heads = decoder_num_attention_heads
  202. if encoder_num_key_value_heads is None:
  203. encoder_num_key_value_heads = encoder_num_attention_heads
  204. self.encoder_num_key_value_heads = encoder_num_key_value_heads
  205. if decoder_num_key_value_heads is None:
  206. decoder_num_key_value_heads = decoder_num_attention_heads
  207. self.decoder_num_key_value_heads = decoder_num_key_value_heads
  208. self.pad_head_dim_to_multiple_of = pad_head_dim_to_multiple_of
  209. self.encoder_hidden_act = encoder_hidden_act
  210. self.decoder_hidden_act = decoder_hidden_act
  211. self.max_position_embeddings = max_position_embeddings
  212. self.initializer_range = initializer_range
  213. self.decoder_start_token_id = decoder_start_token_id
  214. self.use_cache = use_cache
  215. self.rope_theta = rope_theta
  216. self.rope_scaling = rope_scaling
  217. self.partial_rotary_factor = partial_rotary_factor
  218. self.is_encoder_decoder = is_encoder_decoder
  219. self.attention_bias = attention_bias
  220. self.attention_dropout = attention_dropout
  221. # Validate the correctness of rotary position embeddings parameters
  222. rope_config_validation(self)
  223. super().__init__(
  224. bos_token_id=bos_token_id,
  225. eos_token_id=eos_token_id,
  226. is_encoder_decoder=is_encoder_decoder,
  227. decoder_start_token_id=decoder_start_token_id,
  228. **kwargs,
  229. )
  230. class MoonshineEncoderMLP(nn.Module):
  231. def __init__(self, config, hidden_act):
  232. super().__init__()
  233. self.config = config
  234. self.activation_fn = ACT2FN[hidden_act]
  235. self.fc1 = nn.Linear(config.hidden_size, config.intermediate_size)
  236. self.fc2 = nn.Linear(config.intermediate_size, config.hidden_size)
  237. def forward(self, hidden_states: torch.Tensor) -> torch.Tensor:
  238. hidden_states = self.fc1(hidden_states)
  239. hidden_states = self.activation_fn(hidden_states)
  240. hidden_states = self.fc2(hidden_states)
  241. return hidden_states
  242. class MoonshineDecoderMLP(nn.Module):
  243. def __init__(self, config, hidden_act):
  244. super().__init__()
  245. self.config = config
  246. self.activation_fn = ACT2FN[hidden_act]
  247. self.fc1 = nn.Linear(config.hidden_size, config.intermediate_size * 2)
  248. self.fc2 = nn.Linear(config.intermediate_size, config.hidden_size)
  249. def forward(self, hidden_states: torch.Tensor) -> torch.Tensor:
  250. hidden_states = self.fc1(hidden_states)
  251. hidden_states, gate = hidden_states.chunk(2, dim=-1)
  252. hidden_states = self.activation_fn(gate) * hidden_states
  253. hidden_states = self.fc2(hidden_states)
  254. return hidden_states
  255. class MoonshineAttention(GlmAttention):
  256. def __init__(
  257. self,
  258. config: MoonshineConfig,
  259. layer_idx: int,
  260. is_causal: bool,
  261. num_attention_heads: int,
  262. num_key_value_heads: int,
  263. ):
  264. config.update({"num_attention_heads": num_attention_heads, "num_key_value_heads": num_key_value_heads})
  265. super().__init__(config, layer_idx)
  266. self.is_causal = is_causal
  267. self.head_dim = getattr(config, "head_dim", config.hidden_size // config.num_attention_heads)
  268. # Pad head dimension to the next specified multiple.
  269. if self.config.pad_head_dim_to_multiple_of is not None:
  270. target_multiple = self.config.pad_head_dim_to_multiple_of
  271. target_head_dim = target_multiple * ((self.head_dim + target_multiple - 1) // target_multiple)
  272. self.head_dim_padding = target_head_dim - self.head_dim
  273. else:
  274. self.head_dim_padding = 0
  275. @deprecate_kwarg("past_key_value", new_name="past_key_values", version="4.58")
  276. def forward(
  277. self,
  278. hidden_states: torch.Tensor,
  279. position_embeddings: Optional[tuple[torch.Tensor, torch.Tensor]] = None,
  280. attention_mask: Optional[torch.Tensor] = None,
  281. past_key_values: Optional[Cache] = None,
  282. cache_position: Optional[torch.LongTensor] = None,
  283. key_value_states: Optional[torch.Tensor] = None,
  284. **kwargs: Unpack[FlashAttentionKwargs],
  285. ) -> tuple[torch.Tensor, Optional[torch.Tensor], Optional[tuple[torch.Tensor]]]:
  286. bsz, q_len = hidden_states.shape[:-1]
  287. query_states = (
  288. self.q_proj(hidden_states).view(bsz, q_len, self.config.num_key_value_heads, self.head_dim).transpose(1, 2)
  289. )
  290. is_cross_attention = key_value_states is not None
  291. if past_key_values is not None:
  292. is_updated = past_key_values.is_updated.get(self.layer_idx)
  293. if is_cross_attention:
  294. # after the first generated id, we can subsequently re-use all key/value_states from cache
  295. past_key_values.is_updated[self.layer_idx] = True
  296. past_key_values = past_key_values.cross_attention_cache
  297. else:
  298. past_key_values = past_key_values.self_attention_cache
  299. # use key_value_states if cross attention
  300. current_states = key_value_states if key_value_states is not None else hidden_states
  301. if is_cross_attention and past_key_values and is_updated:
  302. key_states = past_key_values.layers[self.layer_idx].keys
  303. value_states = past_key_values.layers[self.layer_idx].values
  304. else:
  305. key_states = (
  306. self.k_proj(current_states)
  307. .view(bsz, -1, self.config.num_key_value_heads, self.head_dim)
  308. .transpose(1, 2)
  309. )
  310. value_states = (
  311. self.v_proj(current_states)
  312. .view(bsz, -1, self.config.num_key_value_heads, self.head_dim)
  313. .transpose(1, 2)
  314. )
  315. if is_cross_attention and past_key_values is not None:
  316. key_states, value_states = past_key_values.update(
  317. key_states, value_states, self.layer_idx, {"cache_position": cache_position}
  318. )
  319. if not is_cross_attention:
  320. cos, sin = position_embeddings
  321. query_states, key_states = apply_rotary_pos_emb(query_states, key_states, cos, sin)
  322. if past_key_values is not None:
  323. cache_kwargs = {"sin": sin, "cos": cos, "cache_position": cache_position}
  324. key_states, value_states = past_key_values.update(
  325. key_states, value_states, self.layer_idx, cache_kwargs
  326. )
  327. attention_interface: Callable = eager_attention_forward
  328. if self.config._attn_implementation != "eager":
  329. attention_interface = ALL_ATTENTION_FUNCTIONS[self.config._attn_implementation]
  330. is_causal = self.is_causal and attention_mask is None and q_len > 1
  331. if self.head_dim_padding > 0:
  332. query_states = torch.nn.functional.pad(query_states, (0, self.head_dim_padding))
  333. key_states = torch.nn.functional.pad(key_states, (0, self.head_dim_padding))
  334. value_states = torch.nn.functional.pad(value_states, (0, self.head_dim_padding))
  335. attn_output, attn_weights = attention_interface(
  336. self,
  337. query_states,
  338. key_states,
  339. value_states,
  340. attention_mask,
  341. dropout=0.0 if not self.training else self.attention_dropout,
  342. scaling=self.scaling,
  343. is_causal=is_causal,
  344. **kwargs,
  345. )
  346. if self.head_dim_padding > 0:
  347. attn_output = attn_output[..., : -self.head_dim_padding]
  348. attn_output = attn_output.reshape(bsz, q_len, -1).contiguous()
  349. attn_output = self.o_proj(attn_output)
  350. return attn_output, attn_weights
  351. class MoonshineRotaryEmbedding(GlmRotaryEmbedding):
  352. pass
  353. class MoonshineEncoderLayer(LlamaDecoderLayer):
  354. def __init__(self, config: MoonshineConfig, layer_idx: int):
  355. super().__init__(config, layer_idx)
  356. self.self_attn = MoonshineAttention(
  357. config=config,
  358. layer_idx=layer_idx,
  359. is_causal=False,
  360. num_attention_heads=config.encoder_num_attention_heads,
  361. num_key_value_heads=config.encoder_num_key_value_heads,
  362. )
  363. self.mlp = MoonshineEncoderMLP(config, config.encoder_hidden_act)
  364. self.input_layernorm = nn.LayerNorm(config.hidden_size, bias=False)
  365. self.post_attention_layernorm = nn.LayerNorm(config.hidden_size, bias=False)
  366. class MoonshineDecoderLayer(GradientCheckpointingLayer):
  367. def __init__(self, config: MoonshineConfig, layer_idx: Optional[int] = None):
  368. super().__init__()
  369. self.hidden_size = config.hidden_size
  370. self.self_attn = MoonshineAttention(
  371. config=config,
  372. layer_idx=layer_idx,
  373. is_causal=True,
  374. num_attention_heads=config.decoder_num_attention_heads,
  375. num_key_value_heads=config.decoder_num_key_value_heads,
  376. )
  377. self.encoder_attn = MoonshineAttention(
  378. config=config,
  379. layer_idx=layer_idx,
  380. is_causal=False,
  381. num_attention_heads=config.decoder_num_attention_heads,
  382. num_key_value_heads=config.decoder_num_key_value_heads,
  383. )
  384. self.mlp = MoonshineDecoderMLP(config, config.decoder_hidden_act)
  385. self.input_layernorm = nn.LayerNorm(config.hidden_size, bias=False)
  386. self.post_attention_layernorm = nn.LayerNorm(config.hidden_size, bias=False)
  387. self.final_layernorm = nn.LayerNorm(config.hidden_size, bias=False)
  388. @deprecate_kwarg("past_key_value", new_name="past_key_values", version="4.58")
  389. def forward(
  390. self,
  391. hidden_states: torch.Tensor,
  392. attention_mask: Optional[torch.Tensor] = None,
  393. encoder_hidden_states: Optional[torch.Tensor] = None,
  394. encoder_attention_mask: Optional[torch.Tensor] = None,
  395. position_ids: Optional[torch.LongTensor] = None,
  396. encoder_position_ids: Optional[torch.LongTensor] = None,
  397. past_key_values: Optional[Cache] = None,
  398. use_cache: Optional[bool] = False,
  399. cache_position: Optional[torch.LongTensor] = None,
  400. position_embeddings: Optional[tuple[torch.Tensor, torch.Tensor]] = None,
  401. encoder_position_embeddings: Optional[tuple[torch.Tensor, torch.Tensor]] = None,
  402. **kwargs: Unpack[TransformersKwargs],
  403. ) -> tuple[torch.FloatTensor, Optional[tuple[torch.FloatTensor, torch.FloatTensor]]]:
  404. residual = hidden_states
  405. hidden_states = self.input_layernorm(hidden_states)
  406. hidden_states, _ = self.self_attn(
  407. hidden_states=hidden_states,
  408. attention_mask=attention_mask,
  409. position_ids=position_ids,
  410. past_key_values=past_key_values,
  411. use_cache=use_cache,
  412. cache_position=cache_position,
  413. position_embeddings=position_embeddings,
  414. **kwargs,
  415. )
  416. hidden_states = residual + hidden_states
  417. if encoder_hidden_states is not None:
  418. residual = hidden_states
  419. hidden_states = self.post_attention_layernorm(hidden_states)
  420. hidden_states, _ = self.encoder_attn(
  421. hidden_states=hidden_states,
  422. key_value_states=encoder_hidden_states,
  423. attention_mask=encoder_attention_mask,
  424. past_key_values=past_key_values,
  425. use_cache=use_cache,
  426. )
  427. hidden_states = residual + hidden_states
  428. residual = hidden_states
  429. hidden_states = self.final_layernorm(hidden_states)
  430. hidden_states = self.mlp(hidden_states)
  431. hidden_states = residual + hidden_states
  432. return hidden_states
  433. @auto_docstring
  434. class MoonshinePreTrainedModel(PreTrainedModel):
  435. config: MoonshineConfig
  436. base_model_prefix = "model"
  437. main_input_name = "input_values"
  438. supports_gradient_checkpointing = True
  439. _no_split_modules = ["MoonshineEncoderLayer", "MoonshineDecoderLayer"]
  440. _supports_flash_attn = True
  441. _supports_sdpa = True
  442. _can_compile_fullgraph = True
  443. # TODO arthur, how do we separate when it cross / self coming from different layer?
  444. def _get_feat_extract_output_lengths(self, input_lengths: torch.LongTensor):
  445. """
  446. Computes the output length of the convolutional layers
  447. """
  448. output_conv1_length = int((input_lengths - 127) / 64 + 1)
  449. output_conv2_length = int((output_conv1_length - 7) / 3 + 1)
  450. output_conv3_length = int((output_conv2_length - 3) / 2 + 1)
  451. return output_conv3_length
  452. class MoonshineEncoder(MoonshinePreTrainedModel):
  453. """
  454. Transformer encoder consisting of *config.num_hidden_layers* layers. Each layer is a [`MoonshineEncoderLayer`]
  455. Args:
  456. config: MoonshineConfig
  457. """
  458. main_input_name = "input_values"
  459. _can_record_outputs = {
  460. "attentions": MoonshineAttention,
  461. "hidden_states": MoonshineEncoderLayer,
  462. }
  463. def __init__(self, config: MoonshineConfig):
  464. super().__init__(config)
  465. self.config = config
  466. embed_dim = config.hidden_size
  467. self.conv1 = nn.Conv1d(1, embed_dim, kernel_size=127, stride=64, bias=False)
  468. self.conv2 = nn.Conv1d(embed_dim, 2 * embed_dim, kernel_size=7, stride=3)
  469. self.conv3 = nn.Conv1d(2 * embed_dim, embed_dim, kernel_size=3, stride=2)
  470. self.groupnorm = nn.GroupNorm(num_groups=1, num_channels=embed_dim, eps=1e-5)
  471. self.rotary_emb = MoonshineRotaryEmbedding(config=config)
  472. self.layers = nn.ModuleList(
  473. [MoonshineEncoderLayer(config, idx) for idx in range(config.encoder_num_hidden_layers)]
  474. )
  475. self.layer_norm = nn.LayerNorm(embed_dim, bias=False)
  476. self.gradient_checkpointing = False
  477. self.post_init()
  478. def get_input_embeddings(self) -> nn.Module:
  479. return self.conv1
  480. def set_input_embeddings(self, value: nn.Module):
  481. self.conv1 = value
  482. @check_model_inputs()
  483. def forward(
  484. self,
  485. input_values: torch.FloatTensor,
  486. attention_mask: Optional[torch.Tensor] = None,
  487. **kwargs: Unpack[TransformersKwargs],
  488. ) -> BaseModelOutputWithPast:
  489. r"""
  490. Args:
  491. input_values (`torch.FloatTensor` of shape `(batch_size, audio_length)`):
  492. Float values of the raw speech waveform. Raw speech waveform can be
  493. obtained by loading a `.flac` or `.wav` audio file into an array of type `list[float]`, a
  494. `numpy.ndarray` or a `torch.Tensor`, *e.g.* via the torchcodec library (`pip install torchcodec`) or
  495. the soundfile library (`pip install soundfile`). To prepare the array into
  496. `input_values`, the [`AutoFeatureExtractor`] should be used for padding
  497. and conversion into a tensor of type `torch.FloatTensor`.
  498. attention_mask (`torch.Tensor` of shape `(batch_size, sequence_length)`, *optional*):
  499. Mask to avoid performing attention on padding indices in `input_values`. Mask values selected in `[0, 1]`:
  500. - 1 for tokens that are **not masked**,
  501. - 0 for tokens that are **masked**.
  502. [What are attention masks?](../glossary#attention-mask)
  503. """
  504. input_values = input_values.unsqueeze(1)
  505. hidden_states = nn.functional.tanh(self.conv1(input_values))
  506. hidden_states = self.groupnorm(hidden_states)
  507. hidden_states = nn.functional.gelu(self.conv2(hidden_states))
  508. hidden_states = nn.functional.gelu(self.conv3(hidden_states))
  509. hidden_states = hidden_states.permute(0, 2, 1)
  510. # attention mask downsampling
  511. if attention_mask is not None:
  512. mask_len = self._get_feat_extract_output_lengths(attention_mask.shape[-1])
  513. downsample_stride = 64 * 3 * 2 # conv strides
  514. attention_mask = attention_mask[..., ::downsample_stride][..., :mask_len]
  515. if self.config._attn_implementation == "flash_attention_2":
  516. attention_mask = attention_mask if (attention_mask == 0.0).any() else None
  517. elif self.config._attn_implementation == "sdpa":
  518. attention_mask = _prepare_4d_attention_mask_for_sdpa(attention_mask, hidden_states.dtype)
  519. else:
  520. attention_mask = _prepare_4d_attention_mask(attention_mask, hidden_states.dtype)
  521. position_ids = torch.arange(0, hidden_states.shape[1], device=hidden_states.device).unsqueeze(0)
  522. position_embeddings = self.rotary_emb(hidden_states, position_ids)
  523. for encoder_layer in self.layers:
  524. hidden_states = encoder_layer(
  525. hidden_states,
  526. attention_mask=attention_mask,
  527. position_ids=position_ids,
  528. position_embeddings=position_embeddings,
  529. **kwargs,
  530. )
  531. hidden_states = self.layer_norm(hidden_states)
  532. return BaseModelOutputWithPast(
  533. last_hidden_state=hidden_states,
  534. )
  535. class MoonshineDecoder(LlamaModel):
  536. main_input_name = "input_ids"
  537. _can_record_outputs = {
  538. "attentions": OutputRecorder(MoonshineAttention, index=1, layer_name="self_attn"),
  539. "hidden_states": MoonshineDecoderLayer,
  540. "cross_attentions": OutputRecorder(MoonshineAttention, index=1, layer_name="encoder_attn"),
  541. }
  542. def __init__(self, config: MoonshineConfig):
  543. super().__init__(config)
  544. self.norm = nn.LayerNorm(config.hidden_size, bias=False)
  545. self.layers = nn.ModuleList(
  546. [MoonshineDecoderLayer(config, idx) for idx in range(config.decoder_num_hidden_layers)]
  547. )
  548. @check_model_inputs()
  549. def forward(
  550. self,
  551. input_ids: Optional[torch.LongTensor] = None,
  552. attention_mask: Optional[torch.Tensor] = None,
  553. position_ids: Optional[torch.LongTensor] = None,
  554. past_key_values: Optional[Cache] = None,
  555. inputs_embeds: Optional[torch.FloatTensor] = None,
  556. use_cache: Optional[bool] = None,
  557. cache_position: Optional[torch.LongTensor] = None,
  558. encoder_hidden_states: Optional[torch.FloatTensor] = None,
  559. encoder_attention_mask: Optional[torch.Tensor] = None,
  560. **kwargs: Unpack[TransformersKwargs],
  561. ) -> Union[tuple, BaseModelOutputWithPast]:
  562. r"""
  563. encoder_hidden_states (`torch.FloatTensor` of shape `(batch_size, encoder_sequence_length, hidden_size)`, *optional*):
  564. Sequence of hidden-states at the output of the last layer of the encoder. Used in the cross-attention
  565. of the decoder.
  566. encoder_attention_mask (`torch.Tensor` of shape `(batch_size, sequence_length)`, *optional*):
  567. Mask to avoid performing attention on padding indices in `encoder_hidden_states`. Mask values selected in `[0, 1]`:
  568. - 1 for tokens that are **not masked**,
  569. - 0 for tokens that are **masked**.
  570. [What are attention masks?](../glossary#attention-mask)
  571. """
  572. if (input_ids is None) ^ (inputs_embeds is not None):
  573. raise ValueError("You must specify exactly one of input_ids or inputs_embeds")
  574. if inputs_embeds is None:
  575. inputs_embeds = self.embed_tokens(input_ids)
  576. if use_cache and past_key_values is None:
  577. past_key_values = EncoderDecoderCache(DynamicCache(config=self.config), DynamicCache(config=self.config))
  578. if cache_position is None:
  579. past_seen_tokens = past_key_values.get_seq_length() if past_key_values is not None else 0
  580. cache_position = torch.arange(
  581. past_seen_tokens, past_seen_tokens + inputs_embeds.shape[1], device=inputs_embeds.device
  582. )
  583. if position_ids is None:
  584. position_ids = cache_position.unsqueeze(0)
  585. causal_mask = create_causal_mask(
  586. config=self.config,
  587. input_embeds=inputs_embeds,
  588. attention_mask=attention_mask,
  589. cache_position=cache_position,
  590. past_key_values=past_key_values,
  591. position_ids=position_ids,
  592. )
  593. hidden_states = inputs_embeds
  594. position_embeddings = self.rotary_emb(hidden_states, position_ids)
  595. if encoder_attention_mask is not None:
  596. mask_len = encoder_hidden_states.shape[-2]
  597. downsample_stride = 64 * 3 * 2 # conv strides
  598. encoder_attention_mask = encoder_attention_mask[..., ::downsample_stride][..., :mask_len]
  599. if self.config._attn_implementation == "flash_attention_2":
  600. encoder_attention_mask = encoder_attention_mask if (encoder_attention_mask == 0.0).any() else None
  601. elif self.config._attn_implementation == "sdpa":
  602. encoder_attention_mask = _prepare_4d_attention_mask_for_sdpa(
  603. encoder_attention_mask, hidden_states.dtype, hidden_states.shape[-2]
  604. )
  605. else:
  606. encoder_attention_mask = _prepare_4d_attention_mask(
  607. encoder_attention_mask, hidden_states.dtype, hidden_states.shape[-2]
  608. )
  609. for decoder_layer in self.layers:
  610. hidden_states = decoder_layer(
  611. hidden_states,
  612. causal_mask,
  613. encoder_hidden_states, # as a positional argument for gradient checkpointing
  614. encoder_attention_mask=encoder_attention_mask,
  615. position_ids=position_ids,
  616. past_key_values=past_key_values,
  617. use_cache=use_cache,
  618. cache_position=cache_position,
  619. position_embeddings=position_embeddings,
  620. **kwargs,
  621. )
  622. hidden_states = self.norm(hidden_states)
  623. return BaseModelOutputWithPastAndCrossAttentions(
  624. last_hidden_state=hidden_states,
  625. past_key_values=past_key_values if use_cache else None,
  626. )
  627. class MoonshineModel(WhisperModel):
  628. @can_return_tuple
  629. @auto_docstring
  630. def forward(
  631. self,
  632. input_values: Optional[torch.FloatTensor] = None,
  633. attention_mask: Optional[torch.LongTensor] = None,
  634. decoder_input_ids: Optional[torch.LongTensor] = None,
  635. decoder_attention_mask: Optional[torch.LongTensor] = None,
  636. encoder_outputs: Optional[tuple[tuple[torch.FloatTensor]]] = None,
  637. past_key_values: Optional[Union[EncoderDecoderCache, tuple[torch.FloatTensor]]] = None,
  638. decoder_inputs_embeds: Optional[tuple[torch.FloatTensor]] = None,
  639. decoder_position_ids: Optional[tuple[torch.LongTensor]] = None,
  640. use_cache: Optional[bool] = None,
  641. cache_position: Optional[torch.LongTensor] = None,
  642. **kwargs: Unpack[TransformersKwargs],
  643. ) -> Seq2SeqModelOutput:
  644. r"""
  645. input_values (`torch.FloatTensor` of shape `(batch_size, audio_length)`):
  646. Float values of the raw speech waveform. Raw speech waveform can be
  647. obtained by loading a `.flac` or `.wav` audio file into an array of type `list[float]`, a
  648. `numpy.ndarray` or a `torch.Tensor`, *e.g.* via the torchcodec library (`pip install torchcodec`) or
  649. the soundfile library (`pip install soundfile`). To prepare the array into
  650. `input_values`, the [`AutoFeatureExtractor`] should be used for padding
  651. and conversion into a tensor of type `torch.FloatTensor`.
  652. decoder_position_ids (`torch.LongTensor` of shape `(batch_size, target_sequence_length)`):
  653. Indices of positions of each input sequence tokens in the position embeddings.
  654. Used to calculate the position embeddings up to `config.decoder_config.max_position_embeddings`
  655. Example:
  656. ```python
  657. >>> import torch
  658. >>> from transformers import AutoFeatureExtractor, MoonshineModel
  659. >>> from datasets import load_dataset
  660. >>> model = MoonshineModel.from_pretrained("UsefulSensors/moonshine-tiny")
  661. >>> feature_extractor = AutoFeatureExtractor.from_pretrained("UsefulSensors/moonshine-tiny")
  662. >>> ds = load_dataset("hf-internal-testing/librispeech_asr_dummy", "clean", split="validation")
  663. >>> inputs = feature_extractor(ds[0]["audio"]["array"], return_tensors="pt")
  664. >>> input_values = inputs.input_values
  665. >>> decoder_input_ids = torch.tensor([[1, 1]]) * model.config.decoder_start_token_id
  666. >>> last_hidden_state = model(input_values, decoder_input_ids=decoder_input_ids).last_hidden_state
  667. >>> list(last_hidden_state.shape)
  668. [1, 2, 288]
  669. ```
  670. """
  671. if encoder_outputs is None:
  672. encoder_outputs: BaseModelOutput = self.encoder(input_values, attention_mask=attention_mask, **kwargs)
  673. decoder_outputs: BaseModelOutputWithPastAndCrossAttentions = self.decoder(
  674. input_ids=decoder_input_ids,
  675. attention_mask=decoder_attention_mask,
  676. encoder_attention_mask=attention_mask,
  677. encoder_hidden_states=encoder_outputs.last_hidden_state,
  678. past_key_values=past_key_values,
  679. inputs_embeds=decoder_inputs_embeds,
  680. position_ids=decoder_position_ids,
  681. use_cache=use_cache,
  682. cache_position=cache_position,
  683. **kwargs,
  684. )
  685. return Seq2SeqModelOutput(
  686. last_hidden_state=decoder_outputs.last_hidden_state,
  687. past_key_values=decoder_outputs.past_key_values,
  688. decoder_hidden_states=decoder_outputs.hidden_states,
  689. decoder_attentions=decoder_outputs.attentions,
  690. cross_attentions=decoder_outputs.cross_attentions,
  691. encoder_last_hidden_state=encoder_outputs.last_hidden_state,
  692. encoder_hidden_states=encoder_outputs.hidden_states,
  693. encoder_attentions=encoder_outputs.attentions,
  694. )
  695. @auto_docstring(
  696. custom_intro="""
  697. The Moonshine Model with a language modeling head. Can be used for automatic speech recognition.
  698. """
  699. )
  700. class MoonshineForConditionalGeneration(MoonshinePreTrainedModel, GenerationMixin):
  701. _tied_weights_keys = ["proj_out.weight"]
  702. def __init__(self, config: MoonshineConfig):
  703. super().__init__(config)
  704. self.model = MoonshineModel(config)
  705. self.proj_out = nn.Linear(config.hidden_size, config.vocab_size, bias=False)
  706. # Initialize weights and apply final processing
  707. self.post_init()
  708. def get_encoder(self):
  709. return self.model.get_encoder()
  710. def get_decoder(self):
  711. return self.model.get_decoder()
  712. def get_output_embeddings(self):
  713. return self.proj_out
  714. def set_output_embeddings(self, new_embeddings):
  715. self.proj_out = new_embeddings
  716. def get_input_embeddings(self) -> nn.Module:
  717. return self.model.get_input_embeddings()
  718. @can_return_tuple
  719. @auto_docstring
  720. def forward(
  721. self,
  722. input_values: Optional[torch.FloatTensor] = None,
  723. attention_mask: Optional[torch.LongTensor] = None,
  724. decoder_input_ids: Optional[torch.LongTensor] = None,
  725. decoder_attention_mask: Optional[torch.LongTensor] = None,
  726. encoder_outputs: Optional[tuple[tuple[torch.FloatTensor]]] = None,
  727. past_key_values: Optional[Union[EncoderDecoderCache, tuple[torch.FloatTensor]]] = None,
  728. decoder_inputs_embeds: Optional[tuple[torch.FloatTensor]] = None,
  729. decoder_position_ids: Optional[tuple[torch.LongTensor]] = None,
  730. use_cache: Optional[bool] = None,
  731. cache_position: Optional[torch.LongTensor] = None,
  732. labels: Optional[torch.LongTensor] = None,
  733. **kwargs: Unpack[TransformersKwargs],
  734. ) -> Seq2SeqLMOutput:
  735. r"""
  736. input_values (`torch.FloatTensor` of shape `(batch_size, audio_length)`):
  737. Float values of the raw speech waveform. Raw speech waveform can be
  738. obtained by loading a `.flac` or `.wav` audio file into an array of type `list[float]`, a
  739. `numpy.ndarray` or a `torch.Tensor`, *e.g.* via the torchcodec library (`pip install torchcodec`) or
  740. the soundfile library (`pip install soundfile`). To prepare the array into
  741. `input_values`, the [`AutoFeatureExtractor`] should be used for padding
  742. and conversion into a tensor of type `torch.FloatTensor`.
  743. decoder_position_ids (`torch.LongTensor` of shape `(batch_size, target_sequence_length)`):
  744. Indices of positions of each input sequence tokens in the position embeddings.
  745. Used to calculate the position embeddings up to `config.decoder_config.max_position_embeddings`
  746. Example:
  747. ```python
  748. >>> import torch
  749. >>> from transformers import AutoProcessor, MoonshineForConditionalGeneration
  750. >>> from datasets import load_dataset
  751. >>> processor = AutoProcessor.from_pretrained("UsefulSensors/moonshine-tiny")
  752. >>> model = MoonshineForConditionalGeneration.from_pretrained("UsefulSensors/moonshine-tiny")
  753. >>> ds = load_dataset("hf-internal-testing/librispeech_asr_dummy", "clean", split="validation")
  754. >>> inputs = processor(ds[0]["audio"]["array"], return_tensors="pt")
  755. >>> input_values = inputs.input_values
  756. >>> generated_ids = model.generate(input_values, max_new_tokens=100)
  757. >>> transcription = processor.batch_decode(generated_ids, skip_special_tokens=True)[0]
  758. >>> transcription
  759. 'Mr. Quilter is the apostle of the middle classes, and we are glad to welcome his gospel.'
  760. ```"""
  761. if labels is not None:
  762. if decoder_input_ids is None and decoder_inputs_embeds is None:
  763. decoder_input_ids = shift_tokens_right(
  764. labels, self.config.pad_token_id, self.config.decoder_start_token_id
  765. )
  766. outputs: Seq2SeqModelOutput = self.model(
  767. input_values,
  768. attention_mask=attention_mask,
  769. decoder_input_ids=decoder_input_ids,
  770. encoder_outputs=encoder_outputs,
  771. decoder_attention_mask=decoder_attention_mask,
  772. past_key_values=past_key_values,
  773. decoder_inputs_embeds=decoder_inputs_embeds,
  774. decoder_position_ids=decoder_position_ids,
  775. use_cache=use_cache,
  776. cache_position=cache_position,
  777. **kwargs,
  778. )
  779. logits = self.proj_out(outputs.last_hidden_state)
  780. loss = None
  781. if labels is not None:
  782. loss = self.loss_function(logits=logits, labels=labels, vocab_size=self.config.vocab_size)
  783. return Seq2SeqLMOutput(
  784. loss=loss,
  785. logits=logits,
  786. past_key_values=outputs.past_key_values,
  787. decoder_hidden_states=outputs.decoder_hidden_states,
  788. decoder_attentions=outputs.decoder_attentions,
  789. cross_attentions=outputs.cross_attentions,
  790. encoder_last_hidden_state=outputs.encoder_last_hidden_state,
  791. encoder_hidden_states=outputs.encoder_hidden_states,
  792. encoder_attentions=outputs.encoder_attentions,
  793. )
  794. __all__ = [
  795. "MoonshineConfig",
  796. "MoonshineModel",
  797. "MoonshinePreTrainedModel",
  798. "MoonshineForConditionalGeneration",
  799. ]