modeling_mamba2.py 47 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525526527528529530531532533534535536537538539540541542543544545546547548549550551552553554555556557558559560561562563564565566567568569570571572573574575576577578579580581582583584585586587588589590591592593594595596597598599600601602603604605606607608609610611612613614615616617618619620621622623624625626627628629630631632633634635636637638639640641642643644645646647648649650651652653654655656657658659660661662663664665666667668669670671672673674675676677678679680681682683684685686687688689690691692693694695696697698699700701702703704705706707708709710711712713714715716717718719720721722723724725726727728729730731732733734735736737738739740741742743744745746747748749750751752753754755756757758759760761762763764765766767768769770771772773774775776777778779780781782783784785786787788789790791792793794795796797798799800801802803804805806807808809810811812813814815816817818819820821822823824825826827828829830831832833834835836837838839840841842843844845846847848849850851852853854855856857858859860861862863864865866867868869870871872873874875876877878879880881882883884885886887888889890891892893894895896897898899900901902903904905906907908909910911912913914915916917918919920921922923924925926927928929930931932933934935936937938939940941942943944945946947948949950951952953954955956957958959960961962963964965966967968969970971972973974975976977978979980981982983984985986987988989990991992993994995996997998999100010011002100310041005100610071008100910101011101210131014101510161017101810191020102110221023102410251026102710281029103010311032103310341035103610371038103910401041104210431044104510461047104810491050105110521053105410551056105710581059
  1. # coding=utf-8
  2. # Copyright 2024 state-spaces/mamba2 org and HuggingFace Inc. team.
  3. #
  4. # Licensed under the Apache License, Version 2.0 (the "License");
  5. # you may not use this file except in compliance with the License.
  6. # You may obtain a copy of the License at
  7. #
  8. # http://www.apache.org/licenses/LICENSE-2.0
  9. #
  10. # Unless required by applicable law or agreed to in writing, software
  11. # distributed under the License is distributed on an "AS IS" BASIS,
  12. # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  13. # See the License for the specific language governing permissions and
  14. # limitations under the License.
  15. """PyTorch MAMBA2 model."""
  16. import math
  17. from dataclasses import dataclass
  18. from typing import Optional, Union
  19. import torch
  20. from torch import nn
  21. from ...activations import ACT2FN
  22. from ...generation import GenerationMixin
  23. from ...modeling_layers import GradientCheckpointingLayer
  24. from ...modeling_utils import PreTrainedModel
  25. from ...utils import (
  26. ModelOutput,
  27. auto_docstring,
  28. logging,
  29. )
  30. from ...utils.import_utils import is_causal_conv1d_available, is_mamba_2_ssm_available
  31. from .configuration_mamba2 import Mamba2Config
  32. logger = logging.get_logger(__name__)
  33. if is_mamba_2_ssm_available():
  34. from mamba_ssm.ops.triton.selective_state_update import selective_state_update
  35. from mamba_ssm.ops.triton.ssd_combined import mamba_chunk_scan_combined, mamba_split_conv1d_scan_combined
  36. else:
  37. mamba_chunk_scan_combined, mamba_split_conv1d_scan_combined, selective_state_update = None, None, None
  38. if is_causal_conv1d_available():
  39. from causal_conv1d import causal_conv1d_fn, causal_conv1d_update
  40. else:
  41. causal_conv1d_update, causal_conv1d_fn = None, None
  42. is_fast_path_available = all(
  43. (
  44. selective_state_update,
  45. mamba_chunk_scan_combined,
  46. mamba_split_conv1d_scan_combined,
  47. causal_conv1d_fn,
  48. causal_conv1d_update,
  49. )
  50. )
  51. # Helper methods for segment sum computation
  52. def pad_tensor_by_size(input_tensor: torch.Tensor, pad_size: int):
  53. """
  54. Padding x tensor with `pad_size` on the seq_len dim (dim=1)
  55. Assumes that we only have tensors of either size 4 or 3
  56. """
  57. pad_shape = (0, 0, 0, 0, 0, pad_size, 0, 0) if len(input_tensor.shape) == 4 else (0, 0, 0, pad_size, 0, 0)
  58. return torch.nn.functional.pad(input_tensor, pad_shape, mode="constant", value=0)
  59. def reshape_into_chunks(input_tensor, pad_size, chunk_size):
  60. """
  61. Padding input_tensor with `pad_size` on the seq_len dim (dim=1) and
  62. simultaneously splitting it into chunk sequences.
  63. Assumes that we only have tensors of either size 4 or 3
  64. """
  65. # [bsz, seq_len, ...] -> [bsz, seq_len multiple of chunk_size, ...]
  66. input_tensor = pad_tensor_by_size(input_tensor, pad_size)
  67. if len(input_tensor.shape) == 3:
  68. # [bsz, seq_len multiple of chunk_size, num_heads] -> [bsz, -1, chunk_size, num_heads]
  69. return input_tensor.reshape(input_tensor.shape[0], -1, chunk_size, input_tensor.shape[2])
  70. else:
  71. # [bsz, seq_len multiple of chunk_size, num_heads, head_dim or state_size] -> [bsz, -1, chunk_size, num_heads, head_dim or state_size]
  72. return input_tensor.reshape(
  73. input_tensor.shape[0], -1, chunk_size, input_tensor.shape[2], input_tensor.shape[3]
  74. )
  75. def segment_sum(input_tensor):
  76. """
  77. More stable segment sum calculation. Uses cumulative sums and masking instead of direct subtractions.
  78. """
  79. chunk_size = input_tensor.size(-1)
  80. # 1. expand input tensor to have an additional dimension and repeat along that dimension
  81. # [..., chunk_size] -> [..., chunk_size, chunk_size]
  82. input_tensor = input_tensor[..., None].expand(*input_tensor.size(), chunk_size)
  83. # 2. create a lower triangular mask with the diagonal set to 0 to 0 out elements above diag
  84. mask = torch.tril(torch.ones(chunk_size, chunk_size, device=input_tensor.device, dtype=torch.bool), diagonal=-1)
  85. input_tensor = input_tensor.masked_fill(~mask, 0)
  86. # 3. compute actual cumsum
  87. tensor_segsum = torch.cumsum(input_tensor, dim=-2)
  88. # 4. apply mask to keep only the lower triangular part of the cumulative sum result (incl diagonal this time)
  89. mask = torch.tril(torch.ones(chunk_size, chunk_size, device=input_tensor.device, dtype=torch.bool), diagonal=0)
  90. tensor_segsum = tensor_segsum.masked_fill(~mask, -torch.inf)
  91. return tensor_segsum
  92. def apply_mask_to_padding_states(hidden_states, attention_mask):
  93. """
  94. Tunes out the hidden states for padding tokens, see https://github.com/state-spaces/mamba/issues/66
  95. """
  96. if attention_mask is not None and attention_mask.shape[1] > 1 and attention_mask.shape[0] > 1:
  97. dtype = hidden_states.dtype
  98. hidden_states = (hidden_states * attention_mask[:, :, None]).to(dtype)
  99. return hidden_states
  100. class Mamba2Cache:
  101. """
  102. Arguments:
  103. config: Mamba2Config
  104. batch_size: int
  105. dtype: torch.dtype
  106. device: torch.device
  107. Attributes:
  108. dtype: (`torch.dtype`):
  109. The default `dtype` used to initializing the cache.
  110. conv_kernel_size: (`int`):
  111. Model's convolution kernel size taken from config.
  112. n_groups: (`int`):
  113. Model's number of groups taken from the config - similar to tensor parallel in Transformer.
  114. state_size: (`int`):
  115. Model's SSM state size taken from config.
  116. num_heads: (`int`):
  117. The number of heads used in the linear attention / SSM.
  118. head_dim: (`int`):
  119. The respective dimension of the heads used in the linear attention / SSM.
  120. intermediate_size: (`int`):
  121. Model's intermediate_size based on (expand * hidden_dim) from config.
  122. conv_states: (`torch.Tensor`):
  123. A tensor of shape `[num_layers, batch_size, conv_kernel_size, intermediate_size + 2 * n_groups * state_size]` that holds convolutional states.
  124. ssm_states: (`torch.Tensor`):
  125. A tensor of shape `[num_layers, batch_size, num_heads, head_dim, state_size]` that holds ssm states.
  126. """
  127. def __init__(
  128. self, config: Mamba2Config, batch_size: int, dtype: torch.dtype = torch.float16, device: Optional[str] = None
  129. ):
  130. self.dtype = dtype
  131. self.conv_kernel_size = config.conv_kernel
  132. self.n_groups = config.n_groups
  133. self.state_size = config.state_size
  134. self.num_heads = config.num_heads
  135. self.head_dim = config.head_dim
  136. self.intermediate_size = int(config.expand * config.hidden_size)
  137. self.conv_states = torch.zeros(
  138. config.num_hidden_layers,
  139. batch_size,
  140. self.intermediate_size + 2 * self.n_groups * self.state_size,
  141. self.conv_kernel_size,
  142. device=device,
  143. dtype=dtype,
  144. )
  145. self.ssm_states = torch.zeros(
  146. config.num_hidden_layers,
  147. batch_size,
  148. self.num_heads,
  149. self.head_dim,
  150. self.state_size,
  151. device=device,
  152. dtype=dtype,
  153. )
  154. def update_conv_state(
  155. self, layer_idx: int, new_conv_state: torch.Tensor, cache_init: bool = False
  156. ) -> torch.Tensor:
  157. if cache_init:
  158. self.conv_states[layer_idx] = new_conv_state.to(self.conv_states.device)
  159. else:
  160. self.conv_states[layer_idx] = self.conv_states[layer_idx].roll(shifts=-1, dims=-1)
  161. self.conv_states[layer_idx][:, :, -1] = new_conv_state[:, 0, :].to(self.conv_states.device)
  162. return self.conv_states[layer_idx]
  163. def update_ssm_state(self, layer_idx: int, new_ssm_state: torch.Tensor):
  164. self.ssm_states[layer_idx] = new_ssm_state.to(self.ssm_states.device)
  165. return self.ssm_states[layer_idx]
  166. def reset(self):
  167. self.conv_states.zero_()
  168. self.ssm_states.zero_()
  169. class MambaRMSNormGated(torch.nn.Module):
  170. def __init__(self, hidden_size, eps=1e-6):
  171. super().__init__()
  172. self.weight = nn.Parameter(torch.ones(hidden_size))
  173. self.variance_epsilon = eps
  174. def forward(self, hidden_states, gate=None):
  175. input_dtype = hidden_states.dtype
  176. hidden_states = hidden_states.to(torch.float32)
  177. if gate is not None:
  178. hidden_states = hidden_states * nn.functional.silu(gate.to(torch.float32))
  179. variance = hidden_states.pow(2).mean(-1, keepdim=True)
  180. hidden_states = hidden_states * torch.rsqrt(variance + self.variance_epsilon)
  181. return self.weight * hidden_states.to(input_dtype)
  182. class Mamba2Mixer(nn.Module):
  183. """
  184. Compute ∆, A, B, C, and D the state space parameters and compute the `contextualized_states`.
  185. A, D are input independent (see Mamba paper [1] Section 3.5.2 "Interpretation of A" for why A isn't selective)
  186. ∆, B, C are input-dependent (this is a key difference between Mamba and the linear time invariant S4,
  187. and is why Mamba is called **selective** state spaces)
  188. """
  189. def __init__(self, config: Mamba2Config, layer_idx: int):
  190. super().__init__()
  191. self.num_heads = config.num_heads
  192. self.hidden_size = config.hidden_size
  193. self.ssm_state_size = config.state_size
  194. self.conv_kernel_size = config.conv_kernel
  195. self.intermediate_size = int(config.expand * self.hidden_size)
  196. self.time_step_rank = int(config.time_step_rank)
  197. self.layer_idx = layer_idx
  198. self.use_conv_bias = config.use_conv_bias
  199. self.activation = config.hidden_act
  200. self.act = ACT2FN[config.hidden_act]
  201. self.layer_norm_epsilon = config.layer_norm_epsilon
  202. self.rms_norm = config.rms_norm
  203. self.n_groups = config.n_groups
  204. self.head_dim = config.head_dim
  205. self.chunk_size = config.chunk_size
  206. self.time_step_limit = config.time_step_limit
  207. self.time_step_min = config.time_step_min
  208. self.time_step_max = config.time_step_max
  209. self.conv_dim = self.intermediate_size + 2 * self.n_groups * self.ssm_state_size
  210. self.conv1d = nn.Conv1d(
  211. in_channels=self.conv_dim,
  212. out_channels=self.conv_dim,
  213. bias=config.use_conv_bias,
  214. kernel_size=config.conv_kernel,
  215. groups=self.conv_dim,
  216. padding=config.conv_kernel - 1,
  217. )
  218. # projection of the input hidden states
  219. projection_size = self.intermediate_size + self.conv_dim + self.num_heads
  220. self.in_proj = nn.Linear(
  221. self.hidden_size,
  222. projection_size,
  223. bias=config.use_bias,
  224. )
  225. # selective projection used to make dt, B and C input dependent
  226. # time step projection (discretization)
  227. # instantiate once and copy inv_dt in init_weights of PretrainedModel
  228. self.dt_bias = nn.Parameter(torch.ones(self.num_heads))
  229. # S4D real initialization. These are not discretized!
  230. # The core is to load them, compute the discrete states, then write the updated state. Keeps the memory bounded
  231. A = torch.arange(1, self.num_heads + 1)
  232. self.A_log = nn.Parameter(torch.log(A))
  233. self.norm = MambaRMSNormGated(self.intermediate_size, eps=self.layer_norm_epsilon)
  234. self.D = nn.Parameter(torch.ones(self.num_heads))
  235. self.out_proj = nn.Linear(self.intermediate_size, self.hidden_size, bias=config.use_bias)
  236. self.use_bias = config.use_bias
  237. if not is_fast_path_available:
  238. logger.warning_once(
  239. "The fast path is not available because one of `(selective_state_update, causal_conv1d_fn, causal_conv1d_update)`"
  240. " is None. Falling back to the naive implementation. To install follow https://github.com/state-spaces/mamba/#installation and"
  241. " https://github.com/Dao-AILab/causal-conv1d"
  242. )
  243. def cuda_kernels_forward(
  244. self,
  245. hidden_states: torch.Tensor,
  246. cache_params: Optional[Mamba2Cache] = None,
  247. cache_position: Optional[torch.LongTensor] = None,
  248. attention_mask: Optional[torch.Tensor] = None,
  249. ):
  250. # 1. Gated MLP's linear projection
  251. hidden_states = apply_mask_to_padding_states(hidden_states, attention_mask)
  252. projected_states = self.in_proj(hidden_states)
  253. # Set up dimensions for reshapes later
  254. batch_size, seq_len, _ = hidden_states.shape
  255. groups_time_state_size = self.n_groups * self.ssm_state_size
  256. d_mlp = (
  257. projected_states.shape[-1]
  258. - 2 * self.intermediate_size
  259. - 2 * self.n_groups * self.ssm_state_size
  260. - self.num_heads
  261. ) // 2
  262. # Single step calculations via cache
  263. if cache_params is not None and cache_position is not None and cache_position[0] > 0:
  264. _, _, gate, hidden_states_B_C, dt = projected_states.squeeze(1).split(
  265. [d_mlp, d_mlp, self.intermediate_size, self.conv_dim, self.num_heads], dim=-1
  266. )
  267. # 2. Convolution sequence transformation
  268. hidden_states_B_C = causal_conv1d_update(
  269. hidden_states_B_C,
  270. cache_params.conv_states[self.layer_idx],
  271. self.conv1d.weight.squeeze(1),
  272. self.conv1d.bias,
  273. self.activation,
  274. )
  275. hidden_states, B, C = torch.split(
  276. hidden_states_B_C,
  277. [self.intermediate_size, groups_time_state_size, groups_time_state_size],
  278. dim=-1,
  279. )
  280. # 3. SSM transformation
  281. A = -torch.exp(self.A_log.float()) # (nheads,)
  282. A = A[:, None, ...][:, :, None].expand(-1, self.head_dim, self.ssm_state_size).to(dtype=torch.float32)
  283. dt = dt[:, :, None].expand(-1, -1, self.head_dim)
  284. dt_bias = self.dt_bias[:, None, ...].expand(-1, self.head_dim)
  285. D = self.D[:, None, ...].expand(-1, self.head_dim)
  286. B = B.view(batch_size, self.n_groups, B.shape[1] // self.n_groups)
  287. C = C.view(batch_size, self.n_groups, C.shape[1] // self.n_groups)
  288. hidden_states_reshaped = hidden_states.view(batch_size, self.num_heads, self.head_dim)
  289. hidden_states = selective_state_update(
  290. cache_params.ssm_states[self.layer_idx],
  291. hidden_states_reshaped,
  292. dt,
  293. A,
  294. B,
  295. C,
  296. D,
  297. z=None,
  298. dt_bias=dt_bias,
  299. dt_softplus=True,
  300. )
  301. hidden_states = hidden_states.view(batch_size, self.num_heads * self.head_dim)
  302. hidden_states = self.norm(hidden_states, gate)
  303. # 4. Final linear projection
  304. out = self.out_proj(hidden_states)[:, None, ...]
  305. # Fused calculations or step by step if no initialized cache is found
  306. else:
  307. A = -torch.exp(self.A_log.float()) # (num_heads) or (intermediate_size, state_size)
  308. dt_limit_kwargs = {} if self.time_step_limit == (0.0, float("inf")) else {"dt_limit": self.time_step_limit}
  309. # 2-4. Fused kernel for conv1d, SSM, and the final projection
  310. if self.training and cache_params is None:
  311. out = mamba_split_conv1d_scan_combined(
  312. projected_states,
  313. self.conv1d.weight.squeeze(1),
  314. self.conv1d.bias,
  315. self.dt_bias,
  316. A,
  317. D=self.D,
  318. chunk_size=self.chunk_size,
  319. seq_idx=None, # was seq_idx
  320. activation=self.activation,
  321. rmsnorm_weight=self.norm.weight,
  322. rmsnorm_eps=self.norm.variance_epsilon,
  323. outproj_weight=self.out_proj.weight,
  324. outproj_bias=self.out_proj.bias,
  325. headdim=self.head_dim,
  326. ngroups=self.n_groups,
  327. norm_before_gate=False,
  328. return_final_states=False,
  329. **dt_limit_kwargs,
  330. )
  331. else:
  332. _, _, gate, hidden_states_B_C, dt = projected_states.split(
  333. [d_mlp, d_mlp, self.intermediate_size, self.conv_dim, self.num_heads], dim=-1
  334. )
  335. # 2. Convolution sequence transformation
  336. # Init cache
  337. if cache_params is not None:
  338. hidden_states_B_C_transposed = hidden_states_B_C.transpose(1, 2)
  339. conv_states = nn.functional.pad(
  340. hidden_states_B_C_transposed,
  341. (cache_params.conv_kernel_size - hidden_states_B_C_transposed.shape[-1], 0),
  342. )
  343. cache_params.update_conv_state(
  344. layer_idx=self.layer_idx, new_conv_state=conv_states, cache_init=True
  345. )
  346. if self.activation not in ["silu", "swish"]:
  347. hidden_states_B_C = self.act(
  348. self.conv1d(hidden_states_B_C.transpose(1, 2))[..., :seq_len].transpose(1, 2)
  349. )
  350. else:
  351. hidden_states_B_C = causal_conv1d_fn(
  352. x=hidden_states_B_C.transpose(1, 2),
  353. weight=self.conv1d.weight.squeeze(1),
  354. bias=self.conv1d.bias,
  355. activation=self.activation,
  356. ).transpose(1, 2)
  357. hidden_states_B_C = apply_mask_to_padding_states(hidden_states_B_C, attention_mask)
  358. hidden_states, B, C = torch.split(
  359. hidden_states_B_C,
  360. [self.intermediate_size, groups_time_state_size, groups_time_state_size],
  361. dim=-1,
  362. )
  363. # 3. SSM transformation
  364. scan_output, ssm_state = mamba_chunk_scan_combined(
  365. hidden_states.view(batch_size, seq_len, -1, self.head_dim),
  366. dt,
  367. A,
  368. B.view(batch_size, seq_len, self.n_groups, -1),
  369. C.view(batch_size, seq_len, self.n_groups, -1),
  370. chunk_size=self.chunk_size,
  371. D=self.D,
  372. z=None,
  373. seq_idx=None,
  374. return_final_states=True,
  375. dt_bias=self.dt_bias,
  376. dt_softplus=True,
  377. **dt_limit_kwargs,
  378. )
  379. # Init cache
  380. if ssm_state is not None and cache_params is not None:
  381. cache_params.update_ssm_state(layer_idx=self.layer_idx, new_ssm_state=ssm_state)
  382. scan_output = scan_output.view(batch_size, seq_len, -1)
  383. # Multiply "gate" branch and apply extra normalization layer
  384. scan_output = self.norm(scan_output, gate)
  385. # 4. Final linear projection
  386. out = self.out_proj(scan_output)
  387. return out
  388. # fmt: off
  389. def torch_forward(
  390. self,
  391. hidden_states: torch.Tensor,
  392. cache_params: Optional[Mamba2Cache]=None,
  393. cache_position:Optional[torch.LongTensor]=None,
  394. attention_mask: Optional[torch.Tensor]=None
  395. ):
  396. batch_size, seq_len, _ = hidden_states.shape
  397. dtype = hidden_states.dtype
  398. # 1. Gated MLP's linear projection
  399. hidden_states = apply_mask_to_padding_states(hidden_states, attention_mask)
  400. projected_states = self.in_proj(hidden_states)
  401. d_mlp = (projected_states.shape[-1] - 2 * self.intermediate_size - 2 * self.n_groups * self.ssm_state_size-self.num_heads) // 2
  402. _, _, gate, hidden_states_B_C, dt = projected_states.split(
  403. [d_mlp, d_mlp, self.intermediate_size, self.conv_dim, self.num_heads], dim=-1
  404. )
  405. # 2. Convolution sequence transformation
  406. if cache_params is not None and cache_position is not None and cache_position[0] > 0:
  407. cache_params.update_conv_state(layer_idx=self.layer_idx, new_conv_state=hidden_states_B_C, cache_init=False)
  408. # We need to guarantee that anything regarding the cache is on the same device
  409. conv_states = cache_params.conv_states[self.layer_idx].to(device=self.conv1d.weight.device)
  410. hidden_states_B_C = torch.sum(
  411. conv_states * self.conv1d.weight.squeeze(1), dim=-1
  412. )
  413. if self.use_conv_bias:
  414. hidden_states_B_C = hidden_states_B_C + self.conv1d.bias
  415. hidden_states_B_C = self.act(hidden_states_B_C)
  416. else:
  417. # Init cache
  418. if cache_params is not None:
  419. hidden_states_B_C_transposed = hidden_states_B_C.transpose(1, 2)
  420. conv_states = nn.functional.pad(
  421. hidden_states_B_C_transposed, (cache_params.conv_kernel_size - hidden_states_B_C_transposed.shape[-1], 0)
  422. )
  423. cache_params.update_conv_state(layer_idx=self.layer_idx, new_conv_state=conv_states, cache_init=True)
  424. hidden_states_B_C = self.act(self.conv1d(hidden_states_B_C.transpose(1, 2))[..., :seq_len].transpose(1, 2))
  425. hidden_states_B_C = apply_mask_to_padding_states(hidden_states_B_C, attention_mask)
  426. hidden_states, B, C = torch.split(
  427. hidden_states_B_C,
  428. [self.intermediate_size, self.n_groups * self.ssm_state_size, self.n_groups * self.ssm_state_size],
  429. dim=-1
  430. )
  431. # 3. SSM transformation
  432. A = -torch.exp(self.A_log.float()) # [num_heads]
  433. if cache_params is not None and cache_position is not None and cache_position[0] > 0:
  434. # We need to guarantee that anything regarding the cache is on the same device
  435. cache_device = cache_params.ssm_states.device
  436. # Note: there is no need to pad parameter matrices here, as there is just one new token
  437. # for batched generation
  438. dt = dt[:, 0, :][:, None, ...]
  439. dt = dt.transpose(1, 2).expand(batch_size, dt.shape[-1], self.head_dim)
  440. # [num_heads] -> [num_heads, head_dim]
  441. dt_bias = self.dt_bias[..., None].expand(self.dt_bias.shape[0], self.head_dim)
  442. dt = torch.nn.functional.softplus(dt + dt_bias.to(dt.dtype))
  443. dt = torch.clamp(dt, self.time_step_limit[0], self.time_step_limit[1])
  444. A = A[..., None, None].expand(self.num_heads, self.head_dim, self.ssm_state_size).to(dtype=torch.float32)
  445. # [bsz, num_heads, head_dim, state_size]
  446. dA = (torch.exp(dt[..., None] * A)).to(device=cache_device)
  447. # Discretize B
  448. # [bsz, n_groups * state_size] -> [bsz, n_groups, 1, state_size] ->
  449. # -> [bsz, n_groups, group to head repetition factor, state_size] -> [bsz, num_heads, state_size]
  450. B = B.reshape(batch_size, self.n_groups, -1)[..., None, :]
  451. B = B.expand(batch_size, self.n_groups, self.num_heads // self.n_groups, B.shape[-1]).contiguous()
  452. B = B.reshape(batch_size, -1, B.shape[-1])
  453. # [bsz, num_heads, head_dim, state_size]
  454. dB = dt[..., None] * B[..., None, :]
  455. # Discretize x into dB
  456. # [bsz, intermediate_size] -> [bsz, num_heads, head_dim]
  457. hidden_states = hidden_states.reshape(batch_size, -1, self.head_dim)
  458. dBx = (dB * hidden_states[..., None]).to(device=cache_device)
  459. # State calculation
  460. cache_params.update_ssm_state(
  461. layer_idx=self.layer_idx,
  462. new_ssm_state=cache_params.ssm_states[self.layer_idx] * dA + dBx
  463. )
  464. # Subsequent output
  465. # [bsz, n_groups * state_size] -> [bsz, num_heads, state_size]
  466. C = C.reshape(batch_size, self.n_groups, -1)[..., None, :]
  467. C = C.expand(batch_size, self.n_groups, self.num_heads // self.n_groups, C.shape[-1]).contiguous()
  468. C = C.reshape(batch_size, -1, C.shape[-1])
  469. # [bsz, num_heads, head_dim]
  470. ssm_states = cache_params.ssm_states[self.layer_idx].to(device=C.device, dtype=C.dtype) # Shape: [b, h, d, n]
  471. # Reshape ssm_states to merge the first two dimensions
  472. ssm_states_reshaped = ssm_states.view(batch_size * self.num_heads, self.head_dim, self.ssm_state_size) # Shape: [b*h, d, n]
  473. C_reshaped = C.view(batch_size * self.num_heads, self.ssm_state_size, 1) # Shape: [b*h, n, 1]
  474. y = torch.bmm(ssm_states_reshaped, C_reshaped)
  475. y = y.view(batch_size, self.num_heads, self.head_dim)
  476. # D skip connection
  477. # [num_heads] -> [num_heads, head_dim]
  478. D = self.D[..., None].expand(self.D.shape[0], self.head_dim)
  479. y = (y + hidden_states * D).to(y.dtype)
  480. # [bsz, num_heads, head_dim] -> [bsz, 1, intermediate_size]
  481. y = y.reshape(batch_size, -1)[:, None, ...]
  482. else:
  483. # begin ssd naive implementation without einsums
  484. dt = nn.functional.softplus(dt + self.dt_bias)
  485. dt = torch.clamp(dt, self.time_step_limit[0], self.time_step_limit[1])
  486. hidden_states = hidden_states.reshape(batch_size, seq_len, -1, self.head_dim).float()
  487. B = B.reshape(batch_size, seq_len, -1, self.ssm_state_size).float()
  488. C = C.reshape(batch_size, seq_len, -1, self.ssm_state_size).float()
  489. B = B.repeat_interleave(self.num_heads // self.n_groups, dim=2, output_size=self.num_heads)
  490. C = C.repeat_interleave(self.num_heads // self.n_groups, dim=2, output_size=self.num_heads)
  491. pad_size = (self.chunk_size - seq_len % self.chunk_size) % self.chunk_size
  492. D_residual = self.D[..., None] * pad_tensor_by_size(hidden_states, pad_size)
  493. # Discretize x and A
  494. hidden_states = hidden_states * dt[..., None]
  495. A = A.to(hidden_states.dtype) * dt
  496. # Rearrange into blocks/chunks
  497. hidden_states, A, B, C = [reshape_into_chunks(t, pad_size, self.chunk_size) for t in (hidden_states, A, B, C)]
  498. # [bsz, -1, chunk_size, num_heads] -> [bsz, num_heads, -1, chunk_size]
  499. A = A.permute(0, 3, 1, 2)
  500. A_cumsum = torch.cumsum(A, dim=-1)
  501. # 1. Compute the output for each intra-chunk (diagonal blocks)
  502. # This is the analog of a causal mask
  503. L = torch.exp(segment_sum(A))
  504. # Contraction of C and B to get G (attention-weights like)
  505. G_intermediate = C[:, :, :, None, :, :] * B[:, :, None, :, :, :] # shape: (b, c, l, s, h, n)
  506. G = G_intermediate.sum(dim=-1) # shape: (b, c, l, s, h)
  507. # Compute M, equivalent to applying attention mask to weights
  508. M_intermediate = G[..., None] * L.permute(0, 2, 3, 4, 1)[..., None]
  509. M = M_intermediate.sum(dim=-1)
  510. # Compute Y_diag (apply to values)
  511. Y_diag = (M[..., None] * hidden_states[:, :, None]).sum(dim=3)
  512. # 2. Compute the state for each intra-chunk
  513. # (right term of low-rank factorization of off-diagonal blocks; B terms)
  514. decay_states = torch.exp(A_cumsum[:, :, :, -1:] - A_cumsum)
  515. B_decay = B * decay_states.permute(0, -2, -1, 1)[..., None]
  516. states = (B_decay[..., None, :] * hidden_states[..., None]).sum(dim=2)
  517. # 3. Compute the inter-chunk SSM recurrence; produces correct SSM states at chunk boundaries
  518. # (middle term of factorization of off-diag blocks; A terms)
  519. if cache_params is not None and cache_position is not None and cache_position[0] > 0:
  520. previous_states = cache_params.ssm_states[self.layer_idx][:, None, ...].to(device=states.device)
  521. else:
  522. previous_states = torch.zeros_like(states[:, :1])
  523. states = torch.cat([previous_states, states], dim=1)
  524. decay_chunk = torch.exp(segment_sum(nn.functional.pad(A_cumsum[:, :, :, -1], (1, 0))))
  525. decay_chunk = decay_chunk.transpose(1, 3)
  526. new_states = (decay_chunk[..., None, None] * states[:, :, None, ...]).sum(dim=1)
  527. states, ssm_state = new_states[:, :-1], new_states[:, -1]
  528. # 4. Compute state -> output conversion per chunk
  529. # (left term of low-rank factorization of off-diagonal blocks; C terms)
  530. state_decay_out = torch.exp(A_cumsum)
  531. C_times_states = (C[..., None, :] * states[:, :, None, ...])
  532. state_decay_out_permuted = state_decay_out.permute(0, 2, 3, 1)
  533. Y_off = (C_times_states.sum(-1) * state_decay_out_permuted[..., None])
  534. # Add output of intra-chunk and inter-chunk terms (diagonal and off-diagonal blocks)
  535. y = Y_diag + Y_off
  536. # [bsz, -1, self.chunk_size, num_heads, head_dim] -> [bsz, (padded) seq_len, num_heads, head_dim]
  537. y = y.reshape(batch_size, -1, self.num_heads, self.head_dim)
  538. y = y + D_residual
  539. # Cutting off padded chunks
  540. if pad_size > 0:
  541. y = y[:, :seq_len, :, :]
  542. y = y.reshape(batch_size, seq_len, -1)
  543. # Init cache
  544. if ssm_state is not None and cache_params is not None:
  545. cache_params.update_ssm_state(layer_idx=self.layer_idx, new_ssm_state=ssm_state)
  546. scan_output = self.norm(y, gate)
  547. # end ssd naive
  548. # 4. Final linear projection
  549. contextualized_states = self.out_proj(scan_output.to(dtype)) # [batch, seq_len, hidden_size]
  550. return contextualized_states
  551. # fmt: on
  552. def forward(
  553. self,
  554. hidden_states,
  555. cache_params: Optional[Mamba2Cache] = None,
  556. cache_position: Optional[torch.LongTensor] = None,
  557. attention_mask: Optional[torch.Tensor] = None,
  558. ):
  559. if is_fast_path_available and "cuda" in self.in_proj.weight.device.type:
  560. return self.cuda_kernels_forward(hidden_states, cache_params, cache_position, attention_mask)
  561. return self.torch_forward(hidden_states, cache_params, cache_position, attention_mask)
  562. class Mamba2RMSNorm(nn.Module):
  563. def __init__(self, hidden_size, eps=1e-6):
  564. """
  565. Mamba2RMSNorm is equivalent to T5LayerNorm and LlamaRMSNorm
  566. """
  567. super().__init__()
  568. self.weight = nn.Parameter(torch.ones(hidden_size))
  569. self.variance_epsilon = eps
  570. def forward(self, hidden_states):
  571. input_dtype = hidden_states.dtype
  572. hidden_states = hidden_states.to(torch.float32)
  573. variance = hidden_states.pow(2).mean(-1, keepdim=True)
  574. hidden_states = hidden_states * torch.rsqrt(variance + self.variance_epsilon)
  575. return self.weight * hidden_states.to(input_dtype)
  576. class Mamba2Block(GradientCheckpointingLayer):
  577. def __init__(self, config, layer_idx):
  578. super().__init__()
  579. self.config = config
  580. self.layer_idx = layer_idx
  581. self.residual_in_fp32 = config.residual_in_fp32
  582. self.norm = Mamba2RMSNorm(config.hidden_size, eps=config.layer_norm_epsilon)
  583. self.mixer = Mamba2Mixer(config, layer_idx=layer_idx)
  584. def forward(
  585. self,
  586. hidden_states,
  587. cache_params: Optional[Mamba2Cache] = None,
  588. cache_position: Optional[torch.LongTensor] = None,
  589. attention_mask: Optional[torch.Tensor] = None,
  590. ):
  591. residual = hidden_states
  592. hidden_states = self.norm(hidden_states.to(dtype=self.norm.weight.dtype))
  593. if self.residual_in_fp32:
  594. residual = residual.to(torch.float32)
  595. hidden_states = self.mixer(
  596. hidden_states, cache_params=cache_params, cache_position=cache_position, attention_mask=attention_mask
  597. )
  598. hidden_states = residual + hidden_states
  599. return hidden_states
  600. @auto_docstring
  601. class Mamba2PreTrainedModel(PreTrainedModel):
  602. config: Mamba2Config
  603. base_model_prefix = "backbone"
  604. _no_split_modules = ["Mamba2Block"]
  605. supports_gradient_checkpointing = True
  606. _is_stateful = True
  607. def _init_weights(self, module):
  608. """Initialize the weights."""
  609. std = self.config.initializer_range
  610. if isinstance(module, Mamba2Mixer):
  611. # S4D real initialization. These are not discretized!
  612. # The core is to load them, compute the discrete states, then write the updated state. Keeps the memory bounded
  613. A = torch.arange(1, self.config.num_heads + 1)
  614. module.A_log.copy_(torch.log(A))
  615. module.D.data.fill_(1.0)
  616. dt = torch.exp(
  617. torch.rand(self.config.num_heads)
  618. * (math.log(self.config.time_step_max) - math.log(self.config.time_step_min))
  619. + math.log(self.config.time_step_min)
  620. ).clamp(min=self.config.time_step_floor)
  621. # # Inverse of softplus: https://github.com/pytorch/pytorch/issues/72759
  622. inv_dt = dt + torch.log(-torch.expm1(-dt))
  623. module.dt_bias.copy_(inv_dt)
  624. module.dt_bias._no_reinit = True
  625. nn.init.kaiming_uniform_(module.conv1d.weight, a=math.sqrt(5))
  626. if module.conv1d.bias is not None:
  627. if not getattr(module.conv1d.bias, "_no_reinit", False):
  628. nn.init.zeros_(module.conv1d.bias)
  629. nn.init.kaiming_uniform_(module.out_proj.weight, a=math.sqrt(5))
  630. if self.config.rescale_prenorm_residual:
  631. # Reinitialize selected weights subject to the OpenAI GPT-2 Paper Scheme:
  632. # > A modified initialization which accounts for the accumulation on the residual path with model depth. Scale
  633. # > the weights of residual layers at initialization by a factor of 1/√N where N is the # of residual layers.
  634. # > -- GPT-2 :: https://openai.com/blog/better-language-models/
  635. #
  636. # Reference (Megatron-LM): https://github.com/NVIDIA/Megatron-LM/blob/main/megatron/model/gpt_model.py
  637. # Special Scaled Initialization --> There are 2 Layer Norms per Transformer Block
  638. # Following Pytorch init, except scale by 1/sqrt(2 * n_layer)
  639. # We need to reinit p since this code could be called multiple times
  640. # Having just p *= scale would repeatedly scale it down
  641. p = module.out_proj.weight
  642. p /= math.sqrt(self.config.num_hidden_layers)
  643. if isinstance(module, nn.Linear):
  644. if not getattr(module.weight, "_no_reinit", False):
  645. nn.init.normal_(module.weight, std=std)
  646. if module.bias is not None:
  647. if not getattr(module.bias, "_no_reinit", False):
  648. nn.init.zeros_(module.bias)
  649. elif isinstance(module, (Mamba2RMSNorm, MambaRMSNormGated)):
  650. module.weight.data.fill_(1.0)
  651. elif isinstance(module, nn.Embedding):
  652. nn.init.normal_(module.weight, std=std)
  653. @dataclass
  654. @auto_docstring(
  655. custom_intro="""
  656. Class for the MAMBA2 model outputs.
  657. """
  658. )
  659. # Copied from transformers.models.mamba.modeling_mamba.MambaOutput with MAMBA->MAMBA2,Mamba->Mamba2
  660. class Mamba2Output(ModelOutput):
  661. r"""
  662. cache_params (`Mamba2Cache`):
  663. The state of the model at the last time step. Can be used in a forward method with the next `input_ids` to
  664. avoid providing the old `input_ids`.
  665. Includes both the State space model state matrices after the selective scan, and the Convolutional states
  666. """
  667. last_hidden_state: Optional[torch.FloatTensor] = None
  668. cache_params: Optional[Mamba2Cache] = None
  669. hidden_states: Optional[tuple[torch.FloatTensor]] = None
  670. @dataclass
  671. @auto_docstring(
  672. custom_intro="""
  673. Base class for causal language model (or autoregressive) outputs.
  674. """
  675. )
  676. # Copied from transformers.models.mamba.modeling_mamba.MambaCausalLMOutput with Mamba->Mamba2
  677. class Mamba2CausalLMOutput(ModelOutput):
  678. r"""
  679. loss (`torch.FloatTensor` of shape `(1,)`, *optional*, returned when `labels` is provided):
  680. Language modeling loss (for next-token prediction).
  681. logits (`torch.FloatTensor` of shape `(batch_size, sequence_length, config.vocab_size)`):
  682. Prediction scores of the language modeling head (scores for each vocabulary token before SoftMax).
  683. cache_params (`Mamba2Cache`):
  684. The state of the model at the last time step. Can be used in a forward method with the next `input_ids` to
  685. avoid providing the old `input_ids`.
  686. Includes both the State space model state matrices after the selective scan, and the Convolutional states
  687. """
  688. loss: Optional[torch.FloatTensor] = None
  689. logits: Optional[torch.FloatTensor] = None
  690. cache_params: Optional[Mamba2Cache] = None
  691. hidden_states: Optional[tuple[torch.FloatTensor]] = None
  692. @auto_docstring
  693. class Mamba2Model(Mamba2PreTrainedModel):
  694. def __init__(self, config):
  695. super().__init__(config)
  696. self.embeddings = nn.Embedding(config.vocab_size, config.hidden_size)
  697. self.layers = nn.ModuleList([Mamba2Block(config, layer_idx=idx) for idx in range(config.num_hidden_layers)])
  698. self.gradient_checkpointing = False
  699. self.norm_f = Mamba2RMSNorm(config.hidden_size, eps=config.layer_norm_epsilon)
  700. # Initialize weights and apply final processing
  701. self._register_load_state_dict_pre_hook(self.load_hook)
  702. self.post_init()
  703. def load_hook(self, state_dict, prefix, *args):
  704. for k in state_dict:
  705. if "embedding." in k:
  706. state_dict[k.replace("embedding.", "embeddings.")] = state_dict.pop(k)
  707. break
  708. def get_input_embeddings(self):
  709. return self.embeddings
  710. def set_input_embeddings(self, new_embeddings):
  711. self.embeddings = new_embeddings
  712. @auto_docstring
  713. def forward(
  714. self,
  715. input_ids: Optional[torch.LongTensor] = None,
  716. inputs_embeds: Optional[torch.LongTensor] = None,
  717. cache_params: Optional[Mamba2Cache] = None,
  718. use_cache: Optional[bool] = None,
  719. output_hidden_states: Optional[bool] = None,
  720. return_dict: Optional[bool] = None,
  721. cache_position: Optional[torch.LongTensor] = None,
  722. attention_mask: Optional[torch.Tensor] = None,
  723. **kwargs,
  724. ) -> Union[tuple, Mamba2Output]:
  725. r"""
  726. cache_params (`Mamba2Cache`, *optional*):
  727. If passed along, the model uses the previous state in all the blocks (which will give the output for the
  728. `input_ids` provided as if the model add `state_input_ids + input_ids` as context).
  729. use_cache (`bool`, *optional*):
  730. If set to `True`, the `cache_params` is returned and can be used to quickly generate the next logits.
  731. cache_position (`torch.LongTensor` of shape `(batch_size,)`, *optional*):
  732. The position of the current input in the cache. This is used to ensure that the cache is correctly updated.
  733. If `cache_params` is passed, `cache_position` should also be passed.
  734. """
  735. output_hidden_states = (
  736. output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
  737. )
  738. use_cache = use_cache if use_cache is not None else (self.config.use_cache if not self.training else False)
  739. return_dict = return_dict if return_dict is not None else self.config.use_return_dict
  740. if (input_ids is None) ^ (inputs_embeds is not None): # ^ is python for xor
  741. raise ValueError("You must specify exactly one of input_ids or inputs_embeds")
  742. if inputs_embeds is None:
  743. inputs_embeds = self.embeddings(input_ids)
  744. if self.gradient_checkpointing and self.training and use_cache:
  745. use_cache = False
  746. if use_cache:
  747. if cache_params is None:
  748. cache_params = Mamba2Cache(
  749. self.config, inputs_embeds.size(0), device=inputs_embeds.device, dtype=inputs_embeds.dtype
  750. )
  751. cache_position = torch.arange(0, self.config.conv_kernel, device=inputs_embeds.device)
  752. elif cache_position is None:
  753. # cases when we do manual forward instead of using `model.generate` which will initiate
  754. # `cache_position` and makes sure it is not None, throw error here instead of doing some
  755. # hack to conjecture the current cache position
  756. raise ValueError(
  757. "You have to specify the `cache_position` manually when `use_cache=True` and `cache_params` is passed, "
  758. "you don't have to pass a `cache_params` if you are in prefilling stage because in that case it will "
  759. "be initialized for you automatically"
  760. )
  761. else:
  762. cache_params = None
  763. hidden_states = inputs_embeds
  764. all_hidden_states = () if output_hidden_states else None
  765. for mixer_block in self.layers:
  766. hidden_states = mixer_block(
  767. hidden_states,
  768. cache_params=cache_params,
  769. cache_position=cache_position,
  770. attention_mask=attention_mask,
  771. )
  772. if output_hidden_states:
  773. all_hidden_states = all_hidden_states + (hidden_states,)
  774. hidden_states = self.norm_f(hidden_states)
  775. if output_hidden_states:
  776. all_hidden_states = all_hidden_states + (hidden_states,)
  777. if not return_dict:
  778. return tuple(v for v in [hidden_states, cache_params, all_hidden_states] if v is not None)
  779. return Mamba2Output(
  780. last_hidden_state=hidden_states,
  781. cache_params=cache_params if use_cache else None,
  782. hidden_states=all_hidden_states,
  783. )
  784. @auto_docstring(
  785. custom_intro="""
  786. The MAMBA2 Model transformer with a language modeling head on top (linear layer with weights not tied to the input
  787. embeddings).
  788. """
  789. )
  790. class Mamba2ForCausalLM(Mamba2PreTrainedModel, GenerationMixin):
  791. _tied_weights_keys = []
  792. def __init__(self, config):
  793. super().__init__(config)
  794. self.backbone = Mamba2Model(config)
  795. self.lm_head = nn.Linear(config.hidden_size, config.vocab_size, bias=False)
  796. # Initialize weights and apply final processing
  797. self.post_init()
  798. def get_input_embeddings(self):
  799. return self.backbone.get_input_embeddings()
  800. def set_input_embeddings(self, new_embeddings):
  801. return self.backbone.set_input_embeddings(new_embeddings)
  802. def prepare_inputs_for_generation(
  803. self,
  804. input_ids,
  805. inputs_embeds=None,
  806. use_cache=None,
  807. cache_params: Optional[Mamba2Cache] = None,
  808. cache_position: Optional[torch.LongTensor] = None,
  809. attention_mask: Optional[torch.Tensor] = None,
  810. **kwargs,
  811. ):
  812. # Overwritten -- uses `cache_params` as opposed to `past_key_values`
  813. model_inputs = {"input_ids": input_ids.contiguous()}
  814. if use_cache and cache_params is None:
  815. # we initialize the `cache_position` to full size of `conv_states` at prefill stage
  816. # considering padding will be applied when input length is shorter, and truncation
  817. # will be applied when it is longer, so it will be equivalent to always have it match
  818. # the length of `cache_params.conv_states`, which is `config.conv_kernel`
  819. cache_position = torch.arange(0, self.backbone.config.conv_kernel, device=input_ids.device)
  820. if inputs_embeds is not None:
  821. model_inputs = {"inputs_embeds": inputs_embeds}
  822. max_batch_size = inputs_embeds.size(0)
  823. else:
  824. max_batch_size = input_ids.size(0)
  825. cache_params = Mamba2Cache(self.backbone.config, max_batch_size, device=self.device, dtype=self.dtype)
  826. if use_cache and cache_position[0] > 0:
  827. model_inputs["input_ids"] = input_ids[:, -1].unsqueeze(-1).contiguous()
  828. attention_mask = None
  829. if not use_cache and inputs_embeds is not None:
  830. model_inputs = {"inputs_embeds": inputs_embeds}
  831. model_inputs.update(
  832. {
  833. "cache_params": cache_params,
  834. "use_cache": use_cache,
  835. "cache_position": cache_position,
  836. "attention_mask": attention_mask,
  837. }
  838. )
  839. # Forward ALL kwargs that are uninitialized (e.g. `use_cache`).
  840. for key, value in kwargs.items():
  841. if key not in model_inputs:
  842. model_inputs[key] = value
  843. return model_inputs
  844. @auto_docstring
  845. def forward(
  846. self,
  847. input_ids: Optional[torch.LongTensor] = None,
  848. inputs_embeds: Optional[torch.FloatTensor] = None,
  849. cache_params: Optional[Mamba2Cache] = None,
  850. labels: Optional[torch.LongTensor] = None,
  851. output_hidden_states: Optional[bool] = None,
  852. return_dict: Optional[bool] = None,
  853. use_cache: Optional[bool] = None,
  854. cache_position: Optional[torch.Tensor] = None,
  855. attention_mask: Optional[torch.Tensor] = None,
  856. **kwargs, # for now we need this for generation and loss_function
  857. ) -> Union[tuple, Mamba2CausalLMOutput]:
  858. r"""
  859. cache_params (`Mamba2Cache`, *optional*):
  860. If passed along, the model uses the previous state in all the blocks (which will give the output for the
  861. `input_ids` provided as if the model add `state_input_ids + input_ids` as context).
  862. labels (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*):
  863. Labels for language modeling. Note that the labels **are shifted** inside the model, i.e. you can set
  864. `labels = input_ids` Indices are selected in `[-100, 0, ..., config.vocab_size]` All labels set to `-100`
  865. are ignored (masked), the loss is only computed for labels in `[0, ..., config.vocab_size]`
  866. use_cache (`bool`, *optional*):
  867. If set to `True`, the `cache_params` is returned and can be used to quickly generate the next logits.
  868. cache_position (`torch.LongTensor` of shape `(batch_size,)`, *optional*):
  869. The position of the current input in the cache. This is used to ensure that the cache is correctly updated.
  870. If `cache_params` is passed, `cache_position` should also be passed.
  871. """
  872. return_dict = return_dict if return_dict is not None else self.config.use_return_dict
  873. mamba2_outputs = self.backbone(
  874. input_ids,
  875. cache_params=cache_params,
  876. inputs_embeds=inputs_embeds,
  877. output_hidden_states=output_hidden_states,
  878. return_dict=return_dict,
  879. use_cache=use_cache,
  880. cache_position=cache_position,
  881. attention_mask=attention_mask,
  882. )
  883. hidden_states = mamba2_outputs[0]
  884. logits = self.lm_head(hidden_states.to(self.lm_head.weight.dtype)).float()
  885. loss = None
  886. if labels is not None:
  887. loss = self.loss_function(logits=logits, labels=labels, vocab_size=self.config.vocab_size, **kwargs)
  888. if not return_dict:
  889. output = (logits,) + mamba2_outputs[1:]
  890. return ((loss,) + output) if loss is not None else output
  891. return Mamba2CausalLMOutput(
  892. loss=loss,
  893. logits=logits,
  894. cache_params=mamba2_outputs.cache_params,
  895. hidden_states=mamba2_outputs.hidden_states,
  896. )
  897. __all__ = ["Mamba2ForCausalLM", "Mamba2Model", "Mamba2PreTrainedModel"]