modular_bamba.py 54 KB

1234567891011121314151617181920212223242526272829303132333435363738394041424344454647484950515253545556575859606162636465666768697071727374757677787980818283848586878889909192939495969798991001011021031041051061071081091101111121131141151161171181191201211221231241251261271281291301311321331341351361371381391401411421431441451461471481491501511521531541551561571581591601611621631641651661671681691701711721731741751761771781791801811821831841851861871881891901911921931941951961971981992002012022032042052062072082092102112122132142152162172182192202212222232242252262272282292302312322332342352362372382392402412422432442452462472482492502512522532542552562572582592602612622632642652662672682692702712722732742752762772782792802812822832842852862872882892902912922932942952962972982993003013023033043053063073083093103113123133143153163173183193203213223233243253263273283293303313323333343353363373383393403413423433443453463473483493503513523533543553563573583593603613623633643653663673683693703713723733743753763773783793803813823833843853863873883893903913923933943953963973983994004014024034044054064074084094104114124134144154164174184194204214224234244254264274284294304314324334344354364374384394404414424434444454464474484494504514524534544554564574584594604614624634644654664674684694704714724734744754764774784794804814824834844854864874884894904914924934944954964974984995005015025035045055065075085095105115125135145155165175185195205215225235245255265275285295305315325335345355365375385395405415425435445455465475485495505515525535545555565575585595605615625635645655665675685695705715725735745755765775785795805815825835845855865875885895905915925935945955965975985996006016026036046056066076086096106116126136146156166176186196206216226236246256266276286296306316326336346356366376386396406416426436446456466476486496506516526536546556566576586596606616626636646656666676686696706716726736746756766776786796806816826836846856866876886896906916926936946956966976986997007017027037047057067077087097107117127137147157167177187197207217227237247257267277287297307317327337347357367377387397407417427437447457467477487497507517527537547557567577587597607617627637647657667677687697707717727737747757767777787797807817827837847857867877887897907917927937947957967977987998008018028038048058068078088098108118128138148158168178188198208218228238248258268278288298308318328338348358368378388398408418428438448458468478488498508518528538548558568578588598608618628638648658668678688698708718728738748758768778788798808818828838848858868878888898908918928938948958968978988999009019029039049059069079089099109119129139149159169179189199209219229239249259269279289299309319329339349359369379389399409419429439449459469479489499509519529539549559569579589599609619629639649659669679689699709719729739749759769779789799809819829839849859869879889899909919929939949959969979989991000100110021003100410051006100710081009101010111012101310141015101610171018101910201021102210231024102510261027102810291030103110321033103410351036103710381039104010411042104310441045104610471048104910501051105210531054105510561057105810591060106110621063106410651066106710681069107010711072107310741075107610771078107910801081108210831084108510861087108810891090109110921093109410951096109710981099110011011102110311041105110611071108110911101111111211131114111511161117111811191120112111221123112411251126112711281129113011311132113311341135113611371138113911401141114211431144114511461147114811491150115111521153115411551156115711581159116011611162116311641165116611671168116911701171117211731174117511761177117811791180118111821183118411851186118711881189119011911192119311941195119611971198119912001201120212031204120512061207120812091210121112121213121412151216121712181219122012211222
  1. # coding=utf-8
  2. # Copyright 2024 IBM and the HuggingFace Inc. team. All rights reserved.
  3. #
  4. # This code is based on EleutherAI's GPT-NeoX library and the GPT-NeoX
  5. # and OPT implementations in this library. It has been modified from its
  6. # original forms to accommodate minor architectural differences compared
  7. # to GPT-NeoX and OPT used by the Meta AI team that trained the model.
  8. #
  9. # Licensed under the Apache License, Version 2.0 (the "License");
  10. # you may not use this file except in compliance with the License.
  11. # You may obtain a copy of the License at
  12. #
  13. # http://www.apache.org/licenses/LICENSE-2.0
  14. #
  15. # Unless required by applicable law or agreed to in writing, software
  16. # distributed under the License is distributed on an "AS IS" BASIS,
  17. # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  18. # See the License for the specific language governing permissions and
  19. # limitations under the License.
  20. """PyTorch Bamba model."""
  21. from typing import Optional, TypedDict, Union
  22. import torch
  23. from torch import nn
  24. from transformers.activations import ACT2FN
  25. from transformers.models.jamba.modeling_jamba import HybridMambaAttentionDynamicCache, JambaAttentionDecoderLayer
  26. from transformers.models.llama.modeling_llama import (
  27. LlamaAttention,
  28. LlamaForCausalLM,
  29. LlamaMLP,
  30. LlamaRMSNorm,
  31. LlamaRotaryEmbedding,
  32. rotate_half,
  33. )
  34. from transformers.models.mamba2.modeling_mamba2 import (
  35. MambaRMSNormGated,
  36. pad_tensor_by_size,
  37. reshape_into_chunks,
  38. segment_sum,
  39. )
  40. from ...modeling_attn_mask_utils import AttentionMaskConverter
  41. from ...modeling_outputs import BaseModelOutputWithPast, CausalLMOutputWithPast
  42. from ...modeling_utils import PreTrainedModel
  43. from ...processing_utils import Unpack
  44. from ...utils import (
  45. auto_docstring,
  46. can_return_tuple,
  47. logging,
  48. )
  49. from ...utils.deprecation import deprecate_kwarg
  50. from ...utils.import_utils import is_causal_conv1d_available, is_mamba_2_ssm_available
  51. from .configuration_bamba import BambaConfig
  52. if is_mamba_2_ssm_available():
  53. from mamba_ssm.ops.triton.selective_state_update import selective_state_update
  54. from mamba_ssm.ops.triton.ssd_combined import mamba_chunk_scan_combined, mamba_split_conv1d_scan_combined
  55. else:
  56. selective_state_update = None
  57. if is_causal_conv1d_available():
  58. from causal_conv1d import causal_conv1d_fn, causal_conv1d_update
  59. else:
  60. causal_conv1d_update, causal_conv1d_fn = None, None
  61. is_fast_path_available = all((selective_state_update, causal_conv1d_fn, causal_conv1d_update))
  62. logger = logging.get_logger(__name__)
  63. class BambaFlashAttentionKwargs(TypedDict, total=False):
  64. """
  65. Keyword arguments for advanced Flash Attention, causal-conv1d, and mamba_ssm kernel usage.
  66. Use cases include padding-free training and fewer `torch.compile` graph breaks.
  67. Attributes:
  68. cu_seq_lens_q (`torch.LongTensor`)
  69. Gets cumulative sequence length for query state.
  70. cu_seq_lens_k (`torch.LongTensor`)
  71. Gets cumulative sequence length for key state.
  72. max_length_q (`int`):
  73. Maximum sequence length for query state.
  74. max_length_k (`int`):
  75. Maximum sequence length for key state.
  76. seq_idx (`torch.IntTensor):
  77. Index of each packed sequence.
  78. """
  79. cu_seq_lens_q: torch.LongTensor
  80. cu_seq_lens_k: torch.LongTensor
  81. max_length_q: int
  82. max_length_k: int
  83. seq_idx: torch.IntTensor
  84. # Adapted from transformers.models.jamba.modeling_jamba.HybridMambaAttentionDynamicCache for the v2 mixer
  85. class HybridMambaAttentionDynamicCache(HybridMambaAttentionDynamicCache):
  86. """
  87. A dynamic cache that can handle both the attention cache (which has a seq_len dimension) and the mamba cache
  88. (which has a constant shape regardless of seq_len).
  89. This cache has two sets of lists of tensors: `key_cache` and `value_cache` for attention cache and `conv_states`
  90. and `ssm_states` for mamba cache. Each of these lists has `num_layers` tensors. The expected shape for each tensor
  91. For attention layers, `key_cache` and `value_cache` have a shape of `(batch_size, num_heads, seq_len, head_dim)`,
  92. while `conv_states` and `ssm_states` have a shape of `(batch_size, 0)` (empty tensors).
  93. For mamba layers, `key_cache` and `value_cache` have a shape of `(batch_size, 0)` (empty tensors),
  94. while `conv_states` represents the convolution state and has a shape of `(batch_size, d_inner, d_conv)`,
  95. and `ssm_states` represents the ssm state and has a shape of `(batch_size, d_inner, d_state)`.
  96. """
  97. def __init__(self, config: BambaConfig, batch_size, dtype=torch.float16, device=None):
  98. self.layers_block_type = config.layers_block_type
  99. self.has_previous_state = False # only used by mamba
  100. conv_kernel_size = config.mamba_d_conv
  101. ssm_state_size = config.mamba_d_state
  102. self.conv_states = []
  103. self.ssm_states = []
  104. self.transformer_layers = []
  105. for i in range(config.num_hidden_layers):
  106. if self.layers_block_type[i] == "mamba":
  107. self.conv_states += [
  108. torch.zeros(
  109. batch_size,
  110. (config.mamba_expand * config.hidden_size + 2 * config.mamba_n_groups * ssm_state_size),
  111. conv_kernel_size,
  112. device=device,
  113. dtype=dtype,
  114. )
  115. ]
  116. self.ssm_states += [
  117. torch.zeros(
  118. batch_size,
  119. config.mamba_n_heads,
  120. config.mamba_d_head,
  121. ssm_state_size,
  122. device=device,
  123. dtype=dtype,
  124. )
  125. ]
  126. else:
  127. self.conv_states += [torch.tensor([[]] * batch_size, device=device)]
  128. self.ssm_states += [torch.tensor([[]] * batch_size, device=device)]
  129. self.transformer_layers.append(i)
  130. self.key_cache = [torch.tensor([[]] * batch_size, device=device) for _ in range(config.num_hidden_layers)]
  131. self.value_cache = [torch.tensor([[]] * batch_size, device=device) for _ in range(config.num_hidden_layers)]
  132. class BambaRotaryEmbedding(LlamaRotaryEmbedding):
  133. pass
  134. # Adapted from transformers.models.glm.modular_glm.apply_rotary_pos_emb
  135. def apply_rotary_pos_emb(q, k, cos, sin, position_ids=None, unsqueeze_dim=1):
  136. """Applies Rotary Position Embedding to the query and key tensors.
  137. Removes the interleaving of cos and sin from GLM
  138. Args:
  139. q (`torch.Tensor`): The query tensor.
  140. k (`torch.Tensor`): The key tensor.
  141. cos (`torch.Tensor`): The cosine part of the rotary embedding.
  142. sin (`torch.Tensor`): The sine part of the rotary embedding.
  143. position_ids (`torch.Tensor`, *optional*):
  144. Deprecated and unused.
  145. unsqueeze_dim (`int`, *optional*, defaults to 1):
  146. The 'unsqueeze_dim' argument specifies the dimension along which to unsqueeze cos[position_ids] and
  147. sin[position_ids] so that they can be properly broadcasted to the dimensions of q and k. For example, note
  148. that cos[position_ids] and sin[position_ids] have the shape [batch_size, seq_len, head_dim]. Then, if q and
  149. k have the shape [batch_size, heads, seq_len, head_dim], then setting unsqueeze_dim=1 makes
  150. cos[position_ids] and sin[position_ids] broadcastable to the shapes of q and k. Similarly, if q and k have
  151. the shape [batch_size, seq_len, heads, head_dim], then set unsqueeze_dim=2.
  152. Returns:
  153. `tuple(torch.Tensor)` comprising of the query and key tensors rotated using the Rotary Position Embedding.
  154. """
  155. cos = cos.unsqueeze(unsqueeze_dim)
  156. sin = sin.unsqueeze(unsqueeze_dim)
  157. # Keep half or full tensor for later concatenation
  158. rotary_dim = cos.shape[-1]
  159. q_rot, q_pass = q[..., :rotary_dim], q[..., rotary_dim:]
  160. k_rot, k_pass = k[..., :rotary_dim], k[..., rotary_dim:]
  161. # Apply rotary embeddings on the first half or full tensor
  162. q_embed = (q_rot * cos) + (rotate_half(q_rot) * sin)
  163. k_embed = (k_rot * cos) + (rotate_half(k_rot) * sin)
  164. # Concatenate back to full shape
  165. q_embed = torch.cat([q_embed, q_pass], dim=-1)
  166. k_embed = torch.cat([k_embed, k_pass], dim=-1)
  167. return q_embed, k_embed
  168. class BambaAttention(LlamaAttention):
  169. pass
  170. class BambaRMSNormGated(MambaRMSNormGated):
  171. pass
  172. def apply_mask_to_padding_states(hidden_states, attention_mask):
  173. """
  174. Tunes out the hidden states for padding tokens, see https://github.com/state-spaces/mamba/issues/66
  175. """
  176. if attention_mask is not None and attention_mask.shape[1] > 1 and attention_mask.shape[0] > 1:
  177. dtype = hidden_states.dtype
  178. hidden_states = (hidden_states * attention_mask[:, :, None]).to(dtype)
  179. return hidden_states
  180. # Adapted from transformers.models.mamba2.modeling_mamba2.Mamba2Mixer
  181. class BambaMixer(nn.Module):
  182. """
  183. Compute ∆, A, B, C, and D the state space parameters and compute the `contextualized_states`.
  184. A, D are input independent (see Mamba paper [1] Section 3.5.2 "Interpretation of A" for why A isn't selective)
  185. ∆, B, C are input-dependent (this is a key difference between Mamba and the linear time invariant S4,
  186. and is why Mamba is called **selective** state spaces)
  187. The are a few differences between this and Mamba2Mixer:
  188. - The variable use_precomputed_states is slightly different due to the hybrid cache structure
  189. - There's a few non-obvious bugs fixed with batching in the slow path that exist in main
  190. - Some extra variables that our layer doesn't need have been removed
  191. - We ported most of the refactors in https://github.com/huggingface/transformers/pull/35154, which is (as of Dec 18, 2024) unmerged
  192. """
  193. def __init__(self, config: BambaConfig, layer_idx: int):
  194. super().__init__()
  195. self.num_heads = config.mamba_n_heads
  196. self.hidden_size = config.hidden_size
  197. self.ssm_state_size = config.mamba_d_state
  198. self.conv_kernel_size = config.mamba_d_conv
  199. self.intermediate_size = int(config.mamba_expand * self.hidden_size)
  200. self.layer_idx = layer_idx
  201. self.use_conv_bias = config.mamba_conv_bias
  202. self.activation = config.hidden_act
  203. self.act = ACT2FN[config.hidden_act]
  204. self.use_bias = config.mamba_proj_bias
  205. self.layer_norm_epsilon = config.rms_norm_eps
  206. self.n_groups = config.mamba_n_groups
  207. self.head_dim = config.mamba_d_head
  208. self.chunk_size = config.mamba_chunk_size
  209. # FIXME:
  210. self.time_step_limit = (0.0, float("inf"))
  211. self.time_step_min = 0.001
  212. self.time_step_max = 0.1
  213. self.conv_dim = self.intermediate_size + 2 * self.n_groups * self.ssm_state_size
  214. self.conv1d = nn.Conv1d(
  215. in_channels=self.conv_dim,
  216. out_channels=self.conv_dim,
  217. bias=config.mamba_conv_bias,
  218. kernel_size=self.conv_kernel_size,
  219. groups=self.conv_dim,
  220. padding=self.conv_kernel_size - 1,
  221. )
  222. # projection of the input hidden states
  223. projection_size = self.intermediate_size + self.conv_dim + self.num_heads
  224. self.in_proj = nn.Linear(
  225. self.hidden_size,
  226. projection_size,
  227. bias=self.use_bias,
  228. )
  229. # selective projection used to make dt, B and C input dependent
  230. # time step projection (discretization)
  231. # instantiate once and copy inv_dt in init_weights of PretrainedModel
  232. self.dt_bias = nn.Parameter(torch.ones(self.num_heads))
  233. # S4D real initialization. These are not discretized!
  234. # The core is to load them, compute the discrete states, then write the updated state. Keeps the memory bounded
  235. A = torch.arange(1, self.num_heads + 1)
  236. self.A_log = nn.Parameter(torch.log(A))
  237. self.norm = BambaRMSNormGated(self.intermediate_size, eps=self.layer_norm_epsilon)
  238. self.D = nn.Parameter(torch.ones(self.num_heads))
  239. self.out_proj = nn.Linear(self.intermediate_size, self.hidden_size, bias=self.use_bias)
  240. if not is_fast_path_available:
  241. logger.warning_once(
  242. "The fast path is not available because one of `(selective_state_update, causal_conv1d_fn, causal_conv1d_update)`"
  243. " is None. Falling back to the naive implementation. To install follow https://github.com/state-spaces/mamba/#installation and"
  244. " https://github.com/Dao-AILab/causal-conv1d"
  245. )
  246. else:
  247. logger.warning_once("The fast path for Bamba will be used when running the model on a GPU")
  248. def cuda_kernels_forward(
  249. self,
  250. hidden_states: torch.Tensor,
  251. cache_params: Optional[HybridMambaAttentionDynamicCache] = None,
  252. cache_position: Optional[torch.LongTensor] = None,
  253. attention_mask: Optional[torch.Tensor] = None,
  254. seq_idx: Optional[torch.IntTensor] = None,
  255. ):
  256. # 1. Gated MLP's linear projection
  257. hidden_states = apply_mask_to_padding_states(hidden_states, attention_mask)
  258. projected_states = self.in_proj(hidden_states)
  259. # Set up dimensions for reshapes later
  260. batch_size, seq_len, _ = hidden_states.shape
  261. groups_time_state_size = self.n_groups * self.ssm_state_size
  262. use_precomputed_states = (
  263. cache_params is not None
  264. and cache_params.has_previous_state
  265. and seq_len == 1
  266. and cache_params.conv_states[self.layer_idx].shape[0]
  267. == cache_params.ssm_states[self.layer_idx].shape[0]
  268. == batch_size
  269. and cache_position is not None
  270. and cache_position[0] > 0
  271. )
  272. # getting projected states from cache if it exists
  273. if use_precomputed_states:
  274. gate, hidden_states_B_C, dt = projected_states.squeeze(1).split(
  275. [self.intermediate_size, self.conv_dim, self.num_heads], dim=-1
  276. )
  277. # 2. Convolution sequence transformation
  278. hidden_states_B_C = causal_conv1d_update(
  279. hidden_states_B_C,
  280. cache_params.conv_states[self.layer_idx],
  281. self.conv1d.weight.squeeze(1),
  282. self.conv1d.bias,
  283. self.activation,
  284. )
  285. hidden_states, B, C = torch.split(
  286. hidden_states_B_C,
  287. [self.intermediate_size, groups_time_state_size, groups_time_state_size],
  288. dim=-1,
  289. )
  290. # 3. SSM transformation
  291. A = -torch.exp(self.A_log.float()) # (nheads,)
  292. A = A[:, None, ...][:, :, None].expand(-1, self.head_dim, self.ssm_state_size).to(dtype=torch.float32)
  293. dt = dt[:, :, None].expand(-1, -1, self.head_dim)
  294. dt_bias = self.dt_bias[:, None, ...].expand(-1, self.head_dim)
  295. D = self.D[:, None, ...].expand(-1, self.head_dim)
  296. B = B.view(batch_size, self.n_groups, B.shape[1] // self.n_groups)
  297. C = C.view(batch_size, self.n_groups, C.shape[1] // self.n_groups)
  298. hidden_states_reshaped = hidden_states.view(batch_size, self.num_heads, self.head_dim)
  299. hidden_states = selective_state_update(
  300. cache_params.ssm_states[self.layer_idx],
  301. hidden_states_reshaped,
  302. dt,
  303. A,
  304. B,
  305. C,
  306. D,
  307. z=None,
  308. dt_bias=dt_bias,
  309. dt_softplus=True,
  310. )
  311. hidden_states = hidden_states.view(batch_size, self.num_heads * self.head_dim)
  312. hidden_states = self.norm(hidden_states, gate)
  313. # 4. Final linear projection
  314. out = self.out_proj(hidden_states)[:, None, ...]
  315. # Fused calculations or step by step if no initialized cache is found
  316. else:
  317. A = -torch.exp(self.A_log.float()) # (num_heads) or (intermediate_size, state_size)
  318. dt_limit_kwargs = {} if self.time_step_limit == (0.0, float("inf")) else {"dt_limit": self.time_step_limit}
  319. # 2-4. Fused kernel for conv1d, SSM, and the final projection
  320. if self.training and cache_params is None:
  321. out = mamba_split_conv1d_scan_combined(
  322. projected_states,
  323. self.conv1d.weight.squeeze(1),
  324. self.conv1d.bias,
  325. self.dt_bias,
  326. A,
  327. D=self.D,
  328. chunk_size=self.chunk_size,
  329. seq_idx=seq_idx,
  330. activation=self.activation,
  331. rmsnorm_weight=self.norm.weight,
  332. rmsnorm_eps=self.norm.variance_epsilon,
  333. outproj_weight=self.out_proj.weight,
  334. outproj_bias=self.out_proj.bias,
  335. headdim=self.head_dim,
  336. ngroups=self.n_groups,
  337. norm_before_gate=False,
  338. return_final_states=False,
  339. **dt_limit_kwargs,
  340. )
  341. else:
  342. gate, hidden_states_B_C, dt = projected_states.split(
  343. [self.intermediate_size, self.conv_dim, self.num_heads], dim=-1
  344. )
  345. # 2. Convolution sequence transformation
  346. # Init cache
  347. if cache_params is not None:
  348. # storing the states
  349. # If we just take xBC[:, :, -self.d_conv :], it will error if seqlen < self.d_conv
  350. # Instead F.pad will pad with zeros if seqlen < self.d_conv, and truncate otherwise.
  351. hidden_states_B_C_transposed = hidden_states_B_C.transpose(1, 2)
  352. conv_states = nn.functional.pad(
  353. hidden_states_B_C_transposed,
  354. (self.conv_kernel_size - hidden_states_B_C_transposed.shape[-1], 0),
  355. )
  356. cache_params.conv_states[self.layer_idx].copy_(conv_states)
  357. if self.activation not in ["silu", "swish"]:
  358. hidden_states_B_C = self.act(
  359. self.conv1d(hidden_states_B_C.transpose(1, 2))[..., :seq_len].transpose(1, 2)
  360. )
  361. else:
  362. hidden_states_B_C = causal_conv1d_fn(
  363. x=hidden_states_B_C.transpose(1, 2),
  364. weight=self.conv1d.weight.squeeze(1),
  365. bias=self.conv1d.bias,
  366. activation=self.activation,
  367. seq_idx=seq_idx,
  368. ).transpose(1, 2)
  369. hidden_states_B_C = apply_mask_to_padding_states(hidden_states_B_C, attention_mask)
  370. hidden_states, B, C = torch.split(
  371. hidden_states_B_C,
  372. [self.intermediate_size, groups_time_state_size, groups_time_state_size],
  373. dim=-1,
  374. )
  375. # 3. SSM transformation
  376. scan_output, ssm_state = mamba_chunk_scan_combined(
  377. hidden_states.view(batch_size, seq_len, -1, self.head_dim),
  378. dt,
  379. A,
  380. B.view(batch_size, seq_len, self.n_groups, -1),
  381. C.view(batch_size, seq_len, self.n_groups, -1),
  382. chunk_size=self.chunk_size,
  383. D=self.D,
  384. z=None,
  385. seq_idx=seq_idx,
  386. return_final_states=True,
  387. dt_bias=self.dt_bias,
  388. dt_softplus=True,
  389. **dt_limit_kwargs,
  390. )
  391. # Init cache
  392. if ssm_state is not None and cache_params is not None:
  393. cache_params.ssm_states[self.layer_idx].copy_(ssm_state)
  394. scan_output = scan_output.view(batch_size, seq_len, -1)
  395. # Multiply "gate" branch and apply extra normalization layer
  396. scan_output = self.norm(scan_output, gate)
  397. # 4. Final linear projection
  398. out = self.out_proj(scan_output)
  399. return out
  400. # fmt: off
  401. def torch_forward(
  402. self,
  403. input_states,
  404. cache_params: Optional[HybridMambaAttentionDynamicCache] = None,
  405. cache_position: Optional[torch.LongTensor] = None,
  406. attention_mask: Optional[torch.Tensor] = None,
  407. ):
  408. batch_size, seq_len, _ = input_states.shape
  409. dtype = input_states.dtype
  410. # 1. Gated MLP's linear projection
  411. input_states = apply_mask_to_padding_states(input_states, attention_mask)
  412. projected_states = self.in_proj(input_states)
  413. gate, hidden_states_B_C, dt = projected_states.split(
  414. [self.intermediate_size, self.conv_dim, self.num_heads], dim=-1
  415. )
  416. use_precomputed_states = (
  417. cache_params is not None
  418. and cache_params.has_previous_state
  419. and seq_len == 1
  420. and cache_params.conv_states[self.layer_idx].shape[0]
  421. == cache_params.ssm_states[self.layer_idx].shape[0]
  422. == batch_size
  423. and cache_position is not None
  424. and cache_position[0] > 0
  425. )
  426. # 2. Convolution sequence transformation
  427. if use_precomputed_states:
  428. cache_params.conv_states[self.layer_idx] = cache_params.conv_states[self.layer_idx].roll(shifts=-1, dims=-1)
  429. cache_params.conv_states[self.layer_idx][:, :, -1] = hidden_states_B_C[:, 0, :].to(cache_params.conv_states[self.layer_idx].device)
  430. # We need to guarantee that anything regarding the cache is on the same device
  431. conv_states = cache_params.conv_states[self.layer_idx].to(device=self.conv1d.weight.device)
  432. hidden_states_B_C = torch.sum(
  433. conv_states * self.conv1d.weight.squeeze(1), dim=-1
  434. )
  435. if self.use_conv_bias:
  436. hidden_states_B_C = hidden_states_B_C + self.conv1d.bias
  437. hidden_states_B_C = self.act(hidden_states_B_C)
  438. else:
  439. # Init cache
  440. if cache_params is not None:
  441. hidden_states_B_C_transposed = hidden_states_B_C.transpose(1, 2)
  442. conv_states = nn.functional.pad(
  443. hidden_states_B_C_transposed, (self.conv_kernel_size - hidden_states_B_C_transposed.shape[-1], 0)
  444. )
  445. cache_params.conv_states[self.layer_idx].copy_(conv_states)
  446. hidden_states_B_C = self.act(self.conv1d(hidden_states_B_C.transpose(1, 2))[..., :seq_len].transpose(1, 2))
  447. hidden_states_B_C = apply_mask_to_padding_states(hidden_states_B_C, attention_mask)
  448. hidden_states, B, C = torch.split(
  449. hidden_states_B_C,
  450. [self.intermediate_size, self.n_groups * self.ssm_state_size, self.n_groups * self.ssm_state_size],
  451. dim=-1
  452. )
  453. # 3. SSM transformation
  454. A = -torch.exp(self.A_log.float()) # [num_heads]
  455. if use_precomputed_states:
  456. # We need to guarantee that anything regarding the cache is on the same device
  457. cache_device = cache_params.ssm_states[self.layer_idx].device
  458. # Note: there is no need to pad parameter matrices here, as there is just one new token
  459. # for batched generation
  460. dt = dt[:, 0, :][:, None, ...]
  461. dt = dt.transpose(1, 2).expand(batch_size, dt.shape[-1], self.head_dim)
  462. # [num_heads] -> [num_heads, head_dim]
  463. dt_bias = self.dt_bias[..., None].expand(self.dt_bias.shape[0], self.head_dim)
  464. dt = torch.nn.functional.softplus(dt + dt_bias.to(dt.dtype))
  465. dt = torch.clamp(dt, self.time_step_limit[0], self.time_step_limit[1])
  466. A = A[..., None, None].expand(self.num_heads, self.head_dim, self.ssm_state_size).to(dtype=torch.float32)
  467. # [bsz, num_heads, head_dim, state_size]
  468. dA = (torch.exp(dt[..., None] * A)).to(device=cache_device)
  469. # Discretize B
  470. # [bsz, n_groups * state_size] -> [bsz, n_groups, 1, state_size] ->
  471. # -> [bsz, n_groups, group to head repetition factor, state_size] -> [bsz, num_heads, state_size]
  472. B = B.reshape(batch_size, self.n_groups, -1)[..., None, :]
  473. B = B.expand(batch_size, self.n_groups, self.num_heads // self.n_groups, B.shape[-1]).contiguous()
  474. B = B.reshape(batch_size, -1, B.shape[-1])
  475. # [bsz, num_heads, head_dim, state_size]
  476. dB = dt[..., None] * B[..., None, :]
  477. # Discretize x into dB
  478. # [bsz, intermediate_size] -> [bsz, num_heads, head_dim]
  479. hidden_states = hidden_states.reshape(batch_size, -1, self.head_dim)
  480. dBx = (dB * hidden_states[..., None]).to(device=cache_device)
  481. # State calculation
  482. cache_params.ssm_states[self.layer_idx].copy_(
  483. cache_params.ssm_states[self.layer_idx] * dA + dBx
  484. )
  485. # Subsequent output
  486. # [bsz, n_groups * state_size] -> [bsz, num_heads, state_size]
  487. C = C.reshape(batch_size, self.n_groups, -1)[..., None, :]
  488. C = C.expand(batch_size, self.n_groups, self.num_heads // self.n_groups, C.shape[-1]).contiguous()
  489. C = C.reshape(batch_size, -1, C.shape[-1])
  490. # [bsz, num_heads, head_dim]
  491. ssm_states = cache_params.ssm_states[self.layer_idx].to(device=C.device, dtype=C.dtype) # Shape: [b, h, d, n]
  492. # Reshape ssm_states to merge the first two dimensions
  493. ssm_states_reshaped = ssm_states.view(batch_size * self.num_heads, self.head_dim, self.ssm_state_size) # Shape: [b*h, d, n]
  494. C_reshaped = C.view(batch_size * self.num_heads, self.ssm_state_size, 1) # Shape: [b*h, n, 1]
  495. y = torch.bmm(ssm_states_reshaped, C_reshaped)
  496. y = y.view(batch_size, self.num_heads, self.head_dim)
  497. # D skip connection
  498. # [num_heads] -> [num_heads, head_dim]
  499. D = self.D[..., None].expand(self.D.shape[0], self.head_dim)
  500. y = (y + hidden_states * D).to(y.dtype)
  501. # [bsz, num_heads, head_dim] -> [bsz, 1, intermediate_size]
  502. y = y.reshape(batch_size, -1)[:, None, ...]
  503. else:
  504. # begin ssd naive implementation without einsums
  505. dt = nn.functional.softplus(dt + self.dt_bias)
  506. dt = torch.clamp(dt, self.time_step_limit[0], self.time_step_limit[1])
  507. hidden_states = hidden_states.reshape(batch_size, seq_len, -1, self.head_dim).float()
  508. B = B.reshape(batch_size, seq_len, -1, self.ssm_state_size).float()
  509. C = C.reshape(batch_size, seq_len, -1, self.ssm_state_size).float()
  510. B = B.repeat_interleave(self.num_heads // self.n_groups, dim=2, output_size=self.num_heads)
  511. C = C.repeat_interleave(self.num_heads // self.n_groups, dim=2, output_size=self.num_heads)
  512. pad_size = (self.chunk_size - seq_len % self.chunk_size) % self.chunk_size
  513. D_residual = self.D[..., None] * pad_tensor_by_size(hidden_states, pad_size)
  514. # Discretize x and A
  515. hidden_states = hidden_states * dt[..., None]
  516. A = A.to(hidden_states.dtype) * dt
  517. # Rearrange into blocks/chunks
  518. hidden_states, A, B, C = [reshape_into_chunks(t, pad_size, self.chunk_size) for t in (hidden_states, A, B, C)]
  519. # [bsz, -1, chunk_size, num_heads] -> [bsz, num_heads, -1, chunk_size]
  520. A = A.permute(0, 3, 1, 2)
  521. A_cumsum = torch.cumsum(A, dim=-1)
  522. # 1. Compute the output for each intra-chunk (diagonal blocks)
  523. # This is the analog of a causal mask
  524. L = torch.exp(segment_sum(A))
  525. # Contraction of C and B to get G (attention-weights like)
  526. G_intermediate = C[:, :, :, None, :, :] * B[:, :, None, :, :, :] # shape: (b, c, l, s, h, n)
  527. G = G_intermediate.sum(dim=-1) # shape: (b, c, l, s, h)
  528. # Compute M, equivalent to applying attention mask to weights
  529. M_intermediate = G[..., None] * L.permute(0, 2, 3, 4, 1)[..., None]
  530. M = M_intermediate.sum(dim=-1)
  531. # Compute Y_diag (apply to values)
  532. Y_diag = (M[..., None] * hidden_states[:, :, None]).sum(dim=3)
  533. # 2. Compute the state for each intra-chunk
  534. # (right term of low-rank factorization of off-diagonal blocks; B terms)
  535. decay_states = torch.exp(A_cumsum[:, :, :, -1:] - A_cumsum)
  536. B_decay = B * decay_states.permute(0, -2, -1, 1)[..., None]
  537. states = (B_decay[..., None, :] * hidden_states[..., None]).sum(dim=2)
  538. # 3. Compute the inter-chunk SSM recurrence; produces correct SSM states at chunk boundaries
  539. # (middle term of factorization of off-diag blocks; A terms)
  540. if use_precomputed_states:
  541. previous_states = cache_params.ssm_states[self.layer_idx][:, None, ...].to(device=states.device)
  542. else:
  543. previous_states = torch.zeros_like(states[:, :1])
  544. states = torch.cat([previous_states, states], dim=1)
  545. decay_chunk = torch.exp(segment_sum(nn.functional.pad(A_cumsum[:, :, :, -1], (1, 0))))
  546. decay_chunk = decay_chunk.transpose(1, 3)
  547. new_states = (decay_chunk[..., None, None] * states[:, :, None, ...]).sum(dim=1)
  548. states, ssm_state = new_states[:, :-1], new_states[:, -1]
  549. # 4. Compute state -> output conversion per chunk
  550. # (left term of low-rank factorization of off-diagonal blocks; C terms)
  551. state_decay_out = torch.exp(A_cumsum)
  552. C_times_states = (C[..., None, :] * states[:, :, None, ...])
  553. state_decay_out_permuted = state_decay_out.permute(0, 2, 3, 1)
  554. Y_off = (C_times_states.sum(-1) * state_decay_out_permuted[..., None])
  555. # Add output of intra-chunk and inter-chunk terms (diagonal and off-diagonal blocks)
  556. y = Y_diag + Y_off
  557. # [bsz, -1, self.chunk_size, num_heads, head_dim] -> [bsz, (padded) seq_len, num_heads, head_dim]
  558. y = y.reshape(batch_size, -1, self.num_heads, self.head_dim)
  559. y = y + D_residual
  560. # Cutting off padded chunks
  561. if pad_size > 0:
  562. y = y[:, :seq_len, :, :]
  563. y = y.reshape(batch_size, seq_len, -1)
  564. # Init cache
  565. if ssm_state is not None and cache_params is not None:
  566. cache_params.ssm_states[self.layer_idx].copy_(ssm_state)
  567. scan_output = self.norm(y, gate)
  568. # end ssd naive
  569. # 4. Final linear projection
  570. contextualized_states = self.out_proj(scan_output.to(dtype)) # [batch, seq_len, hidden_size]
  571. return contextualized_states
  572. # fmt: on
  573. def forward(
  574. self,
  575. hidden_states,
  576. cache_params: Optional[HybridMambaAttentionDynamicCache] = None,
  577. cache_position: Optional[torch.LongTensor] = None,
  578. attention_mask: Optional[torch.Tensor] = None,
  579. seq_idx: Optional[torch.IntTensor] = None,
  580. **kwargs,
  581. ):
  582. if is_fast_path_available and "cuda" in self.in_proj.weight.device.type:
  583. return self.cuda_kernels_forward(hidden_states, cache_params, cache_position, attention_mask, seq_idx)
  584. if seq_idx is not None:
  585. raise NotImplementedError(
  586. "`seq_idx` support requires fast path support. Please install `mamba_ssm` and `causal_conv1d`"
  587. )
  588. dtype = hidden_states.dtype
  589. if attention_mask is not None and attention_mask.shape[1] > 1 and attention_mask.shape[0] > 1:
  590. # tune out hidden states for pad tokens, see https://github.com/state-spaces/mamba/issues/66
  591. hidden_states = (hidden_states * attention_mask[:, :, None]).to(dtype)
  592. return self.torch_forward(hidden_states, cache_params, cache_position, attention_mask)
  593. class BambaMLP(LlamaMLP):
  594. pass
  595. class BambaRMSNorm(LlamaRMSNorm):
  596. pass
  597. class BambaDecoderLayer(JambaAttentionDecoderLayer):
  598. def __init__(self, config: BambaConfig, layer_idx: int, layer_type: str = "mamba"):
  599. super().__init__(config, layer_idx)
  600. del self.self_attn
  601. num_experts = 1
  602. ffn_layer_class = BambaMLP if num_experts == 1 else None
  603. self.feed_forward = ffn_layer_class(config)
  604. self.layer_type = layer_type
  605. if layer_type == "mamba":
  606. self.mamba = BambaMixer(config=config, layer_idx=layer_idx)
  607. elif layer_type == "attention":
  608. self.self_attn = BambaAttention(config, layer_idx)
  609. else:
  610. raise ValueError("Invalid layer_type")
  611. @deprecate_kwarg("past_key_value", new_name="past_key_values", version="4.58")
  612. def forward(
  613. self,
  614. hidden_states: torch.Tensor,
  615. attention_mask: Optional[torch.Tensor] = None,
  616. position_ids: Optional[torch.LongTensor] = None,
  617. past_key_values: Optional[HybridMambaAttentionDynamicCache] = None,
  618. output_attentions: Optional[bool] = False,
  619. use_cache: Optional[bool] = False,
  620. cache_position: Optional[torch.LongTensor] = None,
  621. position_embeddings: Optional[tuple[torch.Tensor, torch.Tensor]] = None, # necessary, but kept here for BC
  622. **kwargs: Unpack[BambaFlashAttentionKwargs],
  623. ) -> tuple[torch.FloatTensor, Optional[tuple[torch.FloatTensor, torch.FloatTensor]]]:
  624. """
  625. Args:
  626. hidden_states (`torch.FloatTensor`): input to the layer of shape `(batch, seq_len, embed_dim)`
  627. attention_mask (`torch.FloatTensor`, *optional*): attention mask of size
  628. `(batch, sequence_length)` where padding elements are indicated by 0.
  629. past_key_values (`HybridMambaAttentionDynamicCache`, *optional*): cached past key and value projection states
  630. output_attentions (`bool`, *optional*):
  631. Whether or not to return the attentions tensors of all attention layers. See `attentions` under
  632. returned tensors for more detail.
  633. use_cache (`bool`, *optional*):
  634. If set to `True`, `past_key_values` key value states are returned and can be used to speed up decoding
  635. (see `past_key_values`).
  636. cache_position (`torch.LongTensor` of shape `(sequence_length)`, *optional*):
  637. Indices depicting the position of the input sequence tokens in the sequence.
  638. position_embeddings (`tuple[torch.FloatTensor, torch.FloatTensor]`, *optional*):
  639. Tuple containing the cosine and sine positional embeddings of shape `(batch_size, seq_len, head_dim)`,
  640. with `head_dim` being the embedding dimension of each attention head.
  641. kwargs (`dict`, *optional*):
  642. Arbitrary kwargs. Can be used to provide `BambaFlashAttentionKwargs` for
  643. padding-free training and/or improve torch.compile performance.
  644. """
  645. residual = hidden_states
  646. hidden_states = self.input_layernorm(hidden_states)
  647. # this is a hybrid decoder layer
  648. if self.layer_type == "mamba":
  649. hidden_states = self.mamba(
  650. hidden_states=hidden_states,
  651. cache_params=past_key_values,
  652. cache_position=cache_position,
  653. attention_mask=attention_mask,
  654. **kwargs,
  655. )
  656. self_attn_weights = None
  657. elif self.layer_type == "attention":
  658. hidden_states, self_attn_weights = self.self_attn(
  659. hidden_states=hidden_states,
  660. attention_mask=attention_mask,
  661. position_ids=position_ids,
  662. past_key_values=past_key_values,
  663. output_attentions=output_attentions,
  664. use_cache=use_cache,
  665. cache_position=cache_position,
  666. position_embeddings=position_embeddings,
  667. **kwargs,
  668. )
  669. # residual connection after attention
  670. hidden_states = residual + hidden_states
  671. # feed-forward
  672. residual = hidden_states
  673. hidden_states = self.pre_ff_layernorm(hidden_states)
  674. hidden_states = self.feed_forward(hidden_states)
  675. hidden_states = residual + hidden_states
  676. outputs = (hidden_states,)
  677. if output_attentions:
  678. outputs += (self_attn_weights,)
  679. return outputs
  680. @auto_docstring
  681. class BambaPreTrainedModel(PreTrainedModel):
  682. config: BambaConfig
  683. base_model_prefix = "model"
  684. supports_gradient_checkpointing = True
  685. _no_split_modules = ["BambaDecoderLayer"]
  686. _skip_keys_device_placement = "past_key_values"
  687. _supports_flash_attn = True
  688. _supports_sdpa = True
  689. # Note: only supports HybridMambaAttentionDynamicCache
  690. _is_stateful = True
  691. def _init_weights(self, module):
  692. super()._init_weights(module)
  693. if isinstance(module, BambaMixer):
  694. module.dt_bias.data.fill_(1.0)
  695. module.A_log.data = torch.log(torch.arange(1, module.num_heads + 1))
  696. module.D.data.fill_(1.0)
  697. @auto_docstring
  698. class BambaModel(BambaPreTrainedModel):
  699. def __init__(self, config: BambaConfig):
  700. super().__init__(config)
  701. self.padding_idx = config.pad_token_id
  702. self.vocab_size = config.vocab_size
  703. self.embed_tokens = nn.Embedding(config.vocab_size, config.hidden_size, self.padding_idx)
  704. decoder_layers = []
  705. for i in range(config.num_hidden_layers):
  706. decoder_layers.append(BambaDecoderLayer(config, layer_idx=i, layer_type=config.layers_block_type[i]))
  707. self.layers = nn.ModuleList(decoder_layers)
  708. self._attn_implementation = config._attn_implementation
  709. self.final_layernorm = BambaRMSNorm(config.hidden_size, eps=config.rms_norm_eps)
  710. self.rotary_emb = BambaRotaryEmbedding(config=config)
  711. self.gradient_checkpointing = False
  712. # Initialize weights and apply final processing
  713. self.post_init()
  714. @can_return_tuple
  715. @auto_docstring
  716. def forward(
  717. self,
  718. input_ids: Optional[torch.LongTensor] = None,
  719. attention_mask: Optional[torch.Tensor] = None,
  720. position_ids: Optional[torch.LongTensor] = None,
  721. past_key_values: Optional[HybridMambaAttentionDynamicCache] = None,
  722. inputs_embeds: Optional[torch.FloatTensor] = None,
  723. use_cache: Optional[bool] = None,
  724. output_attentions: Optional[bool] = None,
  725. output_hidden_states: Optional[bool] = None,
  726. cache_position: Optional[torch.LongTensor] = None,
  727. **kwargs: Unpack[BambaFlashAttentionKwargs],
  728. ) -> BaseModelOutputWithPast:
  729. output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
  730. output_hidden_states = (
  731. output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
  732. )
  733. use_cache = use_cache if use_cache is not None else self.config.use_cache
  734. if (input_ids is None) ^ (inputs_embeds is not None):
  735. raise ValueError("You must specify exactly one of input_ids or inputs_embeds")
  736. if self.gradient_checkpointing and self.training and use_cache:
  737. logger.warning_once(
  738. "`use_cache=True` is incompatible with gradient checkpointing. Setting `use_cache=False`."
  739. )
  740. use_cache = False
  741. if inputs_embeds is None:
  742. inputs_embeds = self.embed_tokens(input_ids)
  743. hidden_states = inputs_embeds
  744. if use_cache and past_key_values is None:
  745. logger.warning_once(
  746. "Bamba requires an initialized `HybridMambaAttentionDynamicCache` to return a cache. None was "
  747. "provided, so no cache will be returned."
  748. )
  749. if cache_position is None:
  750. cache_position = torch.arange(hidden_states.shape[1], device=hidden_states.device)
  751. if position_ids is None:
  752. position_ids = cache_position.unsqueeze(0)
  753. causal_mask = self._update_causal_mask(
  754. attention_mask, inputs_embeds, cache_position, past_key_values, output_attentions
  755. )
  756. mamba_mask = self._update_mamba_mask(attention_mask, cache_position)
  757. # create position embeddings to be shared across the decoder layers
  758. position_embeddings = self.rotary_emb(hidden_states, position_ids)
  759. all_hidden_states = () if output_hidden_states else None
  760. all_self_attns = () if output_attentions else None
  761. for decoder_layer in self.layers:
  762. # Depending on the layer type we opt for 2D base attention mask (Mamba) or 4D causal mask (Attention)
  763. layer_mask = mamba_mask if decoder_layer.layer_type == "mamba" else causal_mask
  764. if output_hidden_states:
  765. all_hidden_states += (hidden_states,)
  766. layer_outputs = decoder_layer(
  767. hidden_states,
  768. attention_mask=layer_mask,
  769. position_ids=position_ids,
  770. past_key_values=past_key_values,
  771. output_attentions=output_attentions,
  772. use_cache=use_cache,
  773. cache_position=cache_position,
  774. position_embeddings=position_embeddings,
  775. **kwargs,
  776. )
  777. hidden_states = layer_outputs[0]
  778. if output_attentions:
  779. if layer_outputs[1] is not None:
  780. # append attentions only of attention layers. Mamba layers return `None` as the attention weights
  781. all_self_attns += (layer_outputs[1],)
  782. hidden_states = self.final_layernorm(hidden_states)
  783. # add hidden states from the last decoder layer
  784. if output_hidden_states:
  785. all_hidden_states += (hidden_states,)
  786. if past_key_values and not past_key_values.has_previous_state:
  787. past_key_values.has_previous_state = True
  788. next_cache = None if not use_cache else past_key_values
  789. return BaseModelOutputWithPast(
  790. last_hidden_state=hidden_states,
  791. past_key_values=next_cache,
  792. hidden_states=all_hidden_states,
  793. attentions=all_self_attns,
  794. )
  795. def _update_causal_mask(
  796. self,
  797. attention_mask: torch.Tensor,
  798. input_tensor: torch.Tensor,
  799. cache_position: torch.Tensor,
  800. past_key_values: HybridMambaAttentionDynamicCache,
  801. output_attentions: bool,
  802. ):
  803. if self.config._attn_implementation == "flash_attention_2":
  804. if attention_mask is not None and 0.0 in attention_mask:
  805. return attention_mask
  806. return None
  807. # For SDPA, when possible, we will rely on its `is_causal` argument instead of its `attn_mask` argument, in
  808. # order to dispatch on Flash Attention 2. This feature is not compatible with static cache, as SDPA will fail
  809. # to infer the attention mask.
  810. past_seen_tokens = past_key_values.get_seq_length() if past_key_values is not None else 0
  811. # When output attentions is True, sdpa implementation's forward method calls the eager implementation's forward
  812. if self.config._attn_implementation == "sdpa" and not output_attentions:
  813. if AttentionMaskConverter._ignore_causal_mask_sdpa(
  814. attention_mask,
  815. inputs_embeds=input_tensor,
  816. past_key_values_length=past_seen_tokens,
  817. is_training=self.training,
  818. ):
  819. return None
  820. dtype = input_tensor.dtype
  821. sequence_length = input_tensor.shape[1]
  822. target_length = (
  823. attention_mask.shape[-1]
  824. if isinstance(attention_mask, torch.Tensor)
  825. else past_seen_tokens + sequence_length + 1
  826. )
  827. # In case the provided `attention` mask is 2D, we generate a causal mask here (4D).
  828. causal_mask = self._prepare_4d_causal_attention_mask_with_cache_position(
  829. attention_mask,
  830. sequence_length=sequence_length,
  831. target_length=target_length,
  832. dtype=dtype,
  833. cache_position=cache_position,
  834. batch_size=input_tensor.shape[0],
  835. )
  836. if (
  837. self.config._attn_implementation == "sdpa"
  838. and attention_mask is not None
  839. and attention_mask.device.type in ["cuda", "xpu", "npu"]
  840. and not output_attentions
  841. ):
  842. # Attend to all tokens in fully masked rows in the causal_mask, for example the relevant first rows when
  843. # using left padding. This is required by F.scaled_dot_product_attention memory-efficient attention path.
  844. # Details: https://github.com/pytorch/pytorch/issues/110213
  845. min_dtype = torch.finfo(dtype).min
  846. causal_mask = AttentionMaskConverter._unmask_unattended(causal_mask, min_dtype)
  847. return causal_mask
  848. @staticmethod
  849. def _prepare_4d_causal_attention_mask_with_cache_position(
  850. attention_mask: torch.Tensor,
  851. sequence_length: int,
  852. target_length: int,
  853. dtype: torch.dtype,
  854. cache_position: torch.Tensor,
  855. batch_size: int,
  856. **kwargs,
  857. ):
  858. """
  859. Creates a causal 4D mask of shape `(batch_size, 1, query_length, key_value_length)` from a 2D mask of shape
  860. `(batch_size, key_value_length)`, or if the input `attention_mask` is already 4D, do nothing.
  861. Args:
  862. attention_mask (`torch.Tensor`):
  863. A 2D attention mask of shape `(batch_size, key_value_length)` or a 4D attention mask of shape
  864. `(batch_size, 1, query_length, key_value_length)`.
  865. sequence_length (`int`):
  866. The sequence length being processed.
  867. target_length (`int`):
  868. The target length: when generating with static cache, the mask should be as long as the static cache,
  869. to account for the 0 padding, the part of the cache that is not filled yet.
  870. dtype (`torch.dtype`):
  871. The dtype to use for the 4D attention mask.
  872. cache_position (`torch.Tensor`):
  873. Indices depicting the position of the input sequence tokens in the sequence.
  874. batch_size (`torch.Tensor`):
  875. Batch size.
  876. """
  877. if attention_mask is not None and attention_mask.dim() == 4:
  878. # In this case we assume that the mask comes already in inverted form and requires no inversion or slicing.
  879. causal_mask = attention_mask
  880. else:
  881. min_dtype = torch.finfo(dtype).min
  882. causal_mask = torch.full(
  883. (sequence_length, target_length), fill_value=min_dtype, dtype=dtype, device=cache_position.device
  884. )
  885. if sequence_length != 1:
  886. causal_mask = torch.triu(causal_mask, diagonal=1)
  887. causal_mask *= torch.arange(target_length, device=cache_position.device) > cache_position.reshape(-1, 1)
  888. causal_mask = causal_mask[None, None, :, :].expand(batch_size, 1, -1, -1)
  889. if attention_mask is not None:
  890. causal_mask = causal_mask.clone() # copy to contiguous memory for in-place edit
  891. mask_length = attention_mask.shape[-1]
  892. padding_attention_mask = (attention_mask[:, None, None, :] == attention_mask[:, None, :, None])[
  893. :, :, -sequence_length:, :
  894. ].to(dtype)
  895. padding_mask = causal_mask[:, :, :, :mask_length] + padding_attention_mask
  896. padding_mask = padding_mask == 0
  897. causal_mask[:, :, :, :mask_length] = causal_mask[:, :, :, :mask_length].masked_fill(
  898. padding_mask, min_dtype
  899. )
  900. return causal_mask
  901. def _update_mamba_mask(self, attention_mask, cache_position):
  902. """
  903. No need for zeroing states when
  904. 1. Cached forward
  905. 2. Attending to all inputs
  906. """
  907. mamba_mask = attention_mask
  908. if cache_position[0] > 0 or (attention_mask is not None and torch.all(attention_mask == 1)):
  909. mamba_mask = None
  910. return mamba_mask
  911. class BambaForCausalLM(LlamaForCausalLM):
  912. def __init__(self, config):
  913. super().__init__(config)
  914. self.z_loss_coefficient = config.z_loss_coefficient
  915. # Initialize weights and apply final processing
  916. self.post_init()
  917. def forward(
  918. self,
  919. input_ids: Optional[torch.LongTensor] = None,
  920. attention_mask: Optional[torch.Tensor] = None,
  921. position_ids: Optional[torch.LongTensor] = None,
  922. past_key_values: Optional[HybridMambaAttentionDynamicCache] = None,
  923. inputs_embeds: Optional[torch.FloatTensor] = None,
  924. labels: Optional[torch.LongTensor] = None,
  925. use_cache: Optional[bool] = None,
  926. output_attentions: Optional[bool] = None,
  927. output_hidden_states: Optional[bool] = None,
  928. cache_position: Optional[torch.LongTensor] = None,
  929. logits_to_keep: Union[int, torch.Tensor] = 0,
  930. **kwargs,
  931. ) -> CausalLMOutputWithPast:
  932. r"""
  933. labels (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*):
  934. Labels for computing the masked language modeling loss. Indices should either be in `[0, ...,
  935. config.vocab_size]` or -100 (see `input_ids` docstring). Tokens with indices set to `-100` are ignored
  936. (masked), the loss is only computed for the tokens with labels in `[0, ..., config.vocab_size]`.
  937. Example:
  938. ```python
  939. >>> from transformers import AutoTokenizer, BambaForCausalLM
  940. >>> model = BambaForCausalLM.from_pretrained("...")
  941. >>> tokenizer = AutoTokenizer.from_pretrained("...")
  942. >>> prompt = "Hey, are you conscious? Can you talk to me?"
  943. >>> inputs = tokenizer(prompt, return_tensors="pt")
  944. >>> # Generate
  945. >>> generate_ids = model.generate(inputs.input_ids, max_length=30)
  946. >>> tokenizer.batch_decode(generate_ids, skip_special_tokens=True, clean_up_tokenization_spaces=False)[0]
  947. "Hey, are you conscious? Can you talk to me?\nI'm not conscious, but I can talk to you."
  948. ```"""
  949. output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
  950. output_hidden_states = (
  951. output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
  952. )
  953. # decoder outputs consists of (dec_features, layer_state, dec_hidden, dec_attn)
  954. outputs: BaseModelOutputWithPast = self.model(
  955. input_ids=input_ids,
  956. attention_mask=attention_mask,
  957. position_ids=position_ids,
  958. past_key_values=past_key_values,
  959. inputs_embeds=inputs_embeds,
  960. use_cache=use_cache,
  961. output_attentions=output_attentions,
  962. output_hidden_states=output_hidden_states,
  963. cache_position=cache_position,
  964. **kwargs,
  965. )
  966. hidden_states = outputs.last_hidden_state
  967. # Only compute necessary logits, and do not upcast them to float if we are not computing the loss
  968. slice_indices = slice(-logits_to_keep, None) if isinstance(logits_to_keep, int) else logits_to_keep
  969. logits = self.lm_head(hidden_states[:, slice_indices, :])
  970. loss = None
  971. if labels is not None:
  972. loss = self.loss_function(logits=logits, labels=labels, vocab_size=self.config.vocab_size, **kwargs)
  973. if self.z_loss_coefficient > 0:
  974. # Type-match loss, but avoid upcasting large logits tensor until after it's been reduced on dim -1
  975. z_loss = logits.logsumexp(dim=-1).to(dtype=loss.dtype).pow(2).mean()
  976. loss = loss + self.z_loss_coefficient * z_loss
  977. return CausalLMOutputWithPast(
  978. loss=loss,
  979. logits=logits,
  980. past_key_values=outputs.past_key_values,
  981. hidden_states=outputs.hidden_states,
  982. attentions=outputs.attentions,
  983. )
  984. def prepare_inputs_for_generation(
  985. self,
  986. input_ids,
  987. past_key_values=None,
  988. attention_mask=None,
  989. inputs_embeds=None,
  990. cache_position=None,
  991. position_ids=None,
  992. use_cache=True,
  993. **kwargs,
  994. ):
  995. # Overwritten -- has a unique cache type, `HybridMambaAttentionDynamicCache`
  996. empty_past_kv = past_key_values is None
  997. # If we have cache: let's slice `input_ids` through `cache_position`, to keep only the unprocessed tokens
  998. # Exception 1: when passing input_embeds, input_ids may be missing entries
  999. # Exception 2: some generation methods do special slicing of input_ids, so we don't need to do it here
  1000. # Exception 3: with synced GPUs cache_position may go out of bounds, but we only want dummy token in that case.
  1001. # (we can't check exception 3 while compiling)
  1002. if not empty_past_kv:
  1003. if (
  1004. inputs_embeds is not None # Exception 1
  1005. or cache_position[-1] >= input_ids.shape[1] # Exception 3
  1006. ):
  1007. input_ids = input_ids[:, -cache_position.shape[0] :]
  1008. elif input_ids.shape[1] != cache_position.shape[0]: # Default case (the "else", a no op, is Exception 2)
  1009. input_ids = input_ids[:, cache_position]
  1010. else:
  1011. past_key_values = HybridMambaAttentionDynamicCache(
  1012. self.config, input_ids.shape[0], self.dtype, device=self.device
  1013. )
  1014. if attention_mask is not None and position_ids is None:
  1015. # create position_ids on the fly for batch generation
  1016. position_ids = attention_mask.long().cumsum(-1) - 1
  1017. position_ids.masked_fill_(attention_mask == 0, 1)
  1018. if not empty_past_kv:
  1019. position_ids = position_ids[:, -input_ids.shape[1] :]
  1020. # if `inputs_embeds` are passed, we only want to use them in the 1st generation step
  1021. if inputs_embeds is not None and empty_past_kv:
  1022. model_inputs = {"inputs_embeds": inputs_embeds}
  1023. else:
  1024. model_inputs = {"input_ids": input_ids.contiguous()} # `contiguous()` needed for compilation use cases
  1025. model_inputs.update(
  1026. {
  1027. "position_ids": position_ids,
  1028. "past_key_values": past_key_values,
  1029. "use_cache": use_cache,
  1030. "attention_mask": attention_mask,
  1031. "logits_to_keep": self.config.num_logits_to_keep,
  1032. "cache_position": cache_position,
  1033. }
  1034. )
  1035. # Forward ALL kwargs that are uninitialized (e.g. `use_cache`).
  1036. for key, value in kwargs.items():
  1037. if key not in model_inputs:
  1038. model_inputs[key] = value
  1039. return model_inputs
  1040. __all__ = ["BambaModel", "BambaForCausalLM", "BambaPreTrainedModel"]