modeling_bamba.py 68 KB

1234567891011121314151617181920212223242526272829303132333435363738394041424344454647484950515253545556575859606162636465666768697071727374757677787980818283848586878889909192939495969798991001011021031041051061071081091101111121131141151161171181191201211221231241251261271281291301311321331341351361371381391401411421431441451461471481491501511521531541551561571581591601611621631641651661671681691701711721731741751761771781791801811821831841851861871881891901911921931941951961971981992002012022032042052062072082092102112122132142152162172182192202212222232242252262272282292302312322332342352362372382392402412422432442452462472482492502512522532542552562572582592602612622632642652662672682692702712722732742752762772782792802812822832842852862872882892902912922932942952962972982993003013023033043053063073083093103113123133143153163173183193203213223233243253263273283293303313323333343353363373383393403413423433443453463473483493503513523533543553563573583593603613623633643653663673683693703713723733743753763773783793803813823833843853863873883893903913923933943953963973983994004014024034044054064074084094104114124134144154164174184194204214224234244254264274284294304314324334344354364374384394404414424434444454464474484494504514524534544554564574584594604614624634644654664674684694704714724734744754764774784794804814824834844854864874884894904914924934944954964974984995005015025035045055065075085095105115125135145155165175185195205215225235245255265275285295305315325335345355365375385395405415425435445455465475485495505515525535545555565575585595605615625635645655665675685695705715725735745755765775785795805815825835845855865875885895905915925935945955965975985996006016026036046056066076086096106116126136146156166176186196206216226236246256266276286296306316326336346356366376386396406416426436446456466476486496506516526536546556566576586596606616626636646656666676686696706716726736746756766776786796806816826836846856866876886896906916926936946956966976986997007017027037047057067077087097107117127137147157167177187197207217227237247257267277287297307317327337347357367377387397407417427437447457467477487497507517527537547557567577587597607617627637647657667677687697707717727737747757767777787797807817827837847857867877887897907917927937947957967977987998008018028038048058068078088098108118128138148158168178188198208218228238248258268278288298308318328338348358368378388398408418428438448458468478488498508518528538548558568578588598608618628638648658668678688698708718728738748758768778788798808818828838848858868878888898908918928938948958968978988999009019029039049059069079089099109119129139149159169179189199209219229239249259269279289299309319329339349359369379389399409419429439449459469479489499509519529539549559569579589599609619629639649659669679689699709719729739749759769779789799809819829839849859869879889899909919929939949959969979989991000100110021003100410051006100710081009101010111012101310141015101610171018101910201021102210231024102510261027102810291030103110321033103410351036103710381039104010411042104310441045104610471048104910501051105210531054105510561057105810591060106110621063106410651066106710681069107010711072107310741075107610771078107910801081108210831084108510861087108810891090109110921093109410951096109710981099110011011102110311041105110611071108110911101111111211131114111511161117111811191120112111221123112411251126112711281129113011311132113311341135113611371138113911401141114211431144114511461147114811491150115111521153115411551156115711581159116011611162116311641165116611671168116911701171117211731174117511761177117811791180118111821183118411851186118711881189119011911192119311941195119611971198119912001201120212031204120512061207120812091210121112121213121412151216121712181219122012211222122312241225122612271228122912301231123212331234123512361237123812391240124112421243124412451246124712481249125012511252125312541255125612571258125912601261126212631264126512661267126812691270127112721273127412751276127712781279128012811282128312841285128612871288128912901291129212931294129512961297129812991300130113021303130413051306130713081309131013111312131313141315131613171318131913201321132213231324132513261327132813291330133113321333133413351336133713381339134013411342134313441345134613471348134913501351135213531354135513561357135813591360136113621363136413651366136713681369137013711372137313741375137613771378137913801381138213831384138513861387138813891390139113921393139413951396139713981399140014011402140314041405140614071408140914101411141214131414141514161417141814191420142114221423142414251426142714281429143014311432143314341435143614371438143914401441144214431444144514461447144814491450145114521453145414551456145714581459146014611462146314641465146614671468146914701471147214731474147514761477147814791480148114821483148414851486148714881489149014911492149314941495149614971498149915001501150215031504
  1. # 🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨
  2. # This file was automatically generated from src/transformers/models/bamba/modular_bamba.py.
  3. # Do NOT edit this file manually as any edits will be overwritten by the generation of
  4. # the file from the modular. If any change should be done, please apply the change to the
  5. # modular_bamba.py file directly. One of our CI enforces this.
  6. # 🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨
  7. # coding=utf-8
  8. # Copyright 2024 IBM and the HuggingFace Inc. team. All rights reserved.
  9. #
  10. # This code is based on EleutherAI's GPT-NeoX library and the GPT-NeoX
  11. # and OPT implementations in this library. It has been modified from its
  12. # original forms to accommodate minor architectural differences compared
  13. # to GPT-NeoX and OPT used by the Meta AI team that trained the model.
  14. #
  15. # Licensed under the Apache License, Version 2.0 (the "License");
  16. # you may not use this file except in compliance with the License.
  17. # You may obtain a copy of the License at
  18. #
  19. # http://www.apache.org/licenses/LICENSE-2.0
  20. #
  21. # Unless required by applicable law or agreed to in writing, software
  22. # distributed under the License is distributed on an "AS IS" BASIS,
  23. # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  24. # See the License for the specific language governing permissions and
  25. # limitations under the License.
  26. from typing import Any, Callable, Optional, TypedDict, Union
  27. import torch
  28. from torch import nn
  29. from transformers.activations import ACT2FN
  30. from ...cache_utils import Cache
  31. from ...generation import GenerationMixin
  32. from ...integrations import use_kernel_forward_from_hub
  33. from ...modeling_attn_mask_utils import AttentionMaskConverter
  34. from ...modeling_layers import GradientCheckpointingLayer
  35. from ...modeling_outputs import BaseModelOutputWithPast, CausalLMOutputWithPast
  36. from ...modeling_rope_utils import ROPE_INIT_FUNCTIONS, dynamic_rope_update
  37. from ...modeling_utils import ALL_ATTENTION_FUNCTIONS, PreTrainedModel
  38. from ...processing_utils import Unpack
  39. from ...utils import TransformersKwargs, auto_docstring, can_return_tuple, logging
  40. from ...utils.deprecation import deprecate_kwarg
  41. from ...utils.import_utils import is_causal_conv1d_available, is_mamba_2_ssm_available
  42. from .configuration_bamba import BambaConfig
  43. if is_mamba_2_ssm_available():
  44. from mamba_ssm.ops.triton.selective_state_update import selective_state_update
  45. from mamba_ssm.ops.triton.ssd_combined import mamba_chunk_scan_combined, mamba_split_conv1d_scan_combined
  46. else:
  47. selective_state_update = None
  48. if is_causal_conv1d_available():
  49. from causal_conv1d import causal_conv1d_fn, causal_conv1d_update
  50. else:
  51. causal_conv1d_update, causal_conv1d_fn = None, None
  52. logger = logging.get_logger(__name__)
  53. class BambaFlashAttentionKwargs(TypedDict, total=False):
  54. """
  55. Keyword arguments for advanced Flash Attention, causal-conv1d, and mamba_ssm kernel usage.
  56. Use cases include padding-free training and fewer `torch.compile` graph breaks.
  57. Attributes:
  58. cu_seq_lens_q (`torch.LongTensor`)
  59. Gets cumulative sequence length for query state.
  60. cu_seq_lens_k (`torch.LongTensor`)
  61. Gets cumulative sequence length for key state.
  62. max_length_q (`int`):
  63. Maximum sequence length for query state.
  64. max_length_k (`int`):
  65. Maximum sequence length for key state.
  66. seq_idx (`torch.IntTensor):
  67. Index of each packed sequence.
  68. """
  69. cu_seq_lens_q: torch.LongTensor
  70. cu_seq_lens_k: torch.LongTensor
  71. max_length_q: int
  72. max_length_k: int
  73. seq_idx: torch.IntTensor
  74. class HybridMambaAttentionDynamicCache:
  75. """
  76. A dynamic cache that can handle both the attention cache (which has a seq_len dimension) and the mamba cache
  77. (which has a constant shape regardless of seq_len).
  78. This cache has two sets of lists of tensors: `key_cache` and `value_cache` for attention cache and `conv_states`
  79. and `ssm_states` for mamba cache. Each of these lists has `num_layers` tensors. The expected shape for each tensor
  80. For attention layers, `key_cache` and `value_cache` have a shape of `(batch_size, num_heads, seq_len, head_dim)`,
  81. while `conv_states` and `ssm_states` have a shape of `(batch_size, 0)` (empty tensors).
  82. For mamba layers, `key_cache` and `value_cache` have a shape of `(batch_size, 0)` (empty tensors),
  83. while `conv_states` represents the convolution state and has a shape of `(batch_size, d_inner, d_conv)`,
  84. and `ssm_states` represents the ssm state and has a shape of `(batch_size, d_inner, d_state)`.
  85. """
  86. is_compileable = False
  87. def __init__(self, config: BambaConfig, batch_size, dtype=torch.float16, device=None):
  88. self.layers_block_type = config.layers_block_type
  89. self.has_previous_state = False # only used by mamba
  90. conv_kernel_size = config.mamba_d_conv
  91. ssm_state_size = config.mamba_d_state
  92. self.conv_states = []
  93. self.ssm_states = []
  94. self.transformer_layers = []
  95. for i in range(config.num_hidden_layers):
  96. if self.layers_block_type[i] == "mamba":
  97. self.conv_states += [
  98. torch.zeros(
  99. batch_size,
  100. (config.mamba_expand * config.hidden_size + 2 * config.mamba_n_groups * ssm_state_size),
  101. conv_kernel_size,
  102. device=device,
  103. dtype=dtype,
  104. )
  105. ]
  106. self.ssm_states += [
  107. torch.zeros(
  108. batch_size,
  109. config.mamba_n_heads,
  110. config.mamba_d_head,
  111. ssm_state_size,
  112. device=device,
  113. dtype=dtype,
  114. )
  115. ]
  116. else:
  117. self.conv_states += [torch.tensor([[]] * batch_size, device=device)]
  118. self.ssm_states += [torch.tensor([[]] * batch_size, device=device)]
  119. self.transformer_layers.append(i)
  120. self.key_cache = [torch.tensor([[]] * batch_size, device=device) for _ in range(config.num_hidden_layers)]
  121. self.value_cache = [torch.tensor([[]] * batch_size, device=device) for _ in range(config.num_hidden_layers)]
  122. def update(
  123. self,
  124. key_states: torch.Tensor,
  125. value_states: torch.Tensor,
  126. layer_idx: int,
  127. cache_kwargs: Optional[dict[str, Any]] = None,
  128. ) -> tuple[torch.Tensor, torch.Tensor]:
  129. # Update the cache
  130. if self.key_cache[layer_idx].shape[-1] == 0:
  131. self.key_cache[layer_idx] = key_states
  132. self.value_cache[layer_idx] = value_states
  133. else:
  134. self.key_cache[layer_idx] = torch.cat([self.key_cache[layer_idx], key_states], dim=2)
  135. self.value_cache[layer_idx] = torch.cat([self.value_cache[layer_idx], value_states], dim=2)
  136. return self.key_cache[layer_idx], self.value_cache[layer_idx]
  137. def reorder_cache(self, beam_idx: torch.LongTensor):
  138. """Reorders the cache for beam search, given the selected beam indices."""
  139. for layer_idx in range(len(self.key_cache)):
  140. device = self.key_cache[layer_idx].device
  141. self.key_cache[layer_idx] = self.key_cache[layer_idx].index_select(0, beam_idx.to(device))
  142. device = self.value_cache[layer_idx].device
  143. self.value_cache[layer_idx] = self.value_cache[layer_idx].index_select(0, beam_idx.to(device))
  144. device = self.conv_states[layer_idx].device
  145. self.conv_states[layer_idx] = self.conv_states[layer_idx].index_select(0, beam_idx.to(device))
  146. device = self.ssm_states[layer_idx].device
  147. self.ssm_states[layer_idx] = self.ssm_states[layer_idx].index_select(0, beam_idx.to(device))
  148. def get_seq_length(self, layer_idx: Optional[int] = 0) -> int:
  149. """Returns the sequence length of the cached states. A layer index can be optionally passed."""
  150. # take any layer that contains cache and not empty tensor
  151. layer_idx = self.transformer_layers[0] if layer_idx not in self.transformer_layers else layer_idx
  152. if len(self.key_cache) <= layer_idx:
  153. return 0
  154. return self.key_cache[layer_idx].shape[-2]
  155. class BambaRotaryEmbedding(nn.Module):
  156. inv_freq: torch.Tensor # fix linting for `register_buffer`
  157. def __init__(self, config: BambaConfig, device=None):
  158. super().__init__()
  159. # BC: "rope_type" was originally "type"
  160. if hasattr(config, "rope_scaling") and isinstance(config.rope_scaling, dict):
  161. self.rope_type = config.rope_scaling.get("rope_type", config.rope_scaling.get("type"))
  162. else:
  163. self.rope_type = "default"
  164. self.max_seq_len_cached = config.max_position_embeddings
  165. self.original_max_seq_len = config.max_position_embeddings
  166. self.config = config
  167. self.rope_init_fn = ROPE_INIT_FUNCTIONS[self.rope_type]
  168. inv_freq, self.attention_scaling = self.rope_init_fn(self.config, device)
  169. self.register_buffer("inv_freq", inv_freq, persistent=False)
  170. self.original_inv_freq = self.inv_freq
  171. @torch.no_grad()
  172. @dynamic_rope_update # power user: used with advanced RoPE types (e.g. dynamic rope)
  173. def forward(self, x, position_ids):
  174. inv_freq_expanded = self.inv_freq[None, :, None].float().expand(position_ids.shape[0], -1, 1).to(x.device)
  175. position_ids_expanded = position_ids[:, None, :].float()
  176. device_type = x.device.type if isinstance(x.device.type, str) and x.device.type != "mps" else "cpu"
  177. with torch.autocast(device_type=device_type, enabled=False): # Force float32
  178. freqs = (inv_freq_expanded.float() @ position_ids_expanded.float()).transpose(1, 2)
  179. emb = torch.cat((freqs, freqs), dim=-1)
  180. cos = emb.cos() * self.attention_scaling
  181. sin = emb.sin() * self.attention_scaling
  182. return cos.to(dtype=x.dtype), sin.to(dtype=x.dtype)
  183. def rotate_half(x):
  184. """Rotates half the hidden dims of the input."""
  185. x1 = x[..., : x.shape[-1] // 2]
  186. x2 = x[..., x.shape[-1] // 2 :]
  187. return torch.cat((-x2, x1), dim=-1)
  188. def repeat_kv(hidden_states: torch.Tensor, n_rep: int) -> torch.Tensor:
  189. """
  190. This is the equivalent of torch.repeat_interleave(x, dim=1, repeats=n_rep). The hidden states go from (batch,
  191. num_key_value_heads, seqlen, head_dim) to (batch, num_attention_heads, seqlen, head_dim)
  192. """
  193. batch, num_key_value_heads, slen, head_dim = hidden_states.shape
  194. if n_rep == 1:
  195. return hidden_states
  196. hidden_states = hidden_states[:, :, None, :, :].expand(batch, num_key_value_heads, n_rep, slen, head_dim)
  197. return hidden_states.reshape(batch, num_key_value_heads * n_rep, slen, head_dim)
  198. def eager_attention_forward(
  199. module: nn.Module,
  200. query: torch.Tensor,
  201. key: torch.Tensor,
  202. value: torch.Tensor,
  203. attention_mask: Optional[torch.Tensor],
  204. scaling: float,
  205. dropout: float = 0.0,
  206. **kwargs: Unpack[TransformersKwargs],
  207. ):
  208. key_states = repeat_kv(key, module.num_key_value_groups)
  209. value_states = repeat_kv(value, module.num_key_value_groups)
  210. attn_weights = torch.matmul(query, key_states.transpose(2, 3)) * scaling
  211. if attention_mask is not None:
  212. causal_mask = attention_mask[:, :, :, : key_states.shape[-2]]
  213. attn_weights = attn_weights + causal_mask
  214. attn_weights = nn.functional.softmax(attn_weights, dim=-1, dtype=torch.float32).to(query.dtype)
  215. attn_weights = nn.functional.dropout(attn_weights, p=dropout, training=module.training)
  216. attn_output = torch.matmul(attn_weights, value_states)
  217. attn_output = attn_output.transpose(1, 2).contiguous()
  218. return attn_output, attn_weights
  219. # Adapted from transformers.models.glm.modular_glm.apply_rotary_pos_emb
  220. def apply_rotary_pos_emb(q, k, cos, sin, position_ids=None, unsqueeze_dim=1):
  221. """Applies Rotary Position Embedding to the query and key tensors.
  222. Removes the interleaving of cos and sin from GLM
  223. Args:
  224. q (`torch.Tensor`): The query tensor.
  225. k (`torch.Tensor`): The key tensor.
  226. cos (`torch.Tensor`): The cosine part of the rotary embedding.
  227. sin (`torch.Tensor`): The sine part of the rotary embedding.
  228. position_ids (`torch.Tensor`, *optional*):
  229. Deprecated and unused.
  230. unsqueeze_dim (`int`, *optional*, defaults to 1):
  231. The 'unsqueeze_dim' argument specifies the dimension along which to unsqueeze cos[position_ids] and
  232. sin[position_ids] so that they can be properly broadcasted to the dimensions of q and k. For example, note
  233. that cos[position_ids] and sin[position_ids] have the shape [batch_size, seq_len, head_dim]. Then, if q and
  234. k have the shape [batch_size, heads, seq_len, head_dim], then setting unsqueeze_dim=1 makes
  235. cos[position_ids] and sin[position_ids] broadcastable to the shapes of q and k. Similarly, if q and k have
  236. the shape [batch_size, seq_len, heads, head_dim], then set unsqueeze_dim=2.
  237. Returns:
  238. `tuple(torch.Tensor)` comprising of the query and key tensors rotated using the Rotary Position Embedding.
  239. """
  240. cos = cos.unsqueeze(unsqueeze_dim)
  241. sin = sin.unsqueeze(unsqueeze_dim)
  242. # Keep half or full tensor for later concatenation
  243. rotary_dim = cos.shape[-1]
  244. q_rot, q_pass = q[..., :rotary_dim], q[..., rotary_dim:]
  245. k_rot, k_pass = k[..., :rotary_dim], k[..., rotary_dim:]
  246. # Apply rotary embeddings on the first half or full tensor
  247. q_embed = (q_rot * cos) + (rotate_half(q_rot) * sin)
  248. k_embed = (k_rot * cos) + (rotate_half(k_rot) * sin)
  249. # Concatenate back to full shape
  250. q_embed = torch.cat([q_embed, q_pass], dim=-1)
  251. k_embed = torch.cat([k_embed, k_pass], dim=-1)
  252. return q_embed, k_embed
  253. class BambaAttention(nn.Module):
  254. """Multi-headed attention from 'Attention Is All You Need' paper"""
  255. def __init__(self, config: BambaConfig, layer_idx: int):
  256. super().__init__()
  257. self.config = config
  258. self.layer_idx = layer_idx
  259. self.head_dim = getattr(config, "head_dim", config.hidden_size // config.num_attention_heads)
  260. self.num_key_value_groups = config.num_attention_heads // config.num_key_value_heads
  261. self.scaling = self.head_dim**-0.5
  262. self.attention_dropout = config.attention_dropout
  263. self.is_causal = True
  264. self.q_proj = nn.Linear(
  265. config.hidden_size, config.num_attention_heads * self.head_dim, bias=config.attention_bias
  266. )
  267. self.k_proj = nn.Linear(
  268. config.hidden_size, config.num_key_value_heads * self.head_dim, bias=config.attention_bias
  269. )
  270. self.v_proj = nn.Linear(
  271. config.hidden_size, config.num_key_value_heads * self.head_dim, bias=config.attention_bias
  272. )
  273. self.o_proj = nn.Linear(
  274. config.num_attention_heads * self.head_dim, config.hidden_size, bias=config.attention_bias
  275. )
  276. @deprecate_kwarg("past_key_value", new_name="past_key_values", version="4.58")
  277. def forward(
  278. self,
  279. hidden_states: torch.Tensor,
  280. position_embeddings: tuple[torch.Tensor, torch.Tensor],
  281. attention_mask: Optional[torch.Tensor],
  282. past_key_values: Optional[Cache] = None,
  283. cache_position: Optional[torch.LongTensor] = None,
  284. **kwargs: Unpack[TransformersKwargs],
  285. ) -> tuple[torch.Tensor, torch.Tensor]:
  286. input_shape = hidden_states.shape[:-1]
  287. hidden_shape = (*input_shape, -1, self.head_dim)
  288. query_states = self.q_proj(hidden_states).view(hidden_shape).transpose(1, 2)
  289. key_states = self.k_proj(hidden_states).view(hidden_shape).transpose(1, 2)
  290. value_states = self.v_proj(hidden_states).view(hidden_shape).transpose(1, 2)
  291. cos, sin = position_embeddings
  292. query_states, key_states = apply_rotary_pos_emb(query_states, key_states, cos, sin)
  293. if past_key_values is not None:
  294. # sin and cos are specific to RoPE models; cache_position needed for the static cache
  295. cache_kwargs = {"sin": sin, "cos": cos, "cache_position": cache_position}
  296. key_states, value_states = past_key_values.update(key_states, value_states, self.layer_idx, cache_kwargs)
  297. attention_interface: Callable = eager_attention_forward
  298. if self.config._attn_implementation != "eager":
  299. attention_interface = ALL_ATTENTION_FUNCTIONS[self.config._attn_implementation]
  300. attn_output, attn_weights = attention_interface(
  301. self,
  302. query_states,
  303. key_states,
  304. value_states,
  305. attention_mask,
  306. dropout=0.0 if not self.training else self.attention_dropout,
  307. scaling=self.scaling,
  308. **kwargs,
  309. )
  310. attn_output = attn_output.reshape(*input_shape, -1).contiguous()
  311. attn_output = self.o_proj(attn_output)
  312. return attn_output, attn_weights
  313. class BambaRMSNormGated(torch.nn.Module):
  314. def __init__(self, hidden_size, eps=1e-6):
  315. super().__init__()
  316. self.weight = nn.Parameter(torch.ones(hidden_size))
  317. self.variance_epsilon = eps
  318. def forward(self, hidden_states, gate=None):
  319. input_dtype = hidden_states.dtype
  320. hidden_states = hidden_states.to(torch.float32)
  321. if gate is not None:
  322. hidden_states = hidden_states * nn.functional.silu(gate.to(torch.float32))
  323. variance = hidden_states.pow(2).mean(-1, keepdim=True)
  324. hidden_states = hidden_states * torch.rsqrt(variance + self.variance_epsilon)
  325. return self.weight * hidden_states.to(input_dtype)
  326. # Helper methods for segment sum computation
  327. def pad_tensor_by_size(input_tensor: torch.Tensor, pad_size: int):
  328. """
  329. Padding x tensor with `pad_size` on the seq_len dim (dim=1)
  330. Assumes that we only have tensors of either size 4 or 3
  331. """
  332. pad_shape = (0, 0, 0, 0, 0, pad_size, 0, 0) if len(input_tensor.shape) == 4 else (0, 0, 0, pad_size, 0, 0)
  333. return torch.nn.functional.pad(input_tensor, pad_shape, mode="constant", value=0)
  334. def reshape_into_chunks(input_tensor, pad_size, chunk_size):
  335. """
  336. Padding input_tensor with `pad_size` on the seq_len dim (dim=1) and
  337. simultaneously splitting it into chunk sequences.
  338. Assumes that we only have tensors of either size 4 or 3
  339. """
  340. # [bsz, seq_len, ...] -> [bsz, seq_len multiple of chunk_size, ...]
  341. input_tensor = pad_tensor_by_size(input_tensor, pad_size)
  342. if len(input_tensor.shape) == 3:
  343. # [bsz, seq_len multiple of chunk_size, num_heads] -> [bsz, -1, chunk_size, num_heads]
  344. return input_tensor.reshape(input_tensor.shape[0], -1, chunk_size, input_tensor.shape[2])
  345. else:
  346. # [bsz, seq_len multiple of chunk_size, num_heads, head_dim or state_size] -> [bsz, -1, chunk_size, num_heads, head_dim or state_size]
  347. return input_tensor.reshape(
  348. input_tensor.shape[0], -1, chunk_size, input_tensor.shape[2], input_tensor.shape[3]
  349. )
  350. def segment_sum(input_tensor):
  351. """
  352. More stable segment sum calculation. Uses cumulative sums and masking instead of direct subtractions.
  353. """
  354. chunk_size = input_tensor.size(-1)
  355. # 1. expand input tensor to have an additional dimension and repeat along that dimension
  356. # [..., chunk_size] -> [..., chunk_size, chunk_size]
  357. input_tensor = input_tensor[..., None].expand(*input_tensor.size(), chunk_size)
  358. # 2. create a lower triangular mask with the diagonal set to 0 to 0 out elements above diag
  359. mask = torch.tril(torch.ones(chunk_size, chunk_size, device=input_tensor.device, dtype=torch.bool), diagonal=-1)
  360. input_tensor = input_tensor.masked_fill(~mask, 0)
  361. # 3. compute actual cumsum
  362. tensor_segsum = torch.cumsum(input_tensor, dim=-2)
  363. # 4. apply mask to keep only the lower triangular part of the cumulative sum result (incl diagonal this time)
  364. mask = torch.tril(torch.ones(chunk_size, chunk_size, device=input_tensor.device, dtype=torch.bool), diagonal=0)
  365. tensor_segsum = tensor_segsum.masked_fill(~mask, -torch.inf)
  366. return tensor_segsum
  367. is_fast_path_available = all((selective_state_update, causal_conv1d_fn, causal_conv1d_update))
  368. def apply_mask_to_padding_states(hidden_states, attention_mask):
  369. """
  370. Tunes out the hidden states for padding tokens, see https://github.com/state-spaces/mamba/issues/66
  371. """
  372. if attention_mask is not None and attention_mask.shape[1] > 1 and attention_mask.shape[0] > 1:
  373. dtype = hidden_states.dtype
  374. hidden_states = (hidden_states * attention_mask[:, :, None]).to(dtype)
  375. return hidden_states
  376. # Adapted from transformers.models.mamba2.modeling_mamba2.Mamba2Mixer
  377. class BambaMixer(nn.Module):
  378. """
  379. Compute ∆, A, B, C, and D the state space parameters and compute the `contextualized_states`.
  380. A, D are input independent (see Mamba paper [1] Section 3.5.2 "Interpretation of A" for why A isn't selective)
  381. ∆, B, C are input-dependent (this is a key difference between Mamba and the linear time invariant S4,
  382. and is why Mamba is called **selective** state spaces)
  383. The are a few differences between this and Mamba2Mixer:
  384. - The variable use_precomputed_states is slightly different due to the hybrid cache structure
  385. - There's a few non-obvious bugs fixed with batching in the slow path that exist in main
  386. - Some extra variables that our layer doesn't need have been removed
  387. - We ported most of the refactors in https://github.com/huggingface/transformers/pull/35154, which is (as of Dec 18, 2024) unmerged
  388. """
  389. def __init__(self, config: BambaConfig, layer_idx: int):
  390. super().__init__()
  391. self.num_heads = config.mamba_n_heads
  392. self.hidden_size = config.hidden_size
  393. self.ssm_state_size = config.mamba_d_state
  394. self.conv_kernel_size = config.mamba_d_conv
  395. self.intermediate_size = int(config.mamba_expand * self.hidden_size)
  396. self.layer_idx = layer_idx
  397. self.use_conv_bias = config.mamba_conv_bias
  398. self.activation = config.hidden_act
  399. self.act = ACT2FN[config.hidden_act]
  400. self.use_bias = config.mamba_proj_bias
  401. self.layer_norm_epsilon = config.rms_norm_eps
  402. self.n_groups = config.mamba_n_groups
  403. self.head_dim = config.mamba_d_head
  404. self.chunk_size = config.mamba_chunk_size
  405. # FIXME:
  406. self.time_step_limit = (0.0, float("inf"))
  407. self.time_step_min = 0.001
  408. self.time_step_max = 0.1
  409. self.conv_dim = self.intermediate_size + 2 * self.n_groups * self.ssm_state_size
  410. self.conv1d = nn.Conv1d(
  411. in_channels=self.conv_dim,
  412. out_channels=self.conv_dim,
  413. bias=config.mamba_conv_bias,
  414. kernel_size=self.conv_kernel_size,
  415. groups=self.conv_dim,
  416. padding=self.conv_kernel_size - 1,
  417. )
  418. # projection of the input hidden states
  419. projection_size = self.intermediate_size + self.conv_dim + self.num_heads
  420. self.in_proj = nn.Linear(
  421. self.hidden_size,
  422. projection_size,
  423. bias=self.use_bias,
  424. )
  425. # selective projection used to make dt, B and C input dependent
  426. # time step projection (discretization)
  427. # instantiate once and copy inv_dt in init_weights of PretrainedModel
  428. self.dt_bias = nn.Parameter(torch.ones(self.num_heads))
  429. # S4D real initialization. These are not discretized!
  430. # The core is to load them, compute the discrete states, then write the updated state. Keeps the memory bounded
  431. A = torch.arange(1, self.num_heads + 1)
  432. self.A_log = nn.Parameter(torch.log(A))
  433. self.norm = BambaRMSNormGated(self.intermediate_size, eps=self.layer_norm_epsilon)
  434. self.D = nn.Parameter(torch.ones(self.num_heads))
  435. self.out_proj = nn.Linear(self.intermediate_size, self.hidden_size, bias=self.use_bias)
  436. if not is_fast_path_available:
  437. logger.warning_once(
  438. "The fast path is not available because one of `(selective_state_update, causal_conv1d_fn, causal_conv1d_update)`"
  439. " is None. Falling back to the naive implementation. To install follow https://github.com/state-spaces/mamba/#installation and"
  440. " https://github.com/Dao-AILab/causal-conv1d"
  441. )
  442. else:
  443. logger.warning_once("The fast path for Bamba will be used when running the model on a GPU")
  444. def cuda_kernels_forward(
  445. self,
  446. hidden_states: torch.Tensor,
  447. cache_params: Optional[HybridMambaAttentionDynamicCache] = None,
  448. cache_position: Optional[torch.LongTensor] = None,
  449. attention_mask: Optional[torch.Tensor] = None,
  450. seq_idx: Optional[torch.IntTensor] = None,
  451. ):
  452. # 1. Gated MLP's linear projection
  453. hidden_states = apply_mask_to_padding_states(hidden_states, attention_mask)
  454. projected_states = self.in_proj(hidden_states)
  455. # Set up dimensions for reshapes later
  456. batch_size, seq_len, _ = hidden_states.shape
  457. groups_time_state_size = self.n_groups * self.ssm_state_size
  458. use_precomputed_states = (
  459. cache_params is not None
  460. and cache_params.has_previous_state
  461. and seq_len == 1
  462. and cache_params.conv_states[self.layer_idx].shape[0]
  463. == cache_params.ssm_states[self.layer_idx].shape[0]
  464. == batch_size
  465. and cache_position is not None
  466. and cache_position[0] > 0
  467. )
  468. # getting projected states from cache if it exists
  469. if use_precomputed_states:
  470. gate, hidden_states_B_C, dt = projected_states.squeeze(1).split(
  471. [self.intermediate_size, self.conv_dim, self.num_heads], dim=-1
  472. )
  473. # 2. Convolution sequence transformation
  474. hidden_states_B_C = causal_conv1d_update(
  475. hidden_states_B_C,
  476. cache_params.conv_states[self.layer_idx],
  477. self.conv1d.weight.squeeze(1),
  478. self.conv1d.bias,
  479. self.activation,
  480. )
  481. hidden_states, B, C = torch.split(
  482. hidden_states_B_C,
  483. [self.intermediate_size, groups_time_state_size, groups_time_state_size],
  484. dim=-1,
  485. )
  486. # 3. SSM transformation
  487. A = -torch.exp(self.A_log.float()) # (nheads,)
  488. A = A[:, None, ...][:, :, None].expand(-1, self.head_dim, self.ssm_state_size).to(dtype=torch.float32)
  489. dt = dt[:, :, None].expand(-1, -1, self.head_dim)
  490. dt_bias = self.dt_bias[:, None, ...].expand(-1, self.head_dim)
  491. D = self.D[:, None, ...].expand(-1, self.head_dim)
  492. B = B.view(batch_size, self.n_groups, B.shape[1] // self.n_groups)
  493. C = C.view(batch_size, self.n_groups, C.shape[1] // self.n_groups)
  494. hidden_states_reshaped = hidden_states.view(batch_size, self.num_heads, self.head_dim)
  495. hidden_states = selective_state_update(
  496. cache_params.ssm_states[self.layer_idx],
  497. hidden_states_reshaped,
  498. dt,
  499. A,
  500. B,
  501. C,
  502. D,
  503. z=None,
  504. dt_bias=dt_bias,
  505. dt_softplus=True,
  506. )
  507. hidden_states = hidden_states.view(batch_size, self.num_heads * self.head_dim)
  508. hidden_states = self.norm(hidden_states, gate)
  509. # 4. Final linear projection
  510. out = self.out_proj(hidden_states)[:, None, ...]
  511. # Fused calculations or step by step if no initialized cache is found
  512. else:
  513. A = -torch.exp(self.A_log.float()) # (num_heads) or (intermediate_size, state_size)
  514. dt_limit_kwargs = {} if self.time_step_limit == (0.0, float("inf")) else {"dt_limit": self.time_step_limit}
  515. # 2-4. Fused kernel for conv1d, SSM, and the final projection
  516. if self.training and cache_params is None:
  517. out = mamba_split_conv1d_scan_combined(
  518. projected_states,
  519. self.conv1d.weight.squeeze(1),
  520. self.conv1d.bias,
  521. self.dt_bias,
  522. A,
  523. D=self.D,
  524. chunk_size=self.chunk_size,
  525. seq_idx=seq_idx,
  526. activation=self.activation,
  527. rmsnorm_weight=self.norm.weight,
  528. rmsnorm_eps=self.norm.variance_epsilon,
  529. outproj_weight=self.out_proj.weight,
  530. outproj_bias=self.out_proj.bias,
  531. headdim=self.head_dim,
  532. ngroups=self.n_groups,
  533. norm_before_gate=False,
  534. return_final_states=False,
  535. **dt_limit_kwargs,
  536. )
  537. else:
  538. gate, hidden_states_B_C, dt = projected_states.split(
  539. [self.intermediate_size, self.conv_dim, self.num_heads], dim=-1
  540. )
  541. # 2. Convolution sequence transformation
  542. # Init cache
  543. if cache_params is not None:
  544. # storing the states
  545. # If we just take xBC[:, :, -self.d_conv :], it will error if seqlen < self.d_conv
  546. # Instead F.pad will pad with zeros if seqlen < self.d_conv, and truncate otherwise.
  547. hidden_states_B_C_transposed = hidden_states_B_C.transpose(1, 2)
  548. conv_states = nn.functional.pad(
  549. hidden_states_B_C_transposed,
  550. (self.conv_kernel_size - hidden_states_B_C_transposed.shape[-1], 0),
  551. )
  552. cache_params.conv_states[self.layer_idx].copy_(conv_states)
  553. if self.activation not in ["silu", "swish"]:
  554. hidden_states_B_C = self.act(
  555. self.conv1d(hidden_states_B_C.transpose(1, 2))[..., :seq_len].transpose(1, 2)
  556. )
  557. else:
  558. hidden_states_B_C = causal_conv1d_fn(
  559. x=hidden_states_B_C.transpose(1, 2),
  560. weight=self.conv1d.weight.squeeze(1),
  561. bias=self.conv1d.bias,
  562. activation=self.activation,
  563. seq_idx=seq_idx,
  564. ).transpose(1, 2)
  565. hidden_states_B_C = apply_mask_to_padding_states(hidden_states_B_C, attention_mask)
  566. hidden_states, B, C = torch.split(
  567. hidden_states_B_C,
  568. [self.intermediate_size, groups_time_state_size, groups_time_state_size],
  569. dim=-1,
  570. )
  571. # 3. SSM transformation
  572. scan_output, ssm_state = mamba_chunk_scan_combined(
  573. hidden_states.view(batch_size, seq_len, -1, self.head_dim),
  574. dt,
  575. A,
  576. B.view(batch_size, seq_len, self.n_groups, -1),
  577. C.view(batch_size, seq_len, self.n_groups, -1),
  578. chunk_size=self.chunk_size,
  579. D=self.D,
  580. z=None,
  581. seq_idx=seq_idx,
  582. return_final_states=True,
  583. dt_bias=self.dt_bias,
  584. dt_softplus=True,
  585. **dt_limit_kwargs,
  586. )
  587. # Init cache
  588. if ssm_state is not None and cache_params is not None:
  589. cache_params.ssm_states[self.layer_idx].copy_(ssm_state)
  590. scan_output = scan_output.view(batch_size, seq_len, -1)
  591. # Multiply "gate" branch and apply extra normalization layer
  592. scan_output = self.norm(scan_output, gate)
  593. # 4. Final linear projection
  594. out = self.out_proj(scan_output)
  595. return out
  596. # fmt: off
  597. def torch_forward(
  598. self,
  599. input_states,
  600. cache_params: Optional[HybridMambaAttentionDynamicCache] = None,
  601. cache_position: Optional[torch.LongTensor] = None,
  602. attention_mask: Optional[torch.Tensor] = None,
  603. ):
  604. batch_size, seq_len, _ = input_states.shape
  605. dtype = input_states.dtype
  606. # 1. Gated MLP's linear projection
  607. input_states = apply_mask_to_padding_states(input_states, attention_mask)
  608. projected_states = self.in_proj(input_states)
  609. gate, hidden_states_B_C, dt = projected_states.split(
  610. [self.intermediate_size, self.conv_dim, self.num_heads], dim=-1
  611. )
  612. use_precomputed_states = (
  613. cache_params is not None
  614. and cache_params.has_previous_state
  615. and seq_len == 1
  616. and cache_params.conv_states[self.layer_idx].shape[0]
  617. == cache_params.ssm_states[self.layer_idx].shape[0]
  618. == batch_size
  619. and cache_position is not None
  620. and cache_position[0] > 0
  621. )
  622. # 2. Convolution sequence transformation
  623. if use_precomputed_states:
  624. cache_params.conv_states[self.layer_idx] = cache_params.conv_states[self.layer_idx].roll(shifts=-1, dims=-1)
  625. cache_params.conv_states[self.layer_idx][:, :, -1] = hidden_states_B_C[:, 0, :].to(cache_params.conv_states[self.layer_idx].device)
  626. # We need to guarantee that anything regarding the cache is on the same device
  627. conv_states = cache_params.conv_states[self.layer_idx].to(device=self.conv1d.weight.device)
  628. hidden_states_B_C = torch.sum(
  629. conv_states * self.conv1d.weight.squeeze(1), dim=-1
  630. )
  631. if self.use_conv_bias:
  632. hidden_states_B_C = hidden_states_B_C + self.conv1d.bias
  633. hidden_states_B_C = self.act(hidden_states_B_C)
  634. else:
  635. # Init cache
  636. if cache_params is not None:
  637. hidden_states_B_C_transposed = hidden_states_B_C.transpose(1, 2)
  638. conv_states = nn.functional.pad(
  639. hidden_states_B_C_transposed, (self.conv_kernel_size - hidden_states_B_C_transposed.shape[-1], 0)
  640. )
  641. cache_params.conv_states[self.layer_idx].copy_(conv_states)
  642. hidden_states_B_C = self.act(self.conv1d(hidden_states_B_C.transpose(1, 2))[..., :seq_len].transpose(1, 2))
  643. hidden_states_B_C = apply_mask_to_padding_states(hidden_states_B_C, attention_mask)
  644. hidden_states, B, C = torch.split(
  645. hidden_states_B_C,
  646. [self.intermediate_size, self.n_groups * self.ssm_state_size, self.n_groups * self.ssm_state_size],
  647. dim=-1
  648. )
  649. # 3. SSM transformation
  650. A = -torch.exp(self.A_log.float()) # [num_heads]
  651. if use_precomputed_states:
  652. # We need to guarantee that anything regarding the cache is on the same device
  653. cache_device = cache_params.ssm_states[self.layer_idx].device
  654. # Note: there is no need to pad parameter matrices here, as there is just one new token
  655. # for batched generation
  656. dt = dt[:, 0, :][:, None, ...]
  657. dt = dt.transpose(1, 2).expand(batch_size, dt.shape[-1], self.head_dim)
  658. # [num_heads] -> [num_heads, head_dim]
  659. dt_bias = self.dt_bias[..., None].expand(self.dt_bias.shape[0], self.head_dim)
  660. dt = torch.nn.functional.softplus(dt + dt_bias.to(dt.dtype))
  661. dt = torch.clamp(dt, self.time_step_limit[0], self.time_step_limit[1])
  662. A = A[..., None, None].expand(self.num_heads, self.head_dim, self.ssm_state_size).to(dtype=torch.float32)
  663. # [bsz, num_heads, head_dim, state_size]
  664. dA = (torch.exp(dt[..., None] * A)).to(device=cache_device)
  665. # Discretize B
  666. # [bsz, n_groups * state_size] -> [bsz, n_groups, 1, state_size] ->
  667. # -> [bsz, n_groups, group to head repetition factor, state_size] -> [bsz, num_heads, state_size]
  668. B = B.reshape(batch_size, self.n_groups, -1)[..., None, :]
  669. B = B.expand(batch_size, self.n_groups, self.num_heads // self.n_groups, B.shape[-1]).contiguous()
  670. B = B.reshape(batch_size, -1, B.shape[-1])
  671. # [bsz, num_heads, head_dim, state_size]
  672. dB = dt[..., None] * B[..., None, :]
  673. # Discretize x into dB
  674. # [bsz, intermediate_size] -> [bsz, num_heads, head_dim]
  675. hidden_states = hidden_states.reshape(batch_size, -1, self.head_dim)
  676. dBx = (dB * hidden_states[..., None]).to(device=cache_device)
  677. # State calculation
  678. cache_params.ssm_states[self.layer_idx].copy_(
  679. cache_params.ssm_states[self.layer_idx] * dA + dBx
  680. )
  681. # Subsequent output
  682. # [bsz, n_groups * state_size] -> [bsz, num_heads, state_size]
  683. C = C.reshape(batch_size, self.n_groups, -1)[..., None, :]
  684. C = C.expand(batch_size, self.n_groups, self.num_heads // self.n_groups, C.shape[-1]).contiguous()
  685. C = C.reshape(batch_size, -1, C.shape[-1])
  686. # [bsz, num_heads, head_dim]
  687. ssm_states = cache_params.ssm_states[self.layer_idx].to(device=C.device, dtype=C.dtype) # Shape: [b, h, d, n]
  688. # Reshape ssm_states to merge the first two dimensions
  689. ssm_states_reshaped = ssm_states.view(batch_size * self.num_heads, self.head_dim, self.ssm_state_size) # Shape: [b*h, d, n]
  690. C_reshaped = C.view(batch_size * self.num_heads, self.ssm_state_size, 1) # Shape: [b*h, n, 1]
  691. y = torch.bmm(ssm_states_reshaped, C_reshaped)
  692. y = y.view(batch_size, self.num_heads, self.head_dim)
  693. # D skip connection
  694. # [num_heads] -> [num_heads, head_dim]
  695. D = self.D[..., None].expand(self.D.shape[0], self.head_dim)
  696. y = (y + hidden_states * D).to(y.dtype)
  697. # [bsz, num_heads, head_dim] -> [bsz, 1, intermediate_size]
  698. y = y.reshape(batch_size, -1)[:, None, ...]
  699. else:
  700. # begin ssd naive implementation without einsums
  701. dt = nn.functional.softplus(dt + self.dt_bias)
  702. dt = torch.clamp(dt, self.time_step_limit[0], self.time_step_limit[1])
  703. hidden_states = hidden_states.reshape(batch_size, seq_len, -1, self.head_dim).float()
  704. B = B.reshape(batch_size, seq_len, -1, self.ssm_state_size).float()
  705. C = C.reshape(batch_size, seq_len, -1, self.ssm_state_size).float()
  706. B = B.repeat_interleave(self.num_heads // self.n_groups, dim=2, output_size=self.num_heads)
  707. C = C.repeat_interleave(self.num_heads // self.n_groups, dim=2, output_size=self.num_heads)
  708. pad_size = (self.chunk_size - seq_len % self.chunk_size) % self.chunk_size
  709. D_residual = self.D[..., None] * pad_tensor_by_size(hidden_states, pad_size)
  710. # Discretize x and A
  711. hidden_states = hidden_states * dt[..., None]
  712. A = A.to(hidden_states.dtype) * dt
  713. # Rearrange into blocks/chunks
  714. hidden_states, A, B, C = [reshape_into_chunks(t, pad_size, self.chunk_size) for t in (hidden_states, A, B, C)]
  715. # [bsz, -1, chunk_size, num_heads] -> [bsz, num_heads, -1, chunk_size]
  716. A = A.permute(0, 3, 1, 2)
  717. A_cumsum = torch.cumsum(A, dim=-1)
  718. # 1. Compute the output for each intra-chunk (diagonal blocks)
  719. # This is the analog of a causal mask
  720. L = torch.exp(segment_sum(A))
  721. # Contraction of C and B to get G (attention-weights like)
  722. G_intermediate = C[:, :, :, None, :, :] * B[:, :, None, :, :, :] # shape: (b, c, l, s, h, n)
  723. G = G_intermediate.sum(dim=-1) # shape: (b, c, l, s, h)
  724. # Compute M, equivalent to applying attention mask to weights
  725. M_intermediate = G[..., None] * L.permute(0, 2, 3, 4, 1)[..., None]
  726. M = M_intermediate.sum(dim=-1)
  727. # Compute Y_diag (apply to values)
  728. Y_diag = (M[..., None] * hidden_states[:, :, None]).sum(dim=3)
  729. # 2. Compute the state for each intra-chunk
  730. # (right term of low-rank factorization of off-diagonal blocks; B terms)
  731. decay_states = torch.exp(A_cumsum[:, :, :, -1:] - A_cumsum)
  732. B_decay = B * decay_states.permute(0, -2, -1, 1)[..., None]
  733. states = (B_decay[..., None, :] * hidden_states[..., None]).sum(dim=2)
  734. # 3. Compute the inter-chunk SSM recurrence; produces correct SSM states at chunk boundaries
  735. # (middle term of factorization of off-diag blocks; A terms)
  736. if use_precomputed_states:
  737. previous_states = cache_params.ssm_states[self.layer_idx][:, None, ...].to(device=states.device)
  738. else:
  739. previous_states = torch.zeros_like(states[:, :1])
  740. states = torch.cat([previous_states, states], dim=1)
  741. decay_chunk = torch.exp(segment_sum(nn.functional.pad(A_cumsum[:, :, :, -1], (1, 0))))
  742. decay_chunk = decay_chunk.transpose(1, 3)
  743. new_states = (decay_chunk[..., None, None] * states[:, :, None, ...]).sum(dim=1)
  744. states, ssm_state = new_states[:, :-1], new_states[:, -1]
  745. # 4. Compute state -> output conversion per chunk
  746. # (left term of low-rank factorization of off-diagonal blocks; C terms)
  747. state_decay_out = torch.exp(A_cumsum)
  748. C_times_states = (C[..., None, :] * states[:, :, None, ...])
  749. state_decay_out_permuted = state_decay_out.permute(0, 2, 3, 1)
  750. Y_off = (C_times_states.sum(-1) * state_decay_out_permuted[..., None])
  751. # Add output of intra-chunk and inter-chunk terms (diagonal and off-diagonal blocks)
  752. y = Y_diag + Y_off
  753. # [bsz, -1, self.chunk_size, num_heads, head_dim] -> [bsz, (padded) seq_len, num_heads, head_dim]
  754. y = y.reshape(batch_size, -1, self.num_heads, self.head_dim)
  755. y = y + D_residual
  756. # Cutting off padded chunks
  757. if pad_size > 0:
  758. y = y[:, :seq_len, :, :]
  759. y = y.reshape(batch_size, seq_len, -1)
  760. # Init cache
  761. if ssm_state is not None and cache_params is not None:
  762. cache_params.ssm_states[self.layer_idx].copy_(ssm_state)
  763. scan_output = self.norm(y, gate)
  764. # end ssd naive
  765. # 4. Final linear projection
  766. contextualized_states = self.out_proj(scan_output.to(dtype)) # [batch, seq_len, hidden_size]
  767. return contextualized_states
  768. # fmt: on
  769. def forward(
  770. self,
  771. hidden_states,
  772. cache_params: Optional[HybridMambaAttentionDynamicCache] = None,
  773. cache_position: Optional[torch.LongTensor] = None,
  774. attention_mask: Optional[torch.Tensor] = None,
  775. seq_idx: Optional[torch.IntTensor] = None,
  776. **kwargs,
  777. ):
  778. if is_fast_path_available and "cuda" in self.in_proj.weight.device.type:
  779. return self.cuda_kernels_forward(hidden_states, cache_params, cache_position, attention_mask, seq_idx)
  780. if seq_idx is not None:
  781. raise NotImplementedError(
  782. "`seq_idx` support requires fast path support. Please install `mamba_ssm` and `causal_conv1d`"
  783. )
  784. dtype = hidden_states.dtype
  785. if attention_mask is not None and attention_mask.shape[1] > 1 and attention_mask.shape[0] > 1:
  786. # tune out hidden states for pad tokens, see https://github.com/state-spaces/mamba/issues/66
  787. hidden_states = (hidden_states * attention_mask[:, :, None]).to(dtype)
  788. return self.torch_forward(hidden_states, cache_params, cache_position, attention_mask)
  789. class BambaMLP(nn.Module):
  790. def __init__(self, config):
  791. super().__init__()
  792. self.config = config
  793. self.hidden_size = config.hidden_size
  794. self.intermediate_size = config.intermediate_size
  795. self.gate_proj = nn.Linear(self.hidden_size, self.intermediate_size, bias=config.mlp_bias)
  796. self.up_proj = nn.Linear(self.hidden_size, self.intermediate_size, bias=config.mlp_bias)
  797. self.down_proj = nn.Linear(self.intermediate_size, self.hidden_size, bias=config.mlp_bias)
  798. self.act_fn = ACT2FN[config.hidden_act]
  799. def forward(self, x):
  800. down_proj = self.down_proj(self.act_fn(self.gate_proj(x)) * self.up_proj(x))
  801. return down_proj
  802. @use_kernel_forward_from_hub("RMSNorm")
  803. class BambaRMSNorm(nn.Module):
  804. def __init__(self, hidden_size, eps=1e-6):
  805. """
  806. BambaRMSNorm is equivalent to T5LayerNorm
  807. """
  808. super().__init__()
  809. self.weight = nn.Parameter(torch.ones(hidden_size))
  810. self.variance_epsilon = eps
  811. def forward(self, hidden_states):
  812. input_dtype = hidden_states.dtype
  813. hidden_states = hidden_states.to(torch.float32)
  814. variance = hidden_states.pow(2).mean(-1, keepdim=True)
  815. hidden_states = hidden_states * torch.rsqrt(variance + self.variance_epsilon)
  816. return self.weight * hidden_states.to(input_dtype)
  817. def extra_repr(self):
  818. return f"{tuple(self.weight.shape)}, eps={self.variance_epsilon}"
  819. class BambaDecoderLayer(GradientCheckpointingLayer):
  820. def __init__(self, config: BambaConfig, layer_idx: int, layer_type: str = "mamba"):
  821. super().__init__()
  822. num_experts = 1
  823. ffn_layer_class = BambaMLP if num_experts == 1 else None
  824. self.feed_forward = ffn_layer_class(config)
  825. self.input_layernorm = BambaRMSNorm(config.hidden_size, eps=config.rms_norm_eps)
  826. self.pre_ff_layernorm = BambaRMSNorm(config.hidden_size, eps=config.rms_norm_eps)
  827. self.layer_type = layer_type
  828. if layer_type == "mamba":
  829. self.mamba = BambaMixer(config=config, layer_idx=layer_idx)
  830. elif layer_type == "attention":
  831. self.self_attn = BambaAttention(config, layer_idx)
  832. else:
  833. raise ValueError("Invalid layer_type")
  834. @deprecate_kwarg("past_key_value", new_name="past_key_values", version="4.58")
  835. def forward(
  836. self,
  837. hidden_states: torch.Tensor,
  838. attention_mask: Optional[torch.Tensor] = None,
  839. position_ids: Optional[torch.LongTensor] = None,
  840. past_key_values: Optional[HybridMambaAttentionDynamicCache] = None,
  841. output_attentions: Optional[bool] = False,
  842. use_cache: Optional[bool] = False,
  843. cache_position: Optional[torch.LongTensor] = None,
  844. position_embeddings: Optional[tuple[torch.Tensor, torch.Tensor]] = None, # necessary, but kept here for BC
  845. **kwargs: Unpack[BambaFlashAttentionKwargs],
  846. ) -> tuple[torch.FloatTensor, Optional[tuple[torch.FloatTensor, torch.FloatTensor]]]:
  847. """
  848. Args:
  849. hidden_states (`torch.FloatTensor`): input to the layer of shape `(batch, seq_len, embed_dim)`
  850. attention_mask (`torch.FloatTensor`, *optional*): attention mask of size
  851. `(batch, sequence_length)` where padding elements are indicated by 0.
  852. past_key_values (`HybridMambaAttentionDynamicCache`, *optional*): cached past key and value projection states
  853. output_attentions (`bool`, *optional*):
  854. Whether or not to return the attentions tensors of all attention layers. See `attentions` under
  855. returned tensors for more detail.
  856. use_cache (`bool`, *optional*):
  857. If set to `True`, `past_key_values` key value states are returned and can be used to speed up decoding
  858. (see `past_key_values`).
  859. cache_position (`torch.LongTensor` of shape `(sequence_length)`, *optional*):
  860. Indices depicting the position of the input sequence tokens in the sequence.
  861. position_embeddings (`tuple[torch.FloatTensor, torch.FloatTensor]`, *optional*):
  862. Tuple containing the cosine and sine positional embeddings of shape `(batch_size, seq_len, head_dim)`,
  863. with `head_dim` being the embedding dimension of each attention head.
  864. kwargs (`dict`, *optional*):
  865. Arbitrary kwargs. Can be used to provide `BambaFlashAttentionKwargs` for
  866. padding-free training and/or improve torch.compile performance.
  867. """
  868. residual = hidden_states
  869. hidden_states = self.input_layernorm(hidden_states)
  870. # this is a hybrid decoder layer
  871. if self.layer_type == "mamba":
  872. hidden_states = self.mamba(
  873. hidden_states=hidden_states,
  874. cache_params=past_key_values,
  875. cache_position=cache_position,
  876. attention_mask=attention_mask,
  877. **kwargs,
  878. )
  879. self_attn_weights = None
  880. elif self.layer_type == "attention":
  881. hidden_states, self_attn_weights = self.self_attn(
  882. hidden_states=hidden_states,
  883. attention_mask=attention_mask,
  884. position_ids=position_ids,
  885. past_key_values=past_key_values,
  886. output_attentions=output_attentions,
  887. use_cache=use_cache,
  888. cache_position=cache_position,
  889. position_embeddings=position_embeddings,
  890. **kwargs,
  891. )
  892. # residual connection after attention
  893. hidden_states = residual + hidden_states
  894. # feed-forward
  895. residual = hidden_states
  896. hidden_states = self.pre_ff_layernorm(hidden_states)
  897. hidden_states = self.feed_forward(hidden_states)
  898. hidden_states = residual + hidden_states
  899. outputs = (hidden_states,)
  900. if output_attentions:
  901. outputs += (self_attn_weights,)
  902. return outputs
  903. @auto_docstring
  904. class BambaPreTrainedModel(PreTrainedModel):
  905. config: BambaConfig
  906. base_model_prefix = "model"
  907. supports_gradient_checkpointing = True
  908. _no_split_modules = ["BambaDecoderLayer"]
  909. _skip_keys_device_placement = "past_key_values"
  910. _supports_flash_attn = True
  911. _supports_sdpa = True
  912. # Note: only supports HybridMambaAttentionDynamicCache
  913. _is_stateful = True
  914. def _init_weights(self, module):
  915. super()._init_weights(module)
  916. if isinstance(module, BambaMixer):
  917. module.dt_bias.data.fill_(1.0)
  918. module.A_log.data = torch.log(torch.arange(1, module.num_heads + 1))
  919. module.D.data.fill_(1.0)
  920. @auto_docstring
  921. class BambaModel(BambaPreTrainedModel):
  922. def __init__(self, config: BambaConfig):
  923. super().__init__(config)
  924. self.padding_idx = config.pad_token_id
  925. self.vocab_size = config.vocab_size
  926. self.embed_tokens = nn.Embedding(config.vocab_size, config.hidden_size, self.padding_idx)
  927. decoder_layers = []
  928. for i in range(config.num_hidden_layers):
  929. decoder_layers.append(BambaDecoderLayer(config, layer_idx=i, layer_type=config.layers_block_type[i]))
  930. self.layers = nn.ModuleList(decoder_layers)
  931. self._attn_implementation = config._attn_implementation
  932. self.final_layernorm = BambaRMSNorm(config.hidden_size, eps=config.rms_norm_eps)
  933. self.rotary_emb = BambaRotaryEmbedding(config=config)
  934. self.gradient_checkpointing = False
  935. # Initialize weights and apply final processing
  936. self.post_init()
  937. @can_return_tuple
  938. @auto_docstring
  939. def forward(
  940. self,
  941. input_ids: Optional[torch.LongTensor] = None,
  942. attention_mask: Optional[torch.Tensor] = None,
  943. position_ids: Optional[torch.LongTensor] = None,
  944. past_key_values: Optional[HybridMambaAttentionDynamicCache] = None,
  945. inputs_embeds: Optional[torch.FloatTensor] = None,
  946. use_cache: Optional[bool] = None,
  947. output_attentions: Optional[bool] = None,
  948. output_hidden_states: Optional[bool] = None,
  949. cache_position: Optional[torch.LongTensor] = None,
  950. **kwargs: Unpack[BambaFlashAttentionKwargs],
  951. ) -> BaseModelOutputWithPast:
  952. output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
  953. output_hidden_states = (
  954. output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
  955. )
  956. use_cache = use_cache if use_cache is not None else self.config.use_cache
  957. if (input_ids is None) ^ (inputs_embeds is not None):
  958. raise ValueError("You must specify exactly one of input_ids or inputs_embeds")
  959. if self.gradient_checkpointing and self.training and use_cache:
  960. logger.warning_once(
  961. "`use_cache=True` is incompatible with gradient checkpointing. Setting `use_cache=False`."
  962. )
  963. use_cache = False
  964. if inputs_embeds is None:
  965. inputs_embeds = self.embed_tokens(input_ids)
  966. hidden_states = inputs_embeds
  967. if use_cache and past_key_values is None:
  968. logger.warning_once(
  969. "Bamba requires an initialized `HybridMambaAttentionDynamicCache` to return a cache. None was "
  970. "provided, so no cache will be returned."
  971. )
  972. if cache_position is None:
  973. cache_position = torch.arange(hidden_states.shape[1], device=hidden_states.device)
  974. if position_ids is None:
  975. position_ids = cache_position.unsqueeze(0)
  976. causal_mask = self._update_causal_mask(
  977. attention_mask, inputs_embeds, cache_position, past_key_values, output_attentions
  978. )
  979. mamba_mask = self._update_mamba_mask(attention_mask, cache_position)
  980. # create position embeddings to be shared across the decoder layers
  981. position_embeddings = self.rotary_emb(hidden_states, position_ids)
  982. all_hidden_states = () if output_hidden_states else None
  983. all_self_attns = () if output_attentions else None
  984. for decoder_layer in self.layers:
  985. # Depending on the layer type we opt for 2D base attention mask (Mamba) or 4D causal mask (Attention)
  986. layer_mask = mamba_mask if decoder_layer.layer_type == "mamba" else causal_mask
  987. if output_hidden_states:
  988. all_hidden_states += (hidden_states,)
  989. layer_outputs = decoder_layer(
  990. hidden_states,
  991. attention_mask=layer_mask,
  992. position_ids=position_ids,
  993. past_key_values=past_key_values,
  994. output_attentions=output_attentions,
  995. use_cache=use_cache,
  996. cache_position=cache_position,
  997. position_embeddings=position_embeddings,
  998. **kwargs,
  999. )
  1000. hidden_states = layer_outputs[0]
  1001. if output_attentions:
  1002. if layer_outputs[1] is not None:
  1003. # append attentions only of attention layers. Mamba layers return `None` as the attention weights
  1004. all_self_attns += (layer_outputs[1],)
  1005. hidden_states = self.final_layernorm(hidden_states)
  1006. # add hidden states from the last decoder layer
  1007. if output_hidden_states:
  1008. all_hidden_states += (hidden_states,)
  1009. if past_key_values and not past_key_values.has_previous_state:
  1010. past_key_values.has_previous_state = True
  1011. next_cache = None if not use_cache else past_key_values
  1012. return BaseModelOutputWithPast(
  1013. last_hidden_state=hidden_states,
  1014. past_key_values=next_cache,
  1015. hidden_states=all_hidden_states,
  1016. attentions=all_self_attns,
  1017. )
  1018. def _update_causal_mask(
  1019. self,
  1020. attention_mask: torch.Tensor,
  1021. input_tensor: torch.Tensor,
  1022. cache_position: torch.Tensor,
  1023. past_key_values: HybridMambaAttentionDynamicCache,
  1024. output_attentions: bool,
  1025. ):
  1026. if self.config._attn_implementation == "flash_attention_2":
  1027. if attention_mask is not None and 0.0 in attention_mask:
  1028. return attention_mask
  1029. return None
  1030. # For SDPA, when possible, we will rely on its `is_causal` argument instead of its `attn_mask` argument, in
  1031. # order to dispatch on Flash Attention 2. This feature is not compatible with static cache, as SDPA will fail
  1032. # to infer the attention mask.
  1033. past_seen_tokens = past_key_values.get_seq_length() if past_key_values is not None else 0
  1034. # When output attentions is True, sdpa implementation's forward method calls the eager implementation's forward
  1035. if self.config._attn_implementation == "sdpa" and not output_attentions:
  1036. if AttentionMaskConverter._ignore_causal_mask_sdpa(
  1037. attention_mask,
  1038. inputs_embeds=input_tensor,
  1039. past_key_values_length=past_seen_tokens,
  1040. is_training=self.training,
  1041. ):
  1042. return None
  1043. dtype = input_tensor.dtype
  1044. sequence_length = input_tensor.shape[1]
  1045. target_length = (
  1046. attention_mask.shape[-1]
  1047. if isinstance(attention_mask, torch.Tensor)
  1048. else past_seen_tokens + sequence_length + 1
  1049. )
  1050. # In case the provided `attention` mask is 2D, we generate a causal mask here (4D).
  1051. causal_mask = self._prepare_4d_causal_attention_mask_with_cache_position(
  1052. attention_mask,
  1053. sequence_length=sequence_length,
  1054. target_length=target_length,
  1055. dtype=dtype,
  1056. cache_position=cache_position,
  1057. batch_size=input_tensor.shape[0],
  1058. )
  1059. if (
  1060. self.config._attn_implementation == "sdpa"
  1061. and attention_mask is not None
  1062. and attention_mask.device.type in ["cuda", "xpu", "npu"]
  1063. and not output_attentions
  1064. ):
  1065. # Attend to all tokens in fully masked rows in the causal_mask, for example the relevant first rows when
  1066. # using left padding. This is required by F.scaled_dot_product_attention memory-efficient attention path.
  1067. # Details: https://github.com/pytorch/pytorch/issues/110213
  1068. min_dtype = torch.finfo(dtype).min
  1069. causal_mask = AttentionMaskConverter._unmask_unattended(causal_mask, min_dtype)
  1070. return causal_mask
  1071. @staticmethod
  1072. def _prepare_4d_causal_attention_mask_with_cache_position(
  1073. attention_mask: torch.Tensor,
  1074. sequence_length: int,
  1075. target_length: int,
  1076. dtype: torch.dtype,
  1077. cache_position: torch.Tensor,
  1078. batch_size: int,
  1079. **kwargs,
  1080. ):
  1081. """
  1082. Creates a causal 4D mask of shape `(batch_size, 1, query_length, key_value_length)` from a 2D mask of shape
  1083. `(batch_size, key_value_length)`, or if the input `attention_mask` is already 4D, do nothing.
  1084. Args:
  1085. attention_mask (`torch.Tensor`):
  1086. A 2D attention mask of shape `(batch_size, key_value_length)` or a 4D attention mask of shape
  1087. `(batch_size, 1, query_length, key_value_length)`.
  1088. sequence_length (`int`):
  1089. The sequence length being processed.
  1090. target_length (`int`):
  1091. The target length: when generating with static cache, the mask should be as long as the static cache,
  1092. to account for the 0 padding, the part of the cache that is not filled yet.
  1093. dtype (`torch.dtype`):
  1094. The dtype to use for the 4D attention mask.
  1095. cache_position (`torch.Tensor`):
  1096. Indices depicting the position of the input sequence tokens in the sequence.
  1097. batch_size (`torch.Tensor`):
  1098. Batch size.
  1099. """
  1100. if attention_mask is not None and attention_mask.dim() == 4:
  1101. # In this case we assume that the mask comes already in inverted form and requires no inversion or slicing.
  1102. causal_mask = attention_mask
  1103. else:
  1104. min_dtype = torch.finfo(dtype).min
  1105. causal_mask = torch.full(
  1106. (sequence_length, target_length), fill_value=min_dtype, dtype=dtype, device=cache_position.device
  1107. )
  1108. if sequence_length != 1:
  1109. causal_mask = torch.triu(causal_mask, diagonal=1)
  1110. causal_mask *= torch.arange(target_length, device=cache_position.device) > cache_position.reshape(-1, 1)
  1111. causal_mask = causal_mask[None, None, :, :].expand(batch_size, 1, -1, -1)
  1112. if attention_mask is not None:
  1113. causal_mask = causal_mask.clone() # copy to contiguous memory for in-place edit
  1114. mask_length = attention_mask.shape[-1]
  1115. padding_attention_mask = (attention_mask[:, None, None, :] == attention_mask[:, None, :, None])[
  1116. :, :, -sequence_length:, :
  1117. ].to(dtype)
  1118. padding_mask = causal_mask[:, :, :, :mask_length] + padding_attention_mask
  1119. padding_mask = padding_mask == 0
  1120. causal_mask[:, :, :, :mask_length] = causal_mask[:, :, :, :mask_length].masked_fill(
  1121. padding_mask, min_dtype
  1122. )
  1123. return causal_mask
  1124. def _update_mamba_mask(self, attention_mask, cache_position):
  1125. """
  1126. No need for zeroing states when
  1127. 1. Cached forward
  1128. 2. Attending to all inputs
  1129. """
  1130. mamba_mask = attention_mask
  1131. if cache_position[0] > 0 or (attention_mask is not None and torch.all(attention_mask == 1)):
  1132. mamba_mask = None
  1133. return mamba_mask
  1134. @auto_docstring
  1135. class BambaForCausalLM(BambaPreTrainedModel, GenerationMixin):
  1136. _tied_weights_keys = ["lm_head.weight"]
  1137. _tp_plan = {"lm_head": "colwise_rep"}
  1138. _pp_plan = {"lm_head": (["hidden_states"], ["logits"])}
  1139. def __init__(self, config):
  1140. super().__init__(config)
  1141. self.model = BambaModel(config)
  1142. self.vocab_size = config.vocab_size
  1143. self.lm_head = nn.Linear(config.hidden_size, config.vocab_size, bias=False)
  1144. self.z_loss_coefficient = config.z_loss_coefficient
  1145. # Initialize weights and apply final processing
  1146. self.post_init()
  1147. @can_return_tuple
  1148. @auto_docstring
  1149. def forward(
  1150. self,
  1151. input_ids: Optional[torch.LongTensor] = None,
  1152. attention_mask: Optional[torch.Tensor] = None,
  1153. position_ids: Optional[torch.LongTensor] = None,
  1154. past_key_values: Optional[HybridMambaAttentionDynamicCache] = None,
  1155. inputs_embeds: Optional[torch.FloatTensor] = None,
  1156. labels: Optional[torch.LongTensor] = None,
  1157. use_cache: Optional[bool] = None,
  1158. output_attentions: Optional[bool] = None,
  1159. output_hidden_states: Optional[bool] = None,
  1160. cache_position: Optional[torch.LongTensor] = None,
  1161. logits_to_keep: Union[int, torch.Tensor] = 0,
  1162. **kwargs,
  1163. ) -> CausalLMOutputWithPast:
  1164. r"""
  1165. labels (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*):
  1166. Labels for computing the masked language modeling loss. Indices should either be in `[0, ...,
  1167. config.vocab_size]` or -100 (see `input_ids` docstring). Tokens with indices set to `-100` are ignored
  1168. (masked), the loss is only computed for the tokens with labels in `[0, ..., config.vocab_size]`.
  1169. Example:
  1170. ```python
  1171. >>> from transformers import AutoTokenizer, BambaForCausalLM
  1172. >>> model = BambaForCausalLM.from_pretrained("...")
  1173. >>> tokenizer = AutoTokenizer.from_pretrained("...")
  1174. >>> prompt = "Hey, are you conscious? Can you talk to me?"
  1175. >>> inputs = tokenizer(prompt, return_tensors="pt")
  1176. >>> # Generate
  1177. >>> generate_ids = model.generate(inputs.input_ids, max_length=30)
  1178. >>> tokenizer.batch_decode(generate_ids, skip_special_tokens=True, clean_up_tokenization_spaces=False)[0]
  1179. "Hey, are you conscious? Can you talk to me?\nI'm not conscious, but I can talk to you."
  1180. ```"""
  1181. output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
  1182. output_hidden_states = (
  1183. output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
  1184. )
  1185. # decoder outputs consists of (dec_features, layer_state, dec_hidden, dec_attn)
  1186. outputs: BaseModelOutputWithPast = self.model(
  1187. input_ids=input_ids,
  1188. attention_mask=attention_mask,
  1189. position_ids=position_ids,
  1190. past_key_values=past_key_values,
  1191. inputs_embeds=inputs_embeds,
  1192. use_cache=use_cache,
  1193. output_attentions=output_attentions,
  1194. output_hidden_states=output_hidden_states,
  1195. cache_position=cache_position,
  1196. **kwargs,
  1197. )
  1198. hidden_states = outputs.last_hidden_state
  1199. # Only compute necessary logits, and do not upcast them to float if we are not computing the loss
  1200. slice_indices = slice(-logits_to_keep, None) if isinstance(logits_to_keep, int) else logits_to_keep
  1201. logits = self.lm_head(hidden_states[:, slice_indices, :])
  1202. loss = None
  1203. if labels is not None:
  1204. loss = self.loss_function(logits=logits, labels=labels, vocab_size=self.config.vocab_size, **kwargs)
  1205. if self.z_loss_coefficient > 0:
  1206. # Type-match loss, but avoid upcasting large logits tensor until after it's been reduced on dim -1
  1207. z_loss = logits.logsumexp(dim=-1).to(dtype=loss.dtype).pow(2).mean()
  1208. loss = loss + self.z_loss_coefficient * z_loss
  1209. return CausalLMOutputWithPast(
  1210. loss=loss,
  1211. logits=logits,
  1212. past_key_values=outputs.past_key_values,
  1213. hidden_states=outputs.hidden_states,
  1214. attentions=outputs.attentions,
  1215. )
  1216. def prepare_inputs_for_generation(
  1217. self,
  1218. input_ids,
  1219. past_key_values=None,
  1220. attention_mask=None,
  1221. inputs_embeds=None,
  1222. cache_position=None,
  1223. position_ids=None,
  1224. use_cache=True,
  1225. **kwargs,
  1226. ):
  1227. # Overwritten -- has a unique cache type, `HybridMambaAttentionDynamicCache`
  1228. empty_past_kv = past_key_values is None
  1229. # If we have cache: let's slice `input_ids` through `cache_position`, to keep only the unprocessed tokens
  1230. # Exception 1: when passing input_embeds, input_ids may be missing entries
  1231. # Exception 2: some generation methods do special slicing of input_ids, so we don't need to do it here
  1232. # Exception 3: with synced GPUs cache_position may go out of bounds, but we only want dummy token in that case.
  1233. # (we can't check exception 3 while compiling)
  1234. if not empty_past_kv:
  1235. if (
  1236. inputs_embeds is not None # Exception 1
  1237. or cache_position[-1] >= input_ids.shape[1] # Exception 3
  1238. ):
  1239. input_ids = input_ids[:, -cache_position.shape[0] :]
  1240. elif input_ids.shape[1] != cache_position.shape[0]: # Default case (the "else", a no op, is Exception 2)
  1241. input_ids = input_ids[:, cache_position]
  1242. else:
  1243. past_key_values = HybridMambaAttentionDynamicCache(
  1244. self.config, input_ids.shape[0], self.dtype, device=self.device
  1245. )
  1246. if attention_mask is not None and position_ids is None:
  1247. # create position_ids on the fly for batch generation
  1248. position_ids = attention_mask.long().cumsum(-1) - 1
  1249. position_ids.masked_fill_(attention_mask == 0, 1)
  1250. if not empty_past_kv:
  1251. position_ids = position_ids[:, -input_ids.shape[1] :]
  1252. # if `inputs_embeds` are passed, we only want to use them in the 1st generation step
  1253. if inputs_embeds is not None and empty_past_kv:
  1254. model_inputs = {"inputs_embeds": inputs_embeds}
  1255. else:
  1256. model_inputs = {"input_ids": input_ids.contiguous()} # `contiguous()` needed for compilation use cases
  1257. model_inputs.update(
  1258. {
  1259. "position_ids": position_ids,
  1260. "past_key_values": past_key_values,
  1261. "use_cache": use_cache,
  1262. "attention_mask": attention_mask,
  1263. "logits_to_keep": self.config.num_logits_to_keep,
  1264. "cache_position": cache_position,
  1265. }
  1266. )
  1267. # Forward ALL kwargs that are uninitialized (e.g. `use_cache`).
  1268. for key, value in kwargs.items():
  1269. if key not in model_inputs:
  1270. model_inputs[key] = value
  1271. return model_inputs
  1272. __all__ = ["BambaModel", "BambaForCausalLM", "BambaPreTrainedModel"]