modeling_jamba.py 67 KB

1234567891011121314151617181920212223242526272829303132333435363738394041424344454647484950515253545556575859606162636465666768697071727374757677787980818283848586878889909192939495969798991001011021031041051061071081091101111121131141151161171181191201211221231241251261271281291301311321331341351361371381391401411421431441451461471481491501511521531541551561571581591601611621631641651661671681691701711721731741751761771781791801811821831841851861871881891901911921931941951961971981992002012022032042052062072082092102112122132142152162172182192202212222232242252262272282292302312322332342352362372382392402412422432442452462472482492502512522532542552562572582592602612622632642652662672682692702712722732742752762772782792802812822832842852862872882892902912922932942952962972982993003013023033043053063073083093103113123133143153163173183193203213223233243253263273283293303313323333343353363373383393403413423433443453463473483493503513523533543553563573583593603613623633643653663673683693703713723733743753763773783793803813823833843853863873883893903913923933943953963973983994004014024034044054064074084094104114124134144154164174184194204214224234244254264274284294304314324334344354364374384394404414424434444454464474484494504514524534544554564574584594604614624634644654664674684694704714724734744754764774784794804814824834844854864874884894904914924934944954964974984995005015025035045055065075085095105115125135145155165175185195205215225235245255265275285295305315325335345355365375385395405415425435445455465475485495505515525535545555565575585595605615625635645655665675685695705715725735745755765775785795805815825835845855865875885895905915925935945955965975985996006016026036046056066076086096106116126136146156166176186196206216226236246256266276286296306316326336346356366376386396406416426436446456466476486496506516526536546556566576586596606616626636646656666676686696706716726736746756766776786796806816826836846856866876886896906916926936946956966976986997007017027037047057067077087097107117127137147157167177187197207217227237247257267277287297307317327337347357367377387397407417427437447457467477487497507517527537547557567577587597607617627637647657667677687697707717727737747757767777787797807817827837847857867877887897907917927937947957967977987998008018028038048058068078088098108118128138148158168178188198208218228238248258268278288298308318328338348358368378388398408418428438448458468478488498508518528538548558568578588598608618628638648658668678688698708718728738748758768778788798808818828838848858868878888898908918928938948958968978988999009019029039049059069079089099109119129139149159169179189199209219229239249259269279289299309319329339349359369379389399409419429439449459469479489499509519529539549559569579589599609619629639649659669679689699709719729739749759769779789799809819829839849859869879889899909919929939949959969979989991000100110021003100410051006100710081009101010111012101310141015101610171018101910201021102210231024102510261027102810291030103110321033103410351036103710381039104010411042104310441045104610471048104910501051105210531054105510561057105810591060106110621063106410651066106710681069107010711072107310741075107610771078107910801081108210831084108510861087108810891090109110921093109410951096109710981099110011011102110311041105110611071108110911101111111211131114111511161117111811191120112111221123112411251126112711281129113011311132113311341135113611371138113911401141114211431144114511461147114811491150115111521153115411551156115711581159116011611162116311641165116611671168116911701171117211731174117511761177117811791180118111821183118411851186118711881189119011911192119311941195119611971198119912001201120212031204120512061207120812091210121112121213121412151216121712181219122012211222122312241225122612271228122912301231123212331234123512361237123812391240124112421243124412451246124712481249125012511252125312541255125612571258125912601261126212631264126512661267126812691270127112721273127412751276127712781279128012811282128312841285128612871288128912901291129212931294129512961297129812991300130113021303130413051306130713081309131013111312131313141315131613171318131913201321132213231324132513261327132813291330133113321333133413351336133713381339134013411342134313441345134613471348134913501351135213531354135513561357135813591360136113621363136413651366136713681369137013711372137313741375137613771378137913801381138213831384138513861387138813891390139113921393139413951396139713981399140014011402140314041405140614071408140914101411141214131414141514161417141814191420142114221423142414251426142714281429143014311432143314341435143614371438143914401441144214431444144514461447144814491450145114521453145414551456145714581459146014611462
  1. # coding=utf-8
  2. # Copyright 2024 AI21 Labs Ltd. and the HuggingFace Inc. team. All rights reserved.
  3. #
  4. # This code is based on EleutherAI's GPT-NeoX library and the GPT-NeoX
  5. # and OPT implementations in this library. It has been modified from its
  6. # original forms to accommodate minor architectural differences compared
  7. # to GPT-NeoX and OPT used by the Meta AI team that trained the model.
  8. #
  9. # Licensed under the Apache License, Version 2.0 (the "License");
  10. # you may not use this file except in compliance with the License.
  11. # You may obtain a copy of the License at
  12. #
  13. # http://www.apache.org/licenses/LICENSE-2.0
  14. #
  15. # Unless required by applicable law or agreed to in writing, software
  16. # distributed under the License is distributed on an "AS IS" BASIS,
  17. # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  18. # See the License for the specific language governing permissions and
  19. # limitations under the License.
  20. """PyTorch Jamba model."""
  21. import math
  22. from typing import Any, Optional, Union
  23. import torch
  24. import torch.nn.functional as F
  25. from torch import nn
  26. from ...activations import ACT2FN
  27. from ...generation import GenerationMixin
  28. from ...modeling_attn_mask_utils import AttentionMaskConverter
  29. from ...modeling_flash_attention_utils import flash_attn_supports_top_left_mask, is_flash_attn_available
  30. from ...modeling_layers import (
  31. GenericForSequenceClassification,
  32. GradientCheckpointingLayer,
  33. )
  34. from ...modeling_outputs import MoeCausalLMOutputWithPast, MoeModelOutputWithPast
  35. from ...modeling_utils import PreTrainedModel
  36. from ...processing_utils import Unpack
  37. from ...utils import TransformersKwargs, auto_docstring, can_return_tuple, logging
  38. from ...utils.deprecation import deprecate_kwarg
  39. from ...utils.import_utils import is_causal_conv1d_available, is_mamba_ssm_available
  40. from .configuration_jamba import JambaConfig
  41. if is_flash_attn_available():
  42. from ...modeling_flash_attention_utils import _flash_attention_forward
  43. if is_mamba_ssm_available():
  44. from mamba_ssm.ops.selective_scan_interface import mamba_inner_fn, selective_scan_fn
  45. from mamba_ssm.ops.triton.selective_state_update import selective_state_update
  46. else:
  47. selective_state_update, selective_scan_fn, mamba_inner_fn = None, None, None
  48. if is_causal_conv1d_available():
  49. from causal_conv1d import causal_conv1d_fn, causal_conv1d_update
  50. else:
  51. causal_conv1d_update, causal_conv1d_fn = None, None
  52. is_fast_path_available = all(
  53. (selective_state_update, selective_scan_fn, causal_conv1d_fn, causal_conv1d_update, mamba_inner_fn)
  54. )
  55. logger = logging.get_logger(__name__)
  56. # Copied from transformers.models.qwen2_moe.modeling_qwen2_moe.load_balancing_loss_func with gate->router
  57. def load_balancing_loss_func(
  58. router_logits: Union[torch.Tensor, tuple[torch.Tensor], None],
  59. num_experts: Optional[int] = None,
  60. top_k=2,
  61. attention_mask: Optional[torch.Tensor] = None,
  62. ) -> Union[torch.Tensor, int]:
  63. r"""
  64. Computes auxiliary load balancing loss as in Switch Transformer - implemented in Pytorch.
  65. See Switch Transformer (https://huggingface.co/papers/2101.03961) for more details. This function implements the loss
  66. function presented in equations (4) - (6) of the paper. It aims at penalizing cases where the routing between
  67. experts is too unbalanced.
  68. Args:
  69. router_logits:
  70. Logits from the `router`, should be a tuple of model.config.num_hidden_layers tensors of
  71. shape [batch_size X sequence_length, num_experts].
  72. num_experts:
  73. Number of experts
  74. top_k:
  75. The number of experts to route per-token, can be also interpreted as the `top-k` routing
  76. parameter.
  77. attention_mask (`torch.Tensor`, *optional*):
  78. The attention_mask used in forward function
  79. shape [batch_size X sequence_length] if not None.
  80. Returns:
  81. The auxiliary loss.
  82. """
  83. if router_logits is None or not isinstance(router_logits, tuple):
  84. return 0
  85. if isinstance(router_logits, tuple):
  86. compute_device = router_logits[0].device
  87. concatenated_router_logits = torch.cat(
  88. [layer_router.to(compute_device) for layer_router in router_logits], dim=0
  89. )
  90. routing_weights = torch.nn.functional.softmax(concatenated_router_logits, dim=-1)
  91. _, selected_experts = torch.topk(routing_weights, top_k, dim=-1)
  92. expert_mask = torch.nn.functional.one_hot(selected_experts, num_experts)
  93. if attention_mask is None:
  94. # Compute the percentage of tokens routed to each experts
  95. tokens_per_expert = torch.mean(expert_mask.float(), dim=0)
  96. # Compute the average probability of routing to these experts
  97. router_prob_per_expert = torch.mean(routing_weights, dim=0)
  98. else:
  99. batch_size, sequence_length = attention_mask.shape
  100. num_hidden_layers = concatenated_router_logits.shape[0] // (batch_size * sequence_length)
  101. # Compute the mask that masks all padding tokens as 0 with the same shape of expert_mask
  102. expert_attention_mask = (
  103. attention_mask[None, :, :, None, None]
  104. .expand((num_hidden_layers, batch_size, sequence_length, top_k, num_experts))
  105. .reshape(-1, top_k, num_experts)
  106. .to(compute_device)
  107. )
  108. # Compute the percentage of tokens routed to each experts
  109. tokens_per_expert = torch.sum(expert_mask.float() * expert_attention_mask, dim=0) / torch.sum(
  110. expert_attention_mask, dim=0
  111. )
  112. # Compute the mask that masks all padding tokens as 0 with the same shape of tokens_per_expert
  113. router_per_expert_attention_mask = (
  114. attention_mask[None, :, :, None]
  115. .expand((num_hidden_layers, batch_size, sequence_length, routing_weights.shape[1]))
  116. .reshape(-1, routing_weights.shape[1])
  117. .to(compute_device)
  118. )
  119. # Compute the average probability of routing to these experts
  120. router_prob_per_expert = torch.sum(routing_weights * router_per_expert_attention_mask, dim=0) / torch.sum(
  121. router_per_expert_attention_mask, dim=0
  122. )
  123. device_index = routing_weights.device.index if routing_weights.device.index is not None else 0
  124. rank = routing_weights.shape[1] * int(device_index)
  125. overall_loss = torch.sum(
  126. tokens_per_expert[:, rank : rank + routing_weights.shape[1]] * router_prob_per_expert.unsqueeze(0)
  127. )
  128. return overall_loss * num_experts
  129. # Copied from transformers.models.llama.modeling_llama.LlamaRMSNorm with Llama->Jamba
  130. class JambaRMSNorm(nn.Module):
  131. def __init__(self, hidden_size, eps=1e-6):
  132. """
  133. JambaRMSNorm is equivalent to T5LayerNorm
  134. """
  135. super().__init__()
  136. self.weight = nn.Parameter(torch.ones(hidden_size))
  137. self.variance_epsilon = eps
  138. def forward(self, hidden_states):
  139. input_dtype = hidden_states.dtype
  140. hidden_states = hidden_states.to(torch.float32)
  141. variance = hidden_states.pow(2).mean(-1, keepdim=True)
  142. hidden_states = hidden_states * torch.rsqrt(variance + self.variance_epsilon)
  143. return self.weight * hidden_states.to(input_dtype)
  144. def extra_repr(self):
  145. return f"{tuple(self.weight.shape)}, eps={self.variance_epsilon}"
  146. # Copied from transformers.models.llama.modeling_llama.repeat_kv
  147. def repeat_kv(hidden_states: torch.Tensor, n_rep: int) -> torch.Tensor:
  148. """
  149. This is the equivalent of torch.repeat_interleave(x, dim=1, repeats=n_rep). The hidden states go from (batch,
  150. num_key_value_heads, seqlen, head_dim) to (batch, num_attention_heads, seqlen, head_dim)
  151. """
  152. batch, num_key_value_heads, slen, head_dim = hidden_states.shape
  153. if n_rep == 1:
  154. return hidden_states
  155. hidden_states = hidden_states[:, :, None, :, :].expand(batch, num_key_value_heads, n_rep, slen, head_dim)
  156. return hidden_states.reshape(batch, num_key_value_heads * n_rep, slen, head_dim)
  157. class HybridMambaAttentionDynamicCache:
  158. """
  159. A dynamic cache that can handle both the attention cache (which has a seq_len dimension) and the mamba cache
  160. (which has a constant shape regardless of seq_len).
  161. This cache has two sets of lists of tensors: `key_cache` and `value_cache` for attention cache and `conv_states`
  162. and `ssm_states` for mamba cache. Each of these lists has `num_layers` tensors. The expected shape for each tensor
  163. For attention layers, `key_cache` and `value_cache` have a shape of `(batch_size, num_heads, seq_len, head_dim)`,
  164. while `conv_states` and `ssm_states` have a shape of `(batch_size, 0)` (empty tensors).
  165. For mamba layers, `key_cache` and `value_cache` have a shape of `(batch_size, 0)` (empty tensors),
  166. while `conv_states` represents the convolution state and has a shape of `(batch_size, d_inner, d_conv)`,
  167. and `ssm_states` represents the ssm state and has a shape of `(batch_size, d_inner, d_state)`.
  168. """
  169. is_compileable = False
  170. def __init__(self, config, batch_size, dtype=torch.float16, device=None):
  171. self.dtype = dtype
  172. self.layers_block_type = config.layers_block_type
  173. self.has_previous_state = False # only used by mamba
  174. intermediate_size = config.mamba_expand * config.hidden_size
  175. ssm_state_size = config.mamba_d_state
  176. conv_kernel_size = config.mamba_d_conv
  177. self.conv_states = []
  178. self.ssm_states = []
  179. self.transformer_layers = []
  180. for i in range(config.num_hidden_layers):
  181. if self.layers_block_type[i] == "mamba":
  182. self.conv_states += [
  183. torch.zeros(batch_size, intermediate_size, conv_kernel_size, device=device, dtype=dtype)
  184. ]
  185. self.ssm_states += [
  186. torch.zeros(batch_size, intermediate_size, ssm_state_size, device=device, dtype=dtype)
  187. ]
  188. else:
  189. self.conv_states += [torch.tensor([[]] * batch_size, device=device)]
  190. self.ssm_states += [torch.tensor([[]] * batch_size, device=device)]
  191. self.transformer_layers.append(i)
  192. self.key_cache = [torch.tensor([[]] * batch_size, device=device) for _ in range(config.num_hidden_layers)]
  193. self.value_cache = [torch.tensor([[]] * batch_size, device=device) for _ in range(config.num_hidden_layers)]
  194. def update(
  195. self,
  196. key_states: torch.Tensor,
  197. value_states: torch.Tensor,
  198. layer_idx: int,
  199. cache_kwargs: Optional[dict[str, Any]] = None,
  200. ) -> tuple[torch.Tensor, torch.Tensor]:
  201. # Update the cache
  202. if self.key_cache[layer_idx].shape[-1] == 0:
  203. self.key_cache[layer_idx] = key_states
  204. self.value_cache[layer_idx] = value_states
  205. else:
  206. self.key_cache[layer_idx] = torch.cat([self.key_cache[layer_idx], key_states], dim=2)
  207. self.value_cache[layer_idx] = torch.cat([self.value_cache[layer_idx], value_states], dim=2)
  208. return self.key_cache[layer_idx], self.value_cache[layer_idx]
  209. def reorder_cache(self, beam_idx: torch.LongTensor):
  210. """Reorders the cache for beam search, given the selected beam indices."""
  211. for layer_idx in range(len(self.key_cache)):
  212. device = self.key_cache[layer_idx].device
  213. self.key_cache[layer_idx] = self.key_cache[layer_idx].index_select(0, beam_idx.to(device))
  214. device = self.value_cache[layer_idx].device
  215. self.value_cache[layer_idx] = self.value_cache[layer_idx].index_select(0, beam_idx.to(device))
  216. device = self.conv_states[layer_idx].device
  217. self.conv_states[layer_idx] = self.conv_states[layer_idx].index_select(0, beam_idx.to(device))
  218. device = self.ssm_states[layer_idx].device
  219. self.ssm_states[layer_idx] = self.ssm_states[layer_idx].index_select(0, beam_idx.to(device))
  220. def get_seq_length(self, layer_idx: Optional[int] = 0) -> int:
  221. """Returns the sequence length of the cached states. A layer index can be optionally passed."""
  222. # take any layer that contains cache and not empty tensor
  223. layer_idx = self.transformer_layers[0] if layer_idx not in self.transformer_layers else layer_idx
  224. if len(self.key_cache) <= layer_idx:
  225. return 0
  226. return self.key_cache[layer_idx].shape[-2]
  227. # Adapted from transformers.models.mistral.modeling_mistral.MistralAttention with Mistral->Jamba
  228. class JambaAttention(nn.Module):
  229. """
  230. Multi-headed attention from 'Attention Is All You Need' paper. Modified to use sliding window attention: Longformer
  231. and "Generating Long Sequences with Sparse Transformers".
  232. """
  233. def __init__(self, config: JambaConfig, layer_idx: Optional[int] = None):
  234. super().__init__()
  235. self.config = config
  236. self.layer_idx = layer_idx
  237. if layer_idx is None:
  238. logger.warning_once(
  239. f"Instantiating {self.__class__.__name__} without passing a `layer_idx` is not recommended and will "
  240. "lead to errors during the forward call if caching is used. Please make sure to provide a `layer_idx` "
  241. "when creating this class."
  242. )
  243. self.hidden_size = config.hidden_size
  244. self.num_heads = config.num_attention_heads
  245. self.head_dim = self.hidden_size // self.num_heads
  246. self.num_key_value_heads = config.num_key_value_heads
  247. self.num_key_value_groups = self.num_heads // self.num_key_value_heads
  248. self.is_causal = True
  249. self.attention_dropout = config.attention_dropout
  250. if (self.head_dim * self.num_heads) != self.hidden_size:
  251. raise ValueError(
  252. f"hidden_size must be divisible by num_heads (got `hidden_size`: {self.hidden_size}"
  253. f" and `num_heads`: {self.num_heads})."
  254. )
  255. self.q_proj = nn.Linear(self.hidden_size, self.num_heads * self.head_dim, bias=False)
  256. self.k_proj = nn.Linear(self.hidden_size, self.num_key_value_heads * self.head_dim, bias=False)
  257. self.v_proj = nn.Linear(self.hidden_size, self.num_key_value_heads * self.head_dim, bias=False)
  258. self.o_proj = nn.Linear(self.num_heads * self.head_dim, self.hidden_size, bias=False)
  259. @deprecate_kwarg("past_key_value", new_name="past_key_values", version="4.58")
  260. def forward(
  261. self,
  262. hidden_states: torch.Tensor,
  263. attention_mask: Optional[torch.Tensor] = None,
  264. position_ids: Optional[torch.LongTensor] = None,
  265. past_key_values: Optional[HybridMambaAttentionDynamicCache] = None,
  266. output_attentions: bool = False,
  267. use_cache: bool = False,
  268. cache_position: Optional[torch.LongTensor] = None,
  269. ) -> tuple[torch.Tensor, Optional[torch.Tensor], Optional[tuple[torch.Tensor]]]:
  270. bsz, q_len, _ = hidden_states.size()
  271. query_states = self.q_proj(hidden_states)
  272. key_states = self.k_proj(hidden_states)
  273. value_states = self.v_proj(hidden_states)
  274. query_states = query_states.view(bsz, q_len, self.num_heads, self.head_dim).transpose(1, 2)
  275. key_states = key_states.view(bsz, q_len, self.num_key_value_heads, self.head_dim).transpose(1, 2)
  276. value_states = value_states.view(bsz, q_len, self.num_key_value_heads, self.head_dim).transpose(1, 2)
  277. if past_key_values is not None:
  278. key_states, value_states = past_key_values.update(key_states, value_states, self.layer_idx)
  279. # repeat k/v heads if n_kv_heads < n_heads
  280. key_states = repeat_kv(key_states, self.num_key_value_groups)
  281. value_states = repeat_kv(value_states, self.num_key_value_groups)
  282. attn_weights = torch.matmul(query_states, key_states.transpose(2, 3)) / math.sqrt(self.head_dim)
  283. if attention_mask is not None: # no matter the length, we just slice it
  284. causal_mask = attention_mask[:, :, :, : key_states.shape[-2]]
  285. attn_weights = attn_weights + causal_mask
  286. # upcast attention to fp32
  287. attn_weights = nn.functional.softmax(attn_weights, dim=-1, dtype=torch.float32).to(query_states.dtype)
  288. attn_weights = nn.functional.dropout(attn_weights, p=self.attention_dropout, training=self.training)
  289. attn_output = torch.matmul(attn_weights, value_states)
  290. if attn_output.size() != (bsz, self.num_heads, q_len, self.head_dim):
  291. raise ValueError(
  292. f"`attn_output` should be of size {(bsz, self.num_heads, q_len, self.head_dim)}, but is"
  293. f" {attn_output.size()}"
  294. )
  295. attn_output = attn_output.transpose(1, 2).contiguous()
  296. attn_output = attn_output.reshape(bsz, q_len, self.hidden_size)
  297. attn_output = self.o_proj(attn_output)
  298. if not output_attentions:
  299. attn_weights = None
  300. return attn_output, attn_weights, past_key_values
  301. # Adapted from transformers.models.mistral.modeling_mistral.MistralFlashAttention2 with Mistral->Jamba
  302. class JambaFlashAttention2(JambaAttention):
  303. """
  304. Jamba flash attention module. This module inherits from `JambaAttention` as the weights of the module stays
  305. untouched. The only required change would be on the forward pass where it needs to correctly call the public API of
  306. flash attention and deal with padding tokens in case the input contains any of them.
  307. """
  308. def __init__(self, *args, **kwargs):
  309. super().__init__(*args, **kwargs)
  310. # TODO: Should be removed once Flash Attention for RoCm is bumped to 2.1.
  311. # flash_attn<2.1 generates top-left aligned causal mask, while what is needed here is bottom-right alignment, that was made default for flash_attn>=2.1. This attribute is used to handle this difference. Reference: https://github.com/Dao-AILab/flash-attention/releases/tag/v2.1.0.
  312. # Beware that with flash_attn<2.1, using q_seqlen != k_seqlen (except for the case q_seqlen == 1) produces a wrong mask (top-left).
  313. self._flash_attn_uses_top_left_mask = flash_attn_supports_top_left_mask()
  314. def forward(
  315. self,
  316. hidden_states: torch.Tensor,
  317. attention_mask: Optional[torch.Tensor] = None,
  318. position_ids: Optional[torch.LongTensor] = None,
  319. past_key_values: Optional[HybridMambaAttentionDynamicCache] = None,
  320. output_attentions: bool = False,
  321. use_cache: bool = False,
  322. cache_position: Optional[torch.LongTensor] = None,
  323. **kwargs,
  324. ):
  325. bsz, q_len, _ = hidden_states.size()
  326. query_states = self.q_proj(hidden_states)
  327. key_states = self.k_proj(hidden_states)
  328. value_states = self.v_proj(hidden_states)
  329. # Flash attention requires the input to have the shape
  330. # batch_size x seq_length x head_dim x hidden_dim
  331. # therefore we just need to keep the original shape
  332. query_states = query_states.view(bsz, q_len, self.num_heads, self.head_dim)
  333. key_states = key_states.view(bsz, q_len, self.num_key_value_heads, self.head_dim).transpose(1, 2)
  334. value_states = value_states.view(bsz, q_len, self.num_key_value_heads, self.head_dim).transpose(1, 2)
  335. if past_key_values is not None:
  336. key_states, value_states = past_key_values.update(key_states, value_states, self.layer_idx)
  337. # repeat k/v heads if n_kv_heads < n_heads
  338. key_states = repeat_kv(key_states, self.num_key_value_groups)
  339. value_states = repeat_kv(value_states, self.num_key_value_groups)
  340. dropout_rate = 0.0 if not self.training else self.attention_dropout
  341. # In PEFT, usually we cast the layer norms in float32 for training stability reasons
  342. # therefore the input hidden states gets silently casted in float32. Hence, we need
  343. # cast them back in float16 just to be sure everything works as expected.
  344. input_dtype = query_states.dtype
  345. device_type = query_states.device.type if query_states.device.type != "mps" else "cpu"
  346. if input_dtype == torch.float32:
  347. if torch.is_autocast_enabled():
  348. target_dtype = (
  349. torch.get_autocast_dtype(device_type)
  350. if hasattr(torch, "get_autocast_dtype")
  351. else torch.get_autocast_gpu_dtype()
  352. )
  353. # Handle the case where the model is quantized
  354. elif hasattr(self.config, "_pre_quantization_dtype"):
  355. target_dtype = self.config._pre_quantization_dtype
  356. else:
  357. target_dtype = self.q_proj.weight.dtype
  358. logger.warning_once(
  359. f"The input hidden states seems to be silently casted in float32, this might be related to"
  360. f" the fact you have upcasted embedding or layer norm layers in float32. We will cast back the input in"
  361. f" {target_dtype}."
  362. )
  363. query_states = query_states.to(target_dtype)
  364. key_states = key_states.to(target_dtype)
  365. value_states = value_states.to(target_dtype)
  366. # Reashape to the expected shape for Flash Attention
  367. key_states = key_states.transpose(1, 2)
  368. value_states = value_states.transpose(1, 2)
  369. attn_output = _flash_attention_forward(
  370. query_states,
  371. key_states,
  372. value_states,
  373. attention_mask,
  374. q_len,
  375. dropout=dropout_rate,
  376. sliding_window=getattr(self.config, "sliding_window", None),
  377. is_causal=self.is_causal,
  378. use_top_left_mask=self._flash_attn_uses_top_left_mask,
  379. )
  380. attn_output = attn_output.reshape(bsz, q_len, self.hidden_size).contiguous()
  381. attn_output = self.o_proj(attn_output)
  382. if not output_attentions:
  383. attn_weights = None
  384. return attn_output, attn_weights, past_key_values
  385. # Adapted from transformers.models.mistral.modeling_mistral.MistralSdpaAttention with Mistral->Jamba
  386. class JambaSdpaAttention(JambaAttention):
  387. """
  388. Jamba attention module using torch.nn.functional.scaled_dot_product_attention. This module inherits from
  389. `JambaAttention` as the weights of the module stays untouched. The only changes are on the forward pass to adapt to
  390. SDPA API.
  391. """
  392. # Adapted from JambaAttention.forward
  393. @deprecate_kwarg("past_key_value", new_name="past_key_values", version="4.58")
  394. def forward(
  395. self,
  396. hidden_states: torch.Tensor,
  397. attention_mask: Optional[torch.Tensor] = None,
  398. position_ids: Optional[torch.LongTensor] = None,
  399. past_key_values: Optional[HybridMambaAttentionDynamicCache] = None,
  400. output_attentions: bool = False,
  401. use_cache: bool = False,
  402. cache_position: Optional[torch.LongTensor] = None,
  403. ) -> tuple[torch.Tensor, Optional[torch.Tensor], Optional[tuple[torch.Tensor]]]:
  404. if output_attentions:
  405. # TODO: Improve this warning with e.g. `model.config.attn_implementation = "manual"` once this is implemented.
  406. logger.warning_once(
  407. "JambaModel is using JambaSdpaAttention, but `torch.nn.functional.scaled_dot_product_attention` does not support `output_attentions=True`. Falling back to the manual attention implementation, "
  408. 'but specifying the manual implementation will be required from Transformers version v5.0.0 onwards. This warning can be removed using the argument `attn_implementation="eager"` when loading the model.'
  409. )
  410. return super().forward(
  411. hidden_states=hidden_states,
  412. attention_mask=attention_mask,
  413. position_ids=position_ids,
  414. past_key_values=past_key_values,
  415. output_attentions=output_attentions,
  416. use_cache=use_cache,
  417. )
  418. bsz, q_len, _ = hidden_states.size()
  419. query_states = self.q_proj(hidden_states)
  420. key_states = self.k_proj(hidden_states)
  421. value_states = self.v_proj(hidden_states)
  422. query_states = query_states.view(bsz, q_len, self.num_heads, self.head_dim).transpose(1, 2)
  423. key_states = key_states.view(bsz, q_len, self.num_key_value_heads, self.head_dim).transpose(1, 2)
  424. value_states = value_states.view(bsz, q_len, self.num_key_value_heads, self.head_dim).transpose(1, 2)
  425. if past_key_values is not None:
  426. key_states, value_states = past_key_values.update(key_states, value_states, self.layer_idx)
  427. key_states = repeat_kv(key_states, self.num_key_value_groups)
  428. value_states = repeat_kv(value_states, self.num_key_value_groups)
  429. causal_mask = attention_mask
  430. if attention_mask is not None:
  431. causal_mask = causal_mask[:, :, :, : key_states.shape[-2]]
  432. # SDPA with memory-efficient backend is currently (torch==2.1.2) bugged with non-contiguous inputs with custom attn_mask,
  433. # Reference: https://github.com/pytorch/pytorch/issues/112577.
  434. if query_states.device.type == "cuda" and attention_mask is not None:
  435. query_states = query_states.contiguous()
  436. key_states = key_states.contiguous()
  437. value_states = value_states.contiguous()
  438. # We dispatch to SDPA's Flash Attention or Efficient kernels via this `is_causal` if statement instead of an inline conditional assignment
  439. # in SDPA to support both torch.compile's dynamic shapes and full graph options. An inline conditional prevents dynamic shapes from compiling.
  440. # The q_len > 1 is necessary to match with AttentionMaskConverter.to_causal_4d that does not create a causal mask in case q_len == 1.
  441. is_causal = self.is_causal and causal_mask is None and q_len > 1
  442. attn_output = torch.nn.functional.scaled_dot_product_attention(
  443. query_states,
  444. key_states,
  445. value_states,
  446. attn_mask=causal_mask,
  447. dropout_p=self.attention_dropout if self.training else 0.0,
  448. is_causal=is_causal,
  449. )
  450. attn_output = attn_output.transpose(1, 2).contiguous()
  451. attn_output = attn_output.view(bsz, q_len, self.hidden_size)
  452. attn_output = self.o_proj(attn_output)
  453. return attn_output, None, past_key_values
  454. JAMBA_ATTENTION_CLASSES = {
  455. "eager": JambaAttention,
  456. "flash_attention_2": JambaFlashAttention2,
  457. "sdpa": JambaSdpaAttention,
  458. }
  459. # Adapted from transformers.models.mamba.modeling_mamba.MambaMixer
  460. class JambaMambaMixer(nn.Module):
  461. """
  462. Compute ∆, A, B, C, and D the state space parameters and compute the `contextualized_states`.
  463. A, D are input independent (see Mamba paper [1] Section 3.5.2 "Interpretation of A" for why A isn't selective)
  464. ∆, B, C are input-dependent (this is a key difference between Mamba and the linear time invariant S4,
  465. and is why Mamba is called **selective** state spaces)
  466. """
  467. def __init__(self, config: JambaConfig, layer_idx):
  468. super().__init__()
  469. self.config = config
  470. self.layer_idx = layer_idx
  471. self.hidden_size = config.hidden_size
  472. self.ssm_state_size = config.mamba_d_state
  473. self.conv_kernel_size = config.mamba_d_conv
  474. self.intermediate_size = config.mamba_expand * config.hidden_size
  475. self.time_step_rank = config.mamba_dt_rank
  476. self.use_conv_bias = config.mamba_conv_bias
  477. self.use_bias = config.mamba_proj_bias
  478. self.conv1d = nn.Conv1d(
  479. in_channels=self.intermediate_size,
  480. out_channels=self.intermediate_size,
  481. bias=self.use_conv_bias,
  482. kernel_size=self.conv_kernel_size,
  483. groups=self.intermediate_size,
  484. padding=self.conv_kernel_size - 1,
  485. )
  486. self.activation = config.hidden_act
  487. self.act = ACT2FN[config.hidden_act]
  488. self.use_fast_kernels = config.use_mamba_kernels
  489. # projection of the input hidden states
  490. self.in_proj = nn.Linear(self.hidden_size, self.intermediate_size * 2, bias=self.use_bias)
  491. # selective projection used to make dt, B and C input dependent
  492. self.x_proj = nn.Linear(self.intermediate_size, self.time_step_rank + self.ssm_state_size * 2, bias=False)
  493. # time step projection (discretization)
  494. self.dt_proj = nn.Linear(self.time_step_rank, self.intermediate_size, bias=True)
  495. # S4D real initialization. These are not discretized!
  496. # The core is to load them, compute the discrete states, then write the updated state. Keeps the memory bounded
  497. A = torch.arange(1, self.ssm_state_size + 1)[None, :]
  498. A = A.expand(self.intermediate_size, -1).contiguous()
  499. self.A_log = nn.Parameter(torch.log(A))
  500. self.D = nn.Parameter(torch.ones(self.intermediate_size))
  501. self.out_proj = nn.Linear(self.intermediate_size, self.hidden_size, bias=self.use_bias)
  502. self.dt_layernorm = JambaRMSNorm(self.time_step_rank, eps=config.rms_norm_eps)
  503. self.b_layernorm = JambaRMSNorm(self.ssm_state_size, eps=config.rms_norm_eps)
  504. self.c_layernorm = JambaRMSNorm(self.ssm_state_size, eps=config.rms_norm_eps)
  505. if not is_fast_path_available:
  506. logger.warning_once(
  507. "The fast path is not available because one of `(selective_state_update, selective_scan_fn, causal_conv1d_fn, causal_conv1d_update, mamba_inner_fn)`"
  508. " is None. To install follow https://github.com/state-spaces/mamba/#installation and"
  509. " https://github.com/Dao-AILab/causal-conv1d. If you want to use the naive implementation, set `use_mamba_kernels=False` in the model config"
  510. )
  511. def cuda_kernels_forward(
  512. self,
  513. hidden_states: torch.Tensor,
  514. cache_params: Optional[HybridMambaAttentionDynamicCache] = None,
  515. attention_mask: Optional[torch.LongTensor] = None,
  516. ):
  517. batch_size, seq_len, _ = hidden_states.shape
  518. use_precomputed_states = (
  519. cache_params is not None
  520. and cache_params.has_previous_state
  521. and seq_len == 1
  522. and cache_params.conv_states[self.layer_idx].shape[0]
  523. == cache_params.ssm_states[self.layer_idx].shape[0]
  524. == batch_size
  525. )
  526. # 1. Gated MLP's linear projection
  527. projected_states = self.in_proj(hidden_states).transpose(1, 2)
  528. # We can't use `mamba_inner_fn` even if in training and without cache params because we have the
  529. # inner layernorms which isn't supported by this fused kernel
  530. hidden_states, gate = projected_states.chunk(2, dim=1)
  531. if attention_mask is not None:
  532. hidden_states = hidden_states * attention_mask.unsqueeze(1)
  533. # 2. Convolution sequence transformation
  534. conv_weights = self.conv1d.weight.view(self.conv1d.weight.size(0), self.conv1d.weight.size(2))
  535. if use_precomputed_states:
  536. hidden_states = causal_conv1d_update(
  537. hidden_states.squeeze(-1),
  538. cache_params.conv_states[self.layer_idx],
  539. conv_weights,
  540. self.conv1d.bias,
  541. self.activation,
  542. )
  543. hidden_states = hidden_states.unsqueeze(-1)
  544. else:
  545. if cache_params is not None:
  546. conv_states = nn.functional.pad(hidden_states, (self.conv_kernel_size - hidden_states.shape[-1], 0))
  547. cache_params.conv_states[self.layer_idx].copy_(conv_states)
  548. hidden_states = causal_conv1d_fn(hidden_states, conv_weights, self.conv1d.bias, activation=self.activation)
  549. if attention_mask is not None:
  550. hidden_states = hidden_states * attention_mask.unsqueeze(1)
  551. # 3. State Space Model sequence transformation
  552. # 3.a. input varying initialization of time_step, B and C
  553. ssm_parameters = self.x_proj(hidden_states.transpose(1, 2))
  554. time_step, B, C = torch.split(
  555. ssm_parameters, [self.time_step_rank, self.ssm_state_size, self.ssm_state_size], dim=-1
  556. )
  557. time_step = self.dt_layernorm(time_step)
  558. B = self.b_layernorm(B)
  559. C = self.c_layernorm(C)
  560. # Here we need to apply dt_proj without the bias, as the bias is added in the selective scan kernel.
  561. # This is a hack to apply dt_proj while still using the forward pass of `torch.nn.Linear`, which is needed
  562. # in order to make quantization work. Quantization code replaces `torch.nn.Linear` layers with quantized
  563. # linear layers, and requires to call the forward pass directly.
  564. # Quantized model can't work with the original code:
  565. # ```discrete_time_step = self.dt_proj.weight @ time_step.transpose(1, 2)```
  566. time_proj_bias = self.dt_proj.bias.data
  567. with torch.no_grad():
  568. self.dt_proj.bias.data = torch.zeros_like(self.dt_proj.bias.data)
  569. discrete_time_step = self.dt_proj(time_step).transpose(1, 2)
  570. with torch.no_grad():
  571. self.dt_proj.bias.data = time_proj_bias
  572. A = -torch.exp(self.A_log.float())
  573. # 3.c perform the recurrence y ← SSM(A, B, C)(x)
  574. time_proj_bias = time_proj_bias.float() if time_proj_bias is not None else None
  575. if use_precomputed_states:
  576. scan_outputs = selective_state_update(
  577. cache_params.ssm_states[self.layer_idx],
  578. hidden_states[..., 0],
  579. discrete_time_step[..., 0],
  580. A,
  581. B[:, 0],
  582. C[:, 0],
  583. self.D,
  584. gate[..., 0],
  585. time_proj_bias,
  586. dt_softplus=True,
  587. ).unsqueeze(-1)
  588. else:
  589. scan_outputs, ssm_state = selective_scan_fn(
  590. hidden_states,
  591. discrete_time_step,
  592. A,
  593. B.transpose(1, 2),
  594. C.transpose(1, 2),
  595. self.D.float(),
  596. gate,
  597. time_proj_bias,
  598. delta_softplus=True,
  599. return_last_state=True,
  600. )
  601. if ssm_state is not None and cache_params is not None:
  602. cache_params.ssm_states[self.layer_idx].copy_(ssm_state)
  603. # 4. Final linear projection
  604. contextualized_states = self.out_proj(scan_outputs.transpose(1, 2))
  605. return contextualized_states
  606. # fmt: off
  607. def slow_forward(self, input_states, cache_params: Optional[HybridMambaAttentionDynamicCache] = None, attention_mask: Optional[torch.LongTensor] = None):
  608. batch_size, seq_len, _ = input_states.shape
  609. dtype = input_states.dtype
  610. # 1. Gated MLP's linear projection
  611. projected_states = self.in_proj(input_states).transpose(1, 2) # [batch, 2 * intermediate_size, seq_len]
  612. hidden_states, gate = projected_states.chunk(2, dim=1)
  613. if attention_mask is not None:
  614. hidden_states = hidden_states * attention_mask.unsqueeze(1)
  615. use_cache = isinstance(cache_params, HybridMambaAttentionDynamicCache)
  616. # 2. Convolution sequence transformation
  617. if use_cache and cache_params.ssm_states[self.layer_idx].shape[0] == batch_size:
  618. if self.training:
  619. # In training mode, we don't want to perform in-place operations on ssm_state so we can compute the backwards pass
  620. ssm_state = cache_params.ssm_states[self.layer_idx].clone()
  621. else:
  622. ssm_state = cache_params.ssm_states[self.layer_idx]
  623. ssm_state = ssm_state.to(hidden_states.device)
  624. if cache_params.has_previous_state and seq_len == 1 and \
  625. cache_params.conv_states[self.layer_idx].shape[0] == batch_size:
  626. conv_state = cache_params.conv_states[self.layer_idx] # [batch, intermediate_size, conv_kernel_size]
  627. conv_state = torch.roll(conv_state, shifts=-1, dims=-1)
  628. conv_state[:, :, -1] = hidden_states[:, :, 0]
  629. cache_params.conv_states[self.layer_idx] = conv_state
  630. hidden_states = torch.sum(conv_state * self.conv1d.weight[:, 0, :], dim=-1)
  631. if self.use_conv_bias:
  632. hidden_states += self.conv1d.bias
  633. hidden_states = self.act(hidden_states).to(dtype).unsqueeze(-1) # [batch, intermediate_size, 1] : decoding
  634. else:
  635. conv_state = nn.functional.pad(
  636. hidden_states,
  637. (self.conv_kernel_size - hidden_states.shape[-1], 0)
  638. )
  639. cache_params.conv_states[self.layer_idx] = conv_state
  640. hidden_states = self.act(self.conv1d(hidden_states)[..., :seq_len]) # [batch, intermediate_size, seq_len]
  641. else:
  642. ssm_state = torch.zeros(
  643. (batch_size, self.intermediate_size, self.ssm_state_size),
  644. device=hidden_states.device, dtype=dtype
  645. )
  646. hidden_states = self.act(self.conv1d(hidden_states)[..., :seq_len]) # [batch, intermediate_size, seq_len]
  647. if attention_mask is not None:
  648. hidden_states = hidden_states * attention_mask.unsqueeze(1)
  649. # 3. State Space Model sequence transformation
  650. # 3.a. Selection: [batch, seq_len, self.time_step_rank + self.ssm_state_size * 2]
  651. ssm_parameters = self.x_proj(hidden_states.transpose(1, 2))
  652. time_step, B, C = torch.split(
  653. ssm_parameters, [self.time_step_rank, self.ssm_state_size, self.ssm_state_size], dim=-1
  654. )
  655. time_step = self.dt_layernorm(time_step)
  656. B = self.b_layernorm(B)
  657. C = self.c_layernorm(C)
  658. discrete_time_step = self.dt_proj(time_step) # [batch, seq_len, intermediate_size]
  659. discrete_time_step = nn.functional.softplus(discrete_time_step).transpose(1, 2) # [batch, intermediate_size, seq_len]
  660. # 3.b. Discretization: B and C to [batch, seq_len, intermediate_size, ssm_state_size] (SRAM)
  661. A = -torch.exp(self.A_log.float()) # [intermediate_size, ssm_state_size]
  662. discrete_A = torch.exp(A[None, :, None, :] * discrete_time_step[:, :, :, None]) # [batch, intermediate_size, seq_len, ssm_state_size]
  663. discrete_B = discrete_time_step[:, :, :, None] * B[:, None, :, :].float() # [batch, intermediate_size, seq_len, ssm_state_size]
  664. deltaB_u = discrete_B * hidden_states[:, :, :, None].float()
  665. # 3.c perform the recurrence y ← SSM(A, B, C)(x)
  666. scan_outputs = []
  667. for i in range(seq_len):
  668. ssm_state = discrete_A[:, :, i, :] * ssm_state + deltaB_u[:, :, i, :] # [batch, intermediate_size, ssm_state]
  669. scan_output = torch.matmul(ssm_state.to(dtype), C[:, i, :].unsqueeze(-1)) # [batch, intermediate_size, 1]
  670. scan_outputs.append(scan_output[:, :, 0])
  671. scan_output = torch.stack(scan_outputs, dim=-1) # [batch, intermediate_size, seq_len]
  672. scan_output = scan_output + (hidden_states * self.D[None, :, None])
  673. scan_output = (scan_output * self.act(gate))
  674. if use_cache:
  675. cache_params.ssm_states[self.layer_idx] = ssm_state
  676. # 4. Final linear projection
  677. contextualized_states = self.out_proj(scan_output.transpose(1, 2)) # [batch, seq_len, hidden_size]
  678. return contextualized_states
  679. # fmt: on
  680. def forward(
  681. self,
  682. hidden_states,
  683. cache_params: Optional[HybridMambaAttentionDynamicCache] = None,
  684. attention_mask: Optional[torch.LongTensor] = None,
  685. ):
  686. if self.use_fast_kernels:
  687. if not is_fast_path_available or "cuda" not in self.x_proj.weight.device.type:
  688. raise ValueError(
  689. "Fast Mamba kernels are not available. Make sure to they are installed and that the mamba module is on a CUDA device"
  690. )
  691. return self.cuda_kernels_forward(hidden_states, cache_params, attention_mask)
  692. return self.slow_forward(hidden_states, cache_params, attention_mask)
  693. # Copied from transformers.models.mistral.modeling_mistral.MistralMLP with Mistral->Jamba
  694. class JambaMLP(nn.Module):
  695. def __init__(self, config):
  696. super().__init__()
  697. self.config = config
  698. self.hidden_size = config.hidden_size
  699. self.intermediate_size = config.intermediate_size
  700. self.gate_proj = nn.Linear(self.hidden_size, self.intermediate_size, bias=False)
  701. self.up_proj = nn.Linear(self.hidden_size, self.intermediate_size, bias=False)
  702. self.down_proj = nn.Linear(self.intermediate_size, self.hidden_size, bias=False)
  703. self.act_fn = ACT2FN[config.hidden_act]
  704. def forward(self, x):
  705. down_proj = self.down_proj(self.act_fn(self.gate_proj(x)) * self.up_proj(x))
  706. return down_proj
  707. # Adapted from transformers.models.mixtral.modeling_mixtral.MixtralSparseMoeBlock with Mistral->Jamba
  708. class JambaSparseMoeBlock(nn.Module):
  709. """
  710. This implementation is
  711. strictly equivalent to standard MoE with full capacity (no
  712. dropped tokens). It's faster since it formulates MoE operations
  713. in terms of block-sparse operations to accommodate imbalanced
  714. assignments of tokens to experts, whereas standard MoE either
  715. (1) drop tokens at the cost of reduced performance or (2) set
  716. capacity factor to number of experts and thus waste computation
  717. and memory on padding.
  718. """
  719. def __init__(self, config: JambaConfig):
  720. super().__init__()
  721. self.hidden_dim = config.hidden_size
  722. self.ffn_dim = config.intermediate_size
  723. self.num_experts = config.num_experts
  724. self.top_k = config.num_experts_per_tok
  725. self.router = nn.Linear(self.hidden_dim, self.num_experts, bias=False)
  726. self.experts = nn.ModuleList([JambaMLP(config) for _ in range(self.num_experts)])
  727. def forward(self, hidden_states: torch.Tensor) -> tuple[torch.Tensor, torch.Tensor]:
  728. """ """
  729. batch_size, sequence_length, hidden_dim = hidden_states.shape
  730. hidden_states = hidden_states.view(-1, hidden_dim)
  731. # router_logits: (batch * sequence_length, n_experts)
  732. router_logits = self.router(hidden_states)
  733. routing_weights = F.softmax(router_logits, dim=1, dtype=torch.float)
  734. routing_weights, selected_experts = torch.topk(routing_weights, self.top_k, dim=-1)
  735. # we cast back to the input dtype
  736. routing_weights = routing_weights.to(hidden_states.dtype)
  737. final_hidden_states = torch.zeros(
  738. (batch_size * sequence_length, hidden_dim), dtype=hidden_states.dtype, device=hidden_states.device
  739. )
  740. # One hot encode the selected experts to create an expert mask
  741. # this will be used to easily index which expert is going to be sollicitated
  742. expert_mask = torch.nn.functional.one_hot(selected_experts, num_classes=self.num_experts).permute(2, 1, 0)
  743. # Loop over all available experts in the model and perform the computation on each expert
  744. for expert_idx in range(self.num_experts):
  745. expert_layer = self.experts[expert_idx]
  746. idx, top_x = torch.where(expert_mask[expert_idx])
  747. if top_x.shape[0] == 0:
  748. continue
  749. # Index the correct hidden states and compute the expert hidden state for
  750. # the current expert. We need to make sure to multiply the output hidden
  751. # states by `routing_weights` on the corresponding tokens (top-1 and top-2)
  752. current_state = hidden_states[None, top_x].reshape(-1, hidden_dim)
  753. current_hidden_states = expert_layer(current_state) * routing_weights[top_x, idx, None]
  754. # However `index_add_` only support torch tensors for indexing so we'll use
  755. # the `top_x` tensor here.
  756. final_hidden_states.index_add_(0, top_x, current_hidden_states.to(hidden_states.dtype))
  757. final_hidden_states = final_hidden_states.reshape(batch_size, sequence_length, hidden_dim)
  758. return final_hidden_states, router_logits
  759. class JambaAttentionDecoderLayer(GradientCheckpointingLayer):
  760. def __init__(self, config: JambaConfig, layer_idx: int):
  761. super().__init__()
  762. num_experts = config.layers_num_experts[layer_idx]
  763. self.self_attn = JAMBA_ATTENTION_CLASSES[config._attn_implementation](config, layer_idx)
  764. ffn_layer_class = JambaSparseMoeBlock if num_experts > 1 else JambaMLP
  765. self.feed_forward = ffn_layer_class(config)
  766. self.input_layernorm = JambaRMSNorm(config.hidden_size, eps=config.rms_norm_eps)
  767. self.pre_ff_layernorm = JambaRMSNorm(config.hidden_size, eps=config.rms_norm_eps)
  768. @deprecate_kwarg("past_key_value", new_name="past_key_values", version="4.58")
  769. def forward(
  770. self,
  771. hidden_states: torch.Tensor,
  772. attention_mask: Optional[torch.Tensor] = None,
  773. position_ids: Optional[torch.LongTensor] = None,
  774. past_key_values: Optional[HybridMambaAttentionDynamicCache] = None,
  775. output_attentions: Optional[bool] = False,
  776. output_router_logits: Optional[bool] = False,
  777. use_cache: Optional[bool] = False,
  778. cache_position: Optional[torch.LongTensor] = None,
  779. ) -> tuple[torch.FloatTensor, Optional[tuple[torch.FloatTensor, torch.FloatTensor]]]:
  780. """
  781. Args:
  782. hidden_states (`torch.FloatTensor`): input to the layer of shape `(batch, seq_len, embed_dim)`
  783. attention_mask (`torch.FloatTensor`, *optional*): attention mask of size
  784. `(batch, sequence_length)` where padding elements are indicated by 0.
  785. past_key_values (`HybridMambaAttentionDynamicCache`, *optional*): cached past key and value projection states
  786. output_attentions (`bool`, *optional*):
  787. Whether or not to return the attentions tensors of all attention layers. See `attentions` under
  788. returned tensors for more detail.
  789. output_router_logits (`bool`, *optional*):
  790. Whether or not to return the logits of all the routers. They are useful for computing the router loss, and
  791. should not be returned during inference.
  792. use_cache (`bool`, *optional*):
  793. If set to `True`, `past_key_values` key value states are returned and can be used to speed up decoding
  794. (see `past_key_values`).
  795. cache_position (`torch.LongTensor` of shape `(sequence_length)`, *optional*):
  796. Indices depicting the position of the input sequence tokens in the sequence.
  797. """
  798. residual = hidden_states
  799. hidden_states = self.input_layernorm(hidden_states)
  800. hidden_states, self_attn_weights, present_key_value = self.self_attn(
  801. hidden_states=hidden_states,
  802. attention_mask=attention_mask,
  803. position_ids=position_ids,
  804. past_key_values=past_key_values,
  805. output_attentions=output_attentions,
  806. use_cache=use_cache,
  807. cache_position=cache_position,
  808. )
  809. # residual connection after attention
  810. hidden_states = residual + hidden_states
  811. # feed-forward (experts/MLP)
  812. residual = hidden_states
  813. hidden_states = self.pre_ff_layernorm(hidden_states)
  814. ff_outputs = self.feed_forward(hidden_states)
  815. if isinstance(ff_outputs, tuple):
  816. hidden_states, router_logits = ff_outputs
  817. else:
  818. hidden_states, router_logits = ff_outputs, None
  819. hidden_states = residual + hidden_states
  820. outputs = (hidden_states,)
  821. if output_attentions:
  822. outputs += (self_attn_weights,)
  823. if use_cache:
  824. outputs += (present_key_value,)
  825. if output_router_logits:
  826. outputs += (router_logits,)
  827. return outputs
  828. class JambaMambaDecoderLayer(GradientCheckpointingLayer):
  829. def __init__(self, config: JambaConfig, layer_idx: int):
  830. super().__init__()
  831. num_experts = config.layers_num_experts[layer_idx]
  832. self.mamba = JambaMambaMixer(config=config, layer_idx=layer_idx)
  833. ffn_layer_class = JambaSparseMoeBlock if num_experts > 1 else JambaMLP
  834. self.feed_forward = ffn_layer_class(config)
  835. self.input_layernorm = JambaRMSNorm(config.hidden_size, eps=config.rms_norm_eps)
  836. self.pre_ff_layernorm = JambaRMSNorm(config.hidden_size, eps=config.rms_norm_eps)
  837. @deprecate_kwarg("past_key_value", new_name="past_key_values", version="4.58")
  838. def forward(
  839. self,
  840. hidden_states: torch.Tensor,
  841. attention_mask: Optional[torch.Tensor] = None,
  842. position_ids: Optional[torch.LongTensor] = None,
  843. past_key_values: Optional[HybridMambaAttentionDynamicCache] = None,
  844. output_attentions: Optional[bool] = False,
  845. output_router_logits: Optional[bool] = False,
  846. use_cache: Optional[bool] = False,
  847. cache_position: Optional[torch.LongTensor] = None,
  848. ) -> tuple[torch.FloatTensor, Optional[tuple[torch.FloatTensor, torch.FloatTensor]]]:
  849. """
  850. Args:
  851. hidden_states (`torch.FloatTensor`): input to the layer of shape `(batch, seq_len, embed_dim)`
  852. attention_mask (`torch.FloatTensor`, *optional*): attention mask of size
  853. `(batch, sequence_length)` where padding elements are indicated by 0.
  854. past_key_values (`HybridMambaAttentionDynamicCache`, *optional*): cached past key and value projection states
  855. output_attentions (`bool`, *optional*):
  856. Whether or not to return the attentions tensors of all attention layers. See `attentions` under
  857. returned tensors for more detail.
  858. output_router_logits (`bool`, *optional*):
  859. Whether or not to return the logits of all the routers. They are useful for computing the router loss, and
  860. should not be returned during inference.
  861. use_cache (`bool`, *optional*):
  862. If set to `True`, `past_key_values` key value states are returned and can be used to speed up decoding
  863. (see `past_key_values`).
  864. cache_position (`torch.LongTensor` of shape `(sequence_length)`, *optional*):
  865. Indices depicting the position of the input sequence tokens in the sequence.
  866. """
  867. residual = hidden_states
  868. hidden_states = self.input_layernorm(hidden_states)
  869. hidden_states = self.mamba(
  870. hidden_states=hidden_states,
  871. cache_params=past_key_values,
  872. attention_mask=attention_mask,
  873. )
  874. self_attn_weights = None
  875. # residual connection after mamba
  876. hidden_states = residual + hidden_states
  877. # feed-forward (experts/MLP)
  878. residual = hidden_states
  879. hidden_states = self.pre_ff_layernorm(hidden_states)
  880. ff_outputs = self.feed_forward(hidden_states)
  881. if isinstance(ff_outputs, tuple):
  882. hidden_states, router_logits = ff_outputs
  883. else:
  884. hidden_states, router_logits = ff_outputs, None
  885. hidden_states = residual + hidden_states
  886. outputs = (hidden_states,)
  887. if output_attentions:
  888. outputs += (self_attn_weights,)
  889. if use_cache:
  890. outputs += (past_key_values,)
  891. if output_router_logits:
  892. outputs += (router_logits,)
  893. return outputs
  894. @auto_docstring
  895. class JambaPreTrainedModel(PreTrainedModel):
  896. config: JambaConfig
  897. base_model_prefix = "model"
  898. supports_gradient_checkpointing = True
  899. _no_split_modules = ["JambaAttentionDecoderLayer", "JambaMambaDecoderLayer"]
  900. _skip_keys_device_placement = "past_key_values"
  901. _supports_flash_attn = True
  902. _supports_sdpa = True
  903. # Note: only supports HybridMambaAttentionDynamicCache
  904. _is_stateful = True
  905. def _init_weights(self, module):
  906. std = self.config.initializer_range
  907. if isinstance(module, (nn.Linear, nn.Conv1d)):
  908. module.weight.data.normal_(mean=0.0, std=std)
  909. if module.bias is not None:
  910. module.bias.data.zero_()
  911. elif isinstance(module, nn.Embedding):
  912. module.weight.data.normal_(mean=0.0, std=std)
  913. if module.padding_idx is not None:
  914. module.weight.data[module.padding_idx].zero_()
  915. elif isinstance(module, JambaRMSNorm):
  916. module.weight.data.fill_(1.0)
  917. elif isinstance(module, JambaMambaMixer):
  918. A = torch.arange(1, module.ssm_state_size + 1)[None, :]
  919. A = A.expand(module.intermediate_size, -1).contiguous()
  920. module.A_log.data.copy_(torch.log(A))
  921. module.D.data.fill_(1.0)
  922. ALL_DECODER_LAYER_TYPES = {"attention": JambaAttentionDecoderLayer, "mamba": JambaMambaDecoderLayer}
  923. # Adapted from transformers.models.mistral.modeling_mistral.MistralModel with MISTRAL->JAMBA, Mistral->Jamba
  924. @auto_docstring
  925. class JambaModel(JambaPreTrainedModel):
  926. """
  927. Transformer decoder consisting of *config.num_hidden_layers* layers. Each layer is a [`JambaDecoderLayer`]
  928. Args:
  929. config: JambaConfig
  930. """
  931. def __init__(self, config: JambaConfig):
  932. super().__init__(config)
  933. self.padding_idx = config.pad_token_id
  934. self.vocab_size = config.vocab_size
  935. self.embed_tokens = nn.Embedding(config.vocab_size, config.hidden_size, self.padding_idx)
  936. decoder_layers = []
  937. for i in range(config.num_hidden_layers):
  938. layer_class = ALL_DECODER_LAYER_TYPES[config.layers_block_type[i]]
  939. decoder_layers.append(layer_class(config, layer_idx=i))
  940. self.layers = nn.ModuleList(decoder_layers)
  941. self._attn_implementation = config._attn_implementation
  942. self.final_layernorm = JambaRMSNorm(config.hidden_size, eps=config.rms_norm_eps)
  943. self.gradient_checkpointing = False
  944. # Initialize weights and apply final processing
  945. self.post_init()
  946. @can_return_tuple
  947. @auto_docstring
  948. def forward(
  949. self,
  950. input_ids: Optional[torch.LongTensor] = None,
  951. attention_mask: Optional[torch.Tensor] = None,
  952. position_ids: Optional[torch.LongTensor] = None,
  953. past_key_values: Optional[HybridMambaAttentionDynamicCache] = None,
  954. inputs_embeds: Optional[torch.FloatTensor] = None,
  955. use_cache: Optional[bool] = None,
  956. output_attentions: Optional[bool] = None,
  957. output_hidden_states: Optional[bool] = None,
  958. output_router_logits: Optional[bool] = None,
  959. cache_position: Optional[torch.LongTensor] = None,
  960. **kwargs: Unpack[TransformersKwargs],
  961. ) -> MoeModelOutputWithPast:
  962. output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
  963. output_router_logits = (
  964. output_router_logits if output_router_logits is not None else self.config.output_router_logits
  965. )
  966. output_hidden_states = (
  967. output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
  968. )
  969. use_cache = use_cache if use_cache is not None else self.config.use_cache
  970. if (input_ids is None) ^ (inputs_embeds is not None):
  971. raise ValueError("You must specify exactly one of input_ids or inputs_embeds")
  972. if self.gradient_checkpointing and self.training and use_cache:
  973. logger.warning_once(
  974. "`use_cache=True` is incompatible with gradient checkpointing. Setting `use_cache=False`."
  975. )
  976. use_cache = False
  977. if inputs_embeds is None:
  978. inputs_embeds = self.embed_tokens(input_ids)
  979. hidden_states = inputs_embeds
  980. if use_cache and past_key_values is None:
  981. logger.warning_once(
  982. "Jamba requires an initialized `HybridMambaAttentionDynamicCache` to return a cache. None was "
  983. "provided, so no cache will be returned."
  984. )
  985. if cache_position is None:
  986. cache_position = torch.arange(hidden_states.shape[1], device=hidden_states.device)
  987. if position_ids is None:
  988. position_ids = cache_position.unsqueeze(0)
  989. causal_mask = self._update_causal_mask(attention_mask, inputs_embeds, cache_position)
  990. mamba_mask = self._update_mamba_mask(attention_mask, cache_position)
  991. all_hidden_states = () if output_hidden_states else None
  992. all_self_attns = () if output_attentions else None
  993. all_router_logits = () if output_router_logits else None
  994. for decoder_layer in self.layers:
  995. # Depending on the layer type we opt for 2D base attention mask (Mamba) or 4D causal mask (Attention)
  996. layer_mask = mamba_mask if isinstance(decoder_layer, JambaMambaDecoderLayer) else causal_mask
  997. if output_hidden_states:
  998. all_hidden_states += (hidden_states,)
  999. layer_outputs = decoder_layer(
  1000. hidden_states,
  1001. attention_mask=layer_mask,
  1002. position_ids=position_ids,
  1003. past_key_values=past_key_values,
  1004. output_attentions=output_attentions,
  1005. output_router_logits=output_router_logits,
  1006. use_cache=use_cache,
  1007. cache_position=cache_position,
  1008. )
  1009. hidden_states = layer_outputs[0]
  1010. if output_attentions:
  1011. if layer_outputs[1] is not None:
  1012. # append attentions only of attention layers. Mamba layers return `None` as the attention weights
  1013. all_self_attns += (layer_outputs[1],)
  1014. if output_router_logits:
  1015. if layer_outputs[-1] is not None:
  1016. # append router logits only of expert layers. Regular MLP layers return `None` as the router logits
  1017. all_router_logits += (layer_outputs[-1],)
  1018. hidden_states = self.final_layernorm(hidden_states)
  1019. # add hidden states from the last decoder layer
  1020. if output_hidden_states:
  1021. all_hidden_states += (hidden_states,)
  1022. if past_key_values and not past_key_values.has_previous_state:
  1023. past_key_values.has_previous_state = True
  1024. next_cache = None if not use_cache else past_key_values
  1025. return MoeModelOutputWithPast(
  1026. last_hidden_state=hidden_states,
  1027. past_key_values=next_cache,
  1028. hidden_states=all_hidden_states,
  1029. attentions=all_self_attns,
  1030. router_logits=all_router_logits,
  1031. )
  1032. def _update_causal_mask(self, attention_mask, input_tensor, cache_position):
  1033. if self.config._attn_implementation == "flash_attention_2":
  1034. if attention_mask is not None and 0.0 in attention_mask:
  1035. return attention_mask
  1036. return None
  1037. dtype, device = input_tensor.dtype, input_tensor.device
  1038. min_dtype = torch.finfo(dtype).min
  1039. sequence_length = input_tensor.shape[1]
  1040. target_length = cache_position[-1] + 1
  1041. causal_mask = torch.full((sequence_length, target_length), fill_value=min_dtype, dtype=dtype, device=device)
  1042. if sequence_length != 1:
  1043. causal_mask = torch.triu(causal_mask, diagonal=1)
  1044. causal_mask *= torch.arange(target_length, device=device) > cache_position.reshape(-1, 1)
  1045. causal_mask = causal_mask[None, None, :, :].expand(input_tensor.shape[0], 1, -1, -1)
  1046. if attention_mask is not None:
  1047. causal_mask = causal_mask.clone() # copy to contiguous memory for in-place edit
  1048. if attention_mask.dim() == 2:
  1049. mask_length = attention_mask.shape[-1]
  1050. padding_mask = causal_mask[..., :mask_length].eq(0.0) * attention_mask[:, None, None, :].eq(0.0)
  1051. causal_mask[..., :mask_length] = causal_mask[..., :mask_length].masked_fill(padding_mask, min_dtype)
  1052. if (
  1053. self.config._attn_implementation == "sdpa"
  1054. and attention_mask is not None
  1055. and attention_mask.device.type in ["cuda", "xpu", "npu"]
  1056. ):
  1057. # Attend to all tokens in fully masked rows in the causal_mask, for example the relevant first rows when
  1058. # using left padding. This is required by F.scaled_dot_product_attention memory-efficient attention path.
  1059. # Details: https://github.com/pytorch/pytorch/issues/110213
  1060. causal_mask = AttentionMaskConverter._unmask_unattended(causal_mask, min_dtype)
  1061. return causal_mask
  1062. def _update_mamba_mask(self, attention_mask, cache_position):
  1063. """
  1064. No need for zeroing states when
  1065. 1. Cached forward
  1066. 2. Attending to all inputs
  1067. """
  1068. mamba_mask = attention_mask
  1069. if cache_position[0] > 0 or (attention_mask is not None and torch.all(attention_mask == 1)):
  1070. mamba_mask = None
  1071. return mamba_mask
  1072. # Adapted from transformers.models.mixtral.modeling_mixtral.MixtralForCausalLM with MIXTRAL->JAMBA, Mixtral->Jamba
  1073. class JambaForCausalLM(JambaPreTrainedModel, GenerationMixin):
  1074. _tied_weights_keys = ["lm_head.weight"]
  1075. def __init__(self, config: JambaConfig):
  1076. super().__init__(config)
  1077. self.model = JambaModel(config)
  1078. self.vocab_size = config.vocab_size
  1079. self.lm_head = nn.Linear(config.hidden_size, config.vocab_size, bias=False)
  1080. self.router_aux_loss_coef = config.router_aux_loss_coef
  1081. self.num_experts = config.num_experts
  1082. self.num_experts_per_tok = config.num_experts_per_tok
  1083. # Initialize weights and apply final processing
  1084. self.post_init()
  1085. @can_return_tuple
  1086. @auto_docstring
  1087. def forward(
  1088. self,
  1089. input_ids: Optional[torch.LongTensor] = None,
  1090. attention_mask: Optional[torch.Tensor] = None,
  1091. position_ids: Optional[torch.LongTensor] = None,
  1092. past_key_values: Optional[HybridMambaAttentionDynamicCache] = None,
  1093. inputs_embeds: Optional[torch.FloatTensor] = None,
  1094. labels: Optional[torch.LongTensor] = None,
  1095. use_cache: Optional[bool] = None,
  1096. output_attentions: Optional[bool] = None,
  1097. output_hidden_states: Optional[bool] = None,
  1098. output_router_logits: Optional[bool] = None,
  1099. cache_position: Optional[torch.LongTensor] = None,
  1100. logits_to_keep: Union[int, torch.Tensor] = 0,
  1101. **kwargs: Unpack[TransformersKwargs],
  1102. ) -> MoeCausalLMOutputWithPast:
  1103. r"""
  1104. labels (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*):
  1105. Labels for computing the masked language modeling loss. Indices should either be in `[0, ...,
  1106. config.vocab_size]` or -100 (see `input_ids` docstring). Tokens with indices set to `-100` are ignored
  1107. (masked), the loss is only computed for the tokens with labels in `[0, ..., config.vocab_size]`.
  1108. Example:
  1109. ```python
  1110. >>> from transformers import AutoTokenizer, JambaForCausalLM
  1111. >>> model = JambaForCausalLM.from_pretrained("ai21labs/Jamba-v0.1")
  1112. >>> tokenizer = AutoTokenizer.from_pretrained("ai21labs/Jamba-v0.1")
  1113. >>> prompt = "Hey, are you conscious? Can you talk to me?"
  1114. >>> inputs = tokenizer(prompt, return_tensors="pt")
  1115. >>> # Generate
  1116. >>> generate_ids = model.generate(inputs.input_ids, max_length=30)
  1117. >>> tokenizer.batch_decode(generate_ids, skip_special_tokens=True, clean_up_tokenization_spaces=False)[0]
  1118. "Hey, are you conscious? Can you talk to me?\nI'm not conscious, but I can talk to you."
  1119. ```"""
  1120. output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
  1121. output_router_logits = (
  1122. output_router_logits if output_router_logits is not None else self.config.output_router_logits
  1123. )
  1124. output_hidden_states = (
  1125. output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
  1126. )
  1127. # decoder outputs consists of (dec_features, layer_state, dec_hidden, dec_attn)
  1128. outputs: MoeModelOutputWithPast = self.model(
  1129. input_ids=input_ids,
  1130. attention_mask=attention_mask,
  1131. position_ids=position_ids,
  1132. past_key_values=past_key_values,
  1133. inputs_embeds=inputs_embeds,
  1134. use_cache=use_cache,
  1135. output_attentions=output_attentions,
  1136. output_hidden_states=output_hidden_states,
  1137. output_router_logits=output_router_logits,
  1138. cache_position=cache_position,
  1139. )
  1140. hidden_states = outputs.last_hidden_state
  1141. slice_indices = slice(-logits_to_keep, None) if isinstance(logits_to_keep, int) else logits_to_keep
  1142. logits = self.lm_head(hidden_states[:, slice_indices, :])
  1143. loss = None
  1144. if labels is not None:
  1145. loss = self.loss_function(logits, labels, self.vocab_size, **kwargs)
  1146. aux_loss = None
  1147. if output_router_logits:
  1148. aux_loss = load_balancing_loss_func(
  1149. outputs.router_logits,
  1150. self.num_experts,
  1151. self.num_experts_per_tok,
  1152. attention_mask,
  1153. )
  1154. if labels is not None:
  1155. loss += self.router_aux_loss_coef * aux_loss.to(loss.device) # make sure to reside in the same device
  1156. return MoeCausalLMOutputWithPast(
  1157. loss=loss,
  1158. aux_loss=aux_loss,
  1159. logits=logits,
  1160. past_key_values=outputs.past_key_values,
  1161. hidden_states=outputs.hidden_states,
  1162. attentions=outputs.attentions,
  1163. router_logits=outputs.router_logits,
  1164. )
  1165. def prepare_inputs_for_generation(
  1166. self,
  1167. input_ids,
  1168. past_key_values=None,
  1169. attention_mask=None,
  1170. inputs_embeds=None,
  1171. output_router_logits=False,
  1172. cache_position=None,
  1173. position_ids=None,
  1174. use_cache=True,
  1175. **kwargs,
  1176. ):
  1177. # Overwritten -- has a unique cache type, `HybridMambaAttentionDynamicCache`
  1178. empty_past_kv = past_key_values is None
  1179. # If we have cache: let's slice `input_ids` through `cache_position`, to keep only the unprocessed tokens
  1180. # Exception 1: when passing input_embeds, input_ids may be missing entries
  1181. # Exception 2: some generation methods do special slicing of input_ids, so we don't need to do it here
  1182. # Exception 3: with synced GPUs cache_position may go out of bounds, but we only want dummy token in that case.
  1183. # (we can't check exception 3 while compiling)
  1184. if not empty_past_kv:
  1185. if (
  1186. inputs_embeds is not None # Exception 1
  1187. or cache_position[-1] >= input_ids.shape[1] # Exception 3
  1188. ):
  1189. input_ids = input_ids[:, -cache_position.shape[0] :]
  1190. elif input_ids.shape[1] != cache_position.shape[0]: # Default case (the "else", a no op, is Exception 2)
  1191. input_ids = input_ids[:, cache_position]
  1192. else:
  1193. past_key_values = HybridMambaAttentionDynamicCache(
  1194. self.config, input_ids.shape[0], self.dtype, device=self.device
  1195. )
  1196. if attention_mask is not None and position_ids is None:
  1197. # create position_ids on the fly for batch generation
  1198. position_ids = attention_mask.long().cumsum(-1) - 1
  1199. position_ids.masked_fill_(attention_mask == 0, 1)
  1200. if not empty_past_kv:
  1201. position_ids = position_ids[:, -input_ids.shape[1] :]
  1202. # if `inputs_embeds` are passed, we only want to use them in the 1st generation step
  1203. if inputs_embeds is not None and empty_past_kv:
  1204. model_inputs = {"inputs_embeds": inputs_embeds}
  1205. else:
  1206. model_inputs = {"input_ids": input_ids.contiguous()} # `contiguous()` needed for compilation use cases
  1207. model_inputs.update(
  1208. {
  1209. "position_ids": position_ids,
  1210. "past_key_values": past_key_values,
  1211. "use_cache": use_cache,
  1212. "attention_mask": attention_mask,
  1213. "output_router_logits": output_router_logits,
  1214. "logits_to_keep": self.config.num_logits_to_keep,
  1215. "cache_position": cache_position,
  1216. }
  1217. )
  1218. # Forward ALL kwargs that are uninitialized (e.g. `use_cache`).
  1219. for key, value in kwargs.items():
  1220. if key not in model_inputs:
  1221. model_inputs[key] = value
  1222. return model_inputs
  1223. class JambaForSequenceClassification(GenericForSequenceClassification, JambaPreTrainedModel): ...
  1224. __all__ = ["JambaForCausalLM", "JambaForSequenceClassification", "JambaModel", "JambaPreTrainedModel"]