modeling_mamba.py 39 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525526527528529530531532533534535536537538539540541542543544545546547548549550551552553554555556557558559560561562563564565566567568569570571572573574575576577578579580581582583584585586587588589590591592593594595596597598599600601602603604605606607608609610611612613614615616617618619620621622623624625626627628629630631632633634635636637638639640641642643644645646647648649650651652653654655656657658659660661662663664665666667668669670671672673674675676677678679680681682683684685686687688689690691692693694695696697698699700701702703704705706707708709710711712713714715716717718719720721722723724725726727728729730731732733734735736737738739740741742743744745746747748749750751752753754755756757758759760761762763764765766767768769770771772773774775776777778779780781782783784785786787788789790791792793794795796797798799800801802803804805806807808809810811812813814815816817818819820821822823824825826827828829830831832833834835836837838839840841842843844845846847848849850851852853854855856857858859860861862863864865866867868869870871872873874875876877
  1. # coding=utf-8
  2. # Copyright 2024 state-spaces/mamba org and HuggingFace Inc. team.
  3. #
  4. # Licensed under the Apache License, Version 2.0 (the "License");
  5. # you may not use this file except in compliance with the License.
  6. # You may obtain a copy of the License at
  7. #
  8. # http://www.apache.org/licenses/LICENSE-2.0
  9. #
  10. # Unless required by applicable law or agreed to in writing, software
  11. # distributed under the License is distributed on an "AS IS" BASIS,
  12. # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  13. # See the License for the specific language governing permissions and
  14. # limitations under the License.
  15. """PyTorch MAMBA model."""
  16. import math
  17. from dataclasses import dataclass
  18. from typing import Any, Optional, Union
  19. import torch
  20. from torch import nn
  21. from torch.nn import CrossEntropyLoss
  22. from ...activations import ACT2FN
  23. from ...configuration_utils import PretrainedConfig
  24. from ...generation import GenerationMixin
  25. from ...modeling_layers import GradientCheckpointingLayer
  26. from ...modeling_utils import PreTrainedModel
  27. from ...utils import (
  28. ModelOutput,
  29. auto_docstring,
  30. logging,
  31. )
  32. from ...utils.import_utils import (
  33. is_causal_conv1d_available,
  34. is_kernels_available,
  35. is_mamba_ssm_available,
  36. is_mambapy_available,
  37. )
  38. from .configuration_mamba import MambaConfig
  39. logger = logging.get_logger(__name__)
  40. if is_mambapy_available():
  41. from mambapy.pscan import pscan
  42. else:
  43. pscan = None
  44. if is_mamba_ssm_available():
  45. from mamba_ssm.ops.selective_scan_interface import mamba_inner_fn, selective_scan_fn
  46. from mamba_ssm.ops.triton.selective_state_update import selective_state_update
  47. else:
  48. selective_state_update, selective_scan_fn, mamba_inner_fn = None, None, None
  49. _causal_conv1d_cache = None
  50. def _lazy_load_causal_conv1d():
  51. global _causal_conv1d_cache
  52. if _causal_conv1d_cache is not None:
  53. return _causal_conv1d_cache
  54. if is_kernels_available():
  55. from kernels import get_kernel
  56. _causal_conv1d_kernel = get_kernel("kernels-community/causal-conv1d")
  57. _causal_conv1d_cache = (_causal_conv1d_kernel.causal_conv1d_update, _causal_conv1d_kernel.causal_conv1d_fn)
  58. elif is_causal_conv1d_available():
  59. from causal_conv1d import causal_conv1d_fn, causal_conv1d_update
  60. _causal_conv1d_cache = (causal_conv1d_update, causal_conv1d_fn)
  61. else:
  62. _causal_conv1d_cache = (None, None)
  63. return _causal_conv1d_cache
  64. class MambaCache:
  65. """
  66. Cache for mamba model which does not have attention mechanism and key value states.
  67. Arguments:
  68. config (`PretrainedConfig):
  69. The configuration file defining the shape-related attributes required to initialize the static cache.
  70. max_batch_size (`int`):
  71. The maximum batch size with which the model will be used. Note that a new instance must be instantiated if
  72. a smaller batch size is used.
  73. dtype (`torch.dtype`, *optional*, defaults to `torch.float16`):
  74. The default `dtype` to use when initializing the layer.
  75. device (`torch.device` or `str`, *optional*):
  76. The device on which the cache should be initialized. Should be the same as the layer.
  77. Example:
  78. ```python
  79. >>> import torch
  80. >>> from transformers import AutoTokenizer, MambaForCausalLM, MambaCache
  81. >>> model = MambaForCausalLM.from_pretrained("state-spaces/mamba-130m-hf")
  82. >>> tokenizer = AutoTokenizer.from_pretrained("state-spaces/mamba-130m-hf")
  83. >>> inputs = tokenizer(text="My name is Mamba", return_tensors="pt")
  84. >>> # Prepare a cache class and pass it to model's forward
  85. >>> cache_params = MambaCache(config=model.config, max_batch_size=1, device=model.device, dtype=model.dtype)
  86. >>> cache_position = torch.arange(len(inputs["input_ids"][0]), device=model.device) # sequence length
  87. >>> outputs = model(**inputs, cache_params=cache_params, cache_position=cache_position, use_cache=True)
  88. >>> outputs.cache_params
  89. ```
  90. """
  91. is_compileable = True
  92. # TODO (joao): add layer_device_map arg and update code in `generate` accordingly
  93. def __init__(
  94. self,
  95. config: PretrainedConfig,
  96. max_batch_size: int,
  97. dtype: torch.dtype = torch.float16,
  98. device: Union[torch.device, str, None] = None,
  99. ):
  100. self.max_batch_size = max_batch_size
  101. self._dtype = dtype
  102. self.intermediate_size = config.intermediate_size
  103. self.ssm_state_size = config.state_size
  104. self.conv_kernel_size = config.conv_kernel
  105. self.conv_states: list[torch.Tensor] = []
  106. self.ssm_states: list[torch.Tensor] = []
  107. device = torch.device(device) if device is not None else None
  108. for _ in range(config.num_hidden_layers):
  109. conv_state: torch.Tensor = torch.zeros(
  110. self.max_batch_size,
  111. self.intermediate_size,
  112. self.conv_kernel_size,
  113. device=device,
  114. dtype=self._dtype,
  115. )
  116. ssm_state: torch.Tensor = torch.zeros(
  117. self.max_batch_size,
  118. self.intermediate_size,
  119. self.ssm_state_size,
  120. device=device,
  121. dtype=self._dtype,
  122. )
  123. torch._dynamo.mark_static_address(conv_state)
  124. torch._dynamo.mark_static_address(ssm_state)
  125. self.conv_states.append(conv_state)
  126. self.ssm_states.append(ssm_state)
  127. def update_conv_state(
  128. self, layer_idx: int, new_conv_state: torch.Tensor, cache_position: torch.LongTensor
  129. ) -> torch.Tensor:
  130. # This `if` blocks is only reached in multigpu and if `layer_device_map` is not passed. It is used
  131. # when the cache is initialized in the forward pass (e.g. Mamba)
  132. if self.conv_states[layer_idx].device != new_conv_state.device:
  133. self.conv_states[layer_idx] = self.conv_states[layer_idx].to(new_conv_state.device)
  134. conv_state = self.conv_states[layer_idx]
  135. cache_position = cache_position.clamp(0, self.conv_kernel_size - 1)
  136. conv_state = conv_state.roll(shifts=-1, dims=-1)
  137. conv_state[:, :, cache_position] = new_conv_state.to(device=conv_state.device, dtype=conv_state.dtype)
  138. self.conv_states[layer_idx].zero_()
  139. self.conv_states[layer_idx] += conv_state
  140. return self.conv_states[layer_idx]
  141. def update_ssm_state(self, layer_idx: int, new_ssm_state: torch.Tensor):
  142. self.ssm_states[layer_idx].zero_()
  143. self.ssm_states[layer_idx] += new_ssm_state.to(self.ssm_states[layer_idx].device)
  144. return self.ssm_states[layer_idx]
  145. def reset(self):
  146. for layer_idx in range(len(self.conv_states)):
  147. # In-place ops prevent breaking the static address
  148. self.conv_states[layer_idx].zero_()
  149. self.ssm_states[layer_idx].zero_()
  150. class MambaMixer(nn.Module):
  151. """
  152. Compute ∆, A, B, C, and D the state space parameters and compute the `contextualized_states`.
  153. A, D are input independent (see Mamba paper [1] Section 3.5.2 "Interpretation of A" for why A isn't selective)
  154. ∆, B, C are input-dependent (this is a key difference between Mamba and the linear time invariant S4,
  155. and is why Mamba is called **selective** state spaces)
  156. """
  157. def __init__(self, config: MambaConfig, layer_idx: int):
  158. super().__init__()
  159. self.config = config
  160. self.hidden_size = config.hidden_size
  161. self.ssm_state_size = config.state_size
  162. self.conv_kernel_size = config.conv_kernel
  163. self.intermediate_size = config.intermediate_size
  164. self.time_step_rank = int(config.time_step_rank)
  165. self.layer_idx = layer_idx
  166. self.use_conv_bias = config.use_conv_bias
  167. self.conv1d = nn.Conv1d(
  168. in_channels=self.intermediate_size,
  169. out_channels=self.intermediate_size,
  170. bias=config.use_conv_bias,
  171. kernel_size=config.conv_kernel,
  172. groups=self.intermediate_size,
  173. padding=config.conv_kernel - 1,
  174. )
  175. self.activation = config.hidden_act
  176. self.act = ACT2FN[config.hidden_act]
  177. self.use_mambapy = config.use_mambapy
  178. # projection of the input hidden states
  179. self.in_proj = nn.Linear(self.hidden_size, self.intermediate_size * 2, bias=config.use_bias)
  180. # selective projection used to make dt, B and C input dependent
  181. self.x_proj = nn.Linear(self.intermediate_size, self.time_step_rank + self.ssm_state_size * 2, bias=False)
  182. # time step projection (discretization)
  183. self.dt_proj = nn.Linear(self.time_step_rank, self.intermediate_size, bias=True)
  184. # S4D real initialization. These are not discretized!
  185. # The core is to load them, compute the discrete states, then write the updated state. Keeps the memory bounded
  186. A = torch.arange(1, self.ssm_state_size + 1, dtype=torch.float32)[None, :]
  187. A = A.expand(self.intermediate_size, -1).contiguous()
  188. self.A_log = nn.Parameter(torch.log(A))
  189. self.D = nn.Parameter(torch.ones(self.intermediate_size))
  190. self.out_proj = nn.Linear(self.intermediate_size, self.hidden_size, bias=config.use_bias)
  191. self.use_bias = config.use_bias
  192. self.warn_slow_implementation()
  193. def warn_slow_implementation(self):
  194. causal_conv1d_update, causal_conv1d_fn = _lazy_load_causal_conv1d()
  195. is_fast_path_available = all(
  196. (selective_state_update, selective_scan_fn, causal_conv1d_fn, causal_conv1d_update, mamba_inner_fn)
  197. )
  198. if not is_fast_path_available:
  199. if self.use_mambapy:
  200. if is_mambapy_available():
  201. logger.warning_once(
  202. "The fast path is not available because one of `(selective_state_update, selective_scan_fn, causal_conv1d_fn, causal_conv1d_update, mamba_inner_fn)`"
  203. " is None. Falling back to the mamba.py backend. To install follow https://github.com/state-spaces/mamba/#installation for mamba-ssm and"
  204. " install the kernels library using `pip install kernels` or https://github.com/Dao-AILab/causal-conv1d for causal-conv1d"
  205. )
  206. else:
  207. raise ImportError(
  208. "use_mambapy is set to True but the mambapy package is not installed. To install it follow https://github.com/alxndrTL/mamba.py."
  209. )
  210. else:
  211. logger.warning_once(
  212. "The fast path is not available because one of `(selective_state_update, selective_scan_fn, causal_conv1d_fn, causal_conv1d_update, mamba_inner_fn)`"
  213. " is None. Falling back to the sequential implementation of Mamba, as use_mambapy is set to False. To install follow https://github.com/state-spaces/mamba/#installation for mamba-ssm and"
  214. " install the kernels library using `pip install kernels` or https://github.com/Dao-AILab/causal-conv1d for causal-conv1d. For the mamba.py backend, follow https://github.com/alxndrTL/mamba.py."
  215. )
  216. def cuda_kernels_forward(
  217. self,
  218. hidden_states: torch.Tensor,
  219. cache_params: Optional[MambaCache] = None,
  220. cache_position: Optional[torch.LongTensor] = None,
  221. attention_mask: Optional[torch.LongTensor] = None,
  222. ):
  223. # 1. Gated MLP's linear projection
  224. projected_states = self.in_proj(hidden_states).transpose(1, 2)
  225. if self.training and cache_params is None: # Doesn't support outputting the states -> used for training
  226. contextualized_states = mamba_inner_fn(
  227. projected_states,
  228. self.conv1d.weight,
  229. self.conv1d.bias if self.use_conv_bias else None,
  230. self.x_proj.weight,
  231. self.dt_proj.weight,
  232. self.out_proj.weight,
  233. self.out_proj.bias.float() if self.use_bias else None,
  234. -torch.exp(self.A_log.float()),
  235. None, # input-dependent B
  236. None, # input-dependent C
  237. self.D.float(),
  238. delta_bias=self.dt_proj.bias.float(),
  239. delta_softplus=True,
  240. )
  241. else:
  242. causal_conv1d_update, causal_conv1d_fn = _lazy_load_causal_conv1d()
  243. hidden_states, gate = projected_states.chunk(2, dim=1)
  244. if attention_mask is not None:
  245. hidden_states = hidden_states * attention_mask.unsqueeze(1)
  246. # 2. Convolution sequence transformation
  247. conv_weights = self.conv1d.weight.view(self.conv1d.weight.size(0), self.conv1d.weight.size(2))
  248. if cache_params is not None and cache_position[0] > 0:
  249. hidden_states = causal_conv1d_update(
  250. hidden_states.squeeze(-1),
  251. cache_params.conv_states[self.layer_idx],
  252. conv_weights,
  253. self.conv1d.bias,
  254. self.activation,
  255. )
  256. hidden_states = hidden_states.unsqueeze(-1)
  257. else:
  258. if cache_params is not None:
  259. conv_states = nn.functional.pad(
  260. hidden_states, (self.conv_kernel_size - hidden_states.shape[-1], 0)
  261. )
  262. cache_params.update_conv_state(self.layer_idx, conv_states, cache_position)
  263. hidden_states = causal_conv1d_fn(
  264. hidden_states, conv_weights, self.conv1d.bias, activation=self.activation
  265. )
  266. if attention_mask is not None:
  267. hidden_states = hidden_states * attention_mask.unsqueeze(1)
  268. # 3. State Space Model sequence transformation
  269. # 3.a. input varying initialization of time_step, B and C
  270. ssm_parameters = self.x_proj(hidden_states.transpose(1, 2))
  271. time_step, B, C = torch.split(
  272. ssm_parameters, [self.time_step_rank, self.ssm_state_size, self.ssm_state_size], dim=-1
  273. )
  274. discrete_time_step = self.dt_proj.weight @ time_step.transpose(1, 2)
  275. A = -torch.exp(self.A_log.float())
  276. # 3.c perform the recurrence y ← SSM(A, B, C)(x)
  277. time_proj_bias = self.dt_proj.bias.float() if hasattr(self.dt_proj, "bias") else None
  278. if cache_params is not None and cache_position[0] > 0:
  279. scan_outputs = selective_state_update(
  280. cache_params.ssm_states[self.layer_idx],
  281. hidden_states[..., 0],
  282. discrete_time_step[..., 0],
  283. A,
  284. B[:, 0],
  285. C[:, 0],
  286. self.D,
  287. gate[..., 0],
  288. time_proj_bias,
  289. dt_softplus=True,
  290. ).unsqueeze(-1)
  291. else:
  292. scan_outputs, ssm_state = selective_scan_fn(
  293. hidden_states,
  294. discrete_time_step,
  295. A,
  296. B.transpose(1, 2),
  297. C.transpose(1, 2),
  298. self.D.float(),
  299. gate,
  300. time_proj_bias,
  301. delta_softplus=True,
  302. return_last_state=True,
  303. )
  304. if ssm_state is not None and cache_params is not None:
  305. cache_params.update_ssm_state(self.layer_idx, ssm_state)
  306. # 4. Final linear projection
  307. contextualized_states = self.out_proj(scan_outputs.transpose(1, 2))
  308. return contextualized_states
  309. # fmt: off
  310. def slow_forward(self, input_states, cache_params: Optional[MambaCache]=None, cache_position:Optional[torch.LongTensor]=None, attention_mask: Optional[torch.LongTensor] = None):
  311. batch_size, seq_len, _ = input_states.shape
  312. dtype = input_states.dtype
  313. # 1. Gated MLP's linear projection
  314. projected_states = self.in_proj(input_states).transpose(1, 2) # [batch, 2 * intermediate_size, seq_len]
  315. hidden_states, gate = projected_states.chunk(2, dim=1)
  316. if attention_mask is not None:
  317. hidden_states = hidden_states * attention_mask.unsqueeze(1)
  318. # 2. Convolution sequence transformation
  319. if cache_params is not None:
  320. ssm_state = cache_params.ssm_states[self.layer_idx].clone()
  321. ssm_state = ssm_state.to(hidden_states.device)
  322. # use `cache_position.shape[0]` to check whether we are in prefill
  323. # stage, it's equivalent to check `cache_position[0] == 0`, which
  324. # breaks dynamo fullgraph constraints
  325. if cache_position.shape[0] == self.conv_kernel_size:
  326. conv_state = nn.functional.pad(
  327. hidden_states,
  328. (self.conv_kernel_size - hidden_states.shape[-1], 0)
  329. )
  330. cache_params.update_conv_state(self.layer_idx, conv_state, cache_position)
  331. hidden_states = self.act(self.conv1d(hidden_states)[..., :seq_len]) # [batch, intermediate_size, seq_len]
  332. else:
  333. conv_state = cache_params.update_conv_state(self.layer_idx, hidden_states, cache_position)
  334. conv_state = conv_state.to(self.conv1d.weight.device)
  335. hidden_states = torch.sum(conv_state * self.conv1d.weight[:, 0, :], dim=-1)
  336. if self.use_conv_bias:
  337. hidden_states += self.conv1d.bias
  338. hidden_states = self.act(hidden_states).to(dtype).unsqueeze(-1) # [batch, intermediate_size, 1] : decoding
  339. else:
  340. ssm_state = torch.zeros(
  341. (batch_size, self.intermediate_size, self.ssm_state_size),
  342. device=hidden_states.device, dtype=dtype
  343. )
  344. hidden_states = self.act(self.conv1d(hidden_states)[..., :seq_len]) # [batch, intermediate_size, seq_len]
  345. if attention_mask is not None:
  346. hidden_states = hidden_states * attention_mask.unsqueeze(1)
  347. # 3. State Space Model sequence transformation
  348. # 3.a. Selection: [batch, seq_len, self.time_step_rank + self.ssm_state_size * 2]
  349. ssm_parameters = self.x_proj(hidden_states.transpose(1, 2))
  350. time_step, B, C = torch.split(
  351. ssm_parameters, [self.time_step_rank, self.ssm_state_size, self.ssm_state_size], dim=-1
  352. )
  353. discrete_time_step = self.dt_proj(time_step) # [batch, seq_len, intermediate_size]
  354. discrete_time_step = nn.functional.softplus(discrete_time_step).transpose(1, 2) # [batch, intermediate_size, seq_len]
  355. # 3.b. Discretization: B and C to [batch, seq_len, intermediate_size, ssm_state_size] (SRAM)
  356. A = -torch.exp(self.A_log.float()) # [intermediate_size, ssm_state_size]
  357. discrete_A = torch.exp(A[None, :, None, :] * discrete_time_step[:, :, :, None]) # [batch, intermediate_size, seq_len, ssm_state_size]
  358. discrete_B = discrete_time_step[:, :, :, None] * B[:, None, :, :].float() # [batch, intermediate_size, seq_len, ssm_state_size]
  359. deltaB_u = discrete_B * hidden_states[:, :, :, None].float()
  360. # 3.c perform the recurrence y ← SSM(A, B, C)(x)
  361. if self.use_mambapy and self.training and cache_params is None:
  362. hs = pscan(discrete_A.transpose(1, 2), deltaB_u.transpose(1, 2)) # [batch, seq_len, intermediate_size, ssm_state_size]
  363. scan_output = (hs @ C.unsqueeze(-1)).squeeze(3).transpose(1, 2) # [batch, intermediate_size, seq_len]
  364. scan_output = scan_output + hidden_states * self.D[None, :, None]
  365. scan_output = scan_output * self.act(gate)
  366. else:
  367. scan_outputs = []
  368. for i in range(seq_len):
  369. ssm_state = discrete_A[:, :, i, :] * ssm_state + deltaB_u[:, :, i, :] # [batch, intermediate_size, ssm_state]
  370. scan_output = torch.matmul(ssm_state.to(dtype), C[:, i, :].unsqueeze(-1)) # [batch, intermediate_size, 1]
  371. scan_outputs.append(scan_output[:, :, 0])
  372. scan_output = torch.stack(scan_outputs, dim=-1) # [batch, intermediate_size, seq_len]
  373. scan_output = scan_output + (hidden_states * self.D[None, :, None])
  374. scan_output = (scan_output * self.act(gate))
  375. if cache_params is not None:
  376. cache_params.ssm_states[self.layer_idx].copy_(ssm_state)
  377. # 4. Final linear projection
  378. contextualized_states = self.out_proj(scan_output.transpose(1, 2)) # [batch, seq_len, hidden_size]
  379. return contextualized_states
  380. # fmt: on
  381. def forward(
  382. self,
  383. hidden_states,
  384. cache_params: Optional[MambaCache] = None,
  385. cache_position: Optional[torch.LongTensor] = None,
  386. attention_mask: Optional[torch.LongTensor] = None,
  387. ):
  388. causal_conv1d_update, causal_conv1d_fn = _lazy_load_causal_conv1d()
  389. is_fast_path_available = all(
  390. (selective_state_update, selective_scan_fn, causal_conv1d_fn, causal_conv1d_update, mamba_inner_fn)
  391. )
  392. if is_fast_path_available and "cuda" in self.x_proj.weight.device.type and not torch._dynamo.is_compiling():
  393. return self.cuda_kernels_forward(hidden_states, cache_params, cache_position, attention_mask)
  394. return self.slow_forward(hidden_states, cache_params, cache_position, attention_mask)
  395. class MambaRMSNorm(nn.Module):
  396. def __init__(self, hidden_size, eps=1e-6):
  397. """
  398. MambaRMSNorm is equivalent to T5LayerNorm and LlamaRMSNorm
  399. """
  400. super().__init__()
  401. self.weight = nn.Parameter(torch.ones(hidden_size))
  402. self.variance_epsilon = eps
  403. def forward(self, hidden_states):
  404. input_dtype = hidden_states.dtype
  405. hidden_states = hidden_states.to(torch.float32)
  406. variance = hidden_states.pow(2).mean(-1, keepdim=True)
  407. hidden_states = hidden_states * torch.rsqrt(variance + self.variance_epsilon)
  408. return self.weight * hidden_states.to(input_dtype)
  409. def extra_repr(self):
  410. return f"{self.weight.shape[0]}, eps={self.variance_epsilon}"
  411. class MambaBlock(GradientCheckpointingLayer):
  412. def __init__(self, config, layer_idx):
  413. super().__init__()
  414. self.config = config
  415. self.layer_idx = layer_idx
  416. self.residual_in_fp32 = config.residual_in_fp32
  417. self.norm = MambaRMSNorm(config.hidden_size, eps=config.layer_norm_epsilon)
  418. self.mixer = MambaMixer(config, layer_idx=layer_idx)
  419. def forward(
  420. self,
  421. hidden_states,
  422. cache_params: Optional[MambaCache] = None,
  423. cache_position: Optional[torch.LongTensor] = None,
  424. attention_mask: Optional[torch.LongTensor] = None,
  425. ):
  426. residual = hidden_states
  427. hidden_states = self.norm(hidden_states.to(dtype=self.norm.weight.dtype))
  428. if self.residual_in_fp32:
  429. residual = residual.to(torch.float32)
  430. hidden_states = self.mixer(
  431. hidden_states, cache_params=cache_params, cache_position=cache_position, attention_mask=attention_mask
  432. )
  433. hidden_states = residual + hidden_states
  434. return hidden_states
  435. @auto_docstring
  436. class MambaPreTrainedModel(PreTrainedModel):
  437. config: MambaConfig
  438. base_model_prefix = "backbone"
  439. _no_split_modules = ["MambaBlock", "MambaMixer"]
  440. supports_gradient_checkpointing = True
  441. _is_stateful = True
  442. def _init_weights(self, module):
  443. """Initialize the weights."""
  444. std = self.config.initializer_range
  445. if isinstance(module, MambaMixer):
  446. # S4D real initialization. These are not discretized!
  447. # The core is to load them, compute the discrete states, then write the updated state. Keeps the memory bounded
  448. A = torch.arange(1, module.ssm_state_size + 1, dtype=torch.float32)[None, :]
  449. A = A.expand(module.intermediate_size, -1).contiguous()
  450. module.A_log.copy_(torch.log(A))
  451. module.D.data.fill_(1.0)
  452. dt_init_std = self.config.time_step_rank**-0.5 * self.config.time_step_scale
  453. if self.config.time_step_init_scheme == "constant":
  454. nn.init.constant_(module.dt_proj.weight, dt_init_std)
  455. elif self.config.time_step_init_scheme == "random":
  456. nn.init.uniform_(module.dt_proj.weight, -dt_init_std, dt_init_std)
  457. dt = torch.exp(
  458. torch.rand(self.config.intermediate_size)
  459. * (math.log(self.config.time_step_max) - math.log(self.config.time_step_min))
  460. + math.log(self.config.time_step_min)
  461. ).clamp(min=self.config.time_step_floor)
  462. # # Inverse of softplus: https://github.com/pytorch/pytorch/issues/72759
  463. inv_dt = dt + torch.log(-torch.expm1(-dt))
  464. module.dt_proj.bias.copy_(inv_dt)
  465. module.dt_proj.bias._no_reinit = True
  466. nn.init.kaiming_uniform_(module.conv1d.weight, a=math.sqrt(5))
  467. if module.conv1d.bias is not None:
  468. if not getattr(module.conv1d.bias, "_no_reinit", False):
  469. nn.init.zeros_(module.conv1d.bias)
  470. nn.init.kaiming_uniform_(module.out_proj.weight, a=math.sqrt(5))
  471. if self.config.rescale_prenorm_residual:
  472. # Reinitialize selected weights subject to the OpenAI GPT-2 Paper Scheme:
  473. # > A modified initialization which accounts for the accumulation on the residual path with model depth. Scale
  474. # > the weights of residual layers at initialization by a factor of 1/√N where N is the # of residual layers.
  475. # > -- GPT-2 :: https://openai.com/blog/better-language-models/
  476. #
  477. # Reference (Megatron-LM): https://github.com/NVIDIA/Megatron-LM/blob/main/megatron/model/gpt_model.py
  478. # Special Scaled Initialization --> There are 2 Layer Norms per Transformer Block
  479. # Following Pytorch init, except scale by 1/sqrt(2 * n_layer)
  480. # We need to reinit p since this code could be called multiple times
  481. # Having just p *= scale would repeatedly scale it down
  482. p = module.out_proj.weight
  483. p /= math.sqrt(self.config.num_hidden_layers)
  484. if isinstance(module, nn.Linear):
  485. if not getattr(module.weight, "_no_reinit", False):
  486. nn.init.normal_(module.weight, std=std)
  487. if module.bias is not None:
  488. if not getattr(module.bias, "_no_reinit", False):
  489. nn.init.zeros_(module.bias)
  490. elif isinstance(module, MambaRMSNorm):
  491. module.weight.data.fill_(1.0)
  492. elif isinstance(module, nn.Embedding):
  493. nn.init.normal_(module.weight, std=std)
  494. @dataclass
  495. @auto_docstring(
  496. custom_intro="""
  497. Class for the MAMBA model outputs.
  498. """
  499. )
  500. class MambaOutput(ModelOutput):
  501. r"""
  502. cache_params (`MambaCache`):
  503. The state of the model at the last time step. Can be used in a forward method with the next `input_ids` to
  504. avoid providing the old `input_ids`.
  505. Includes both the State space model state matrices after the selective scan, and the Convolutional states
  506. """
  507. last_hidden_state: Optional[torch.FloatTensor] = None
  508. cache_params: Optional[MambaCache] = None
  509. hidden_states: Optional[tuple[torch.FloatTensor]] = None
  510. @dataclass
  511. @auto_docstring(
  512. custom_intro="""
  513. Base class for causal language model (or autoregressive) outputs.
  514. """
  515. )
  516. class MambaCausalLMOutput(ModelOutput):
  517. r"""
  518. loss (`torch.FloatTensor` of shape `(1,)`, *optional*, returned when `labels` is provided):
  519. Language modeling loss (for next-token prediction).
  520. logits (`torch.FloatTensor` of shape `(batch_size, sequence_length, config.vocab_size)`):
  521. Prediction scores of the language modeling head (scores for each vocabulary token before SoftMax).
  522. cache_params (`MambaCache`):
  523. The state of the model at the last time step. Can be used in a forward method with the next `input_ids` to
  524. avoid providing the old `input_ids`.
  525. Includes both the State space model state matrices after the selective scan, and the Convolutional states
  526. """
  527. loss: Optional[torch.FloatTensor] = None
  528. logits: Optional[torch.FloatTensor] = None
  529. cache_params: Optional[MambaCache] = None
  530. hidden_states: Optional[tuple[torch.FloatTensor]] = None
  531. @auto_docstring
  532. class MambaModel(MambaPreTrainedModel):
  533. def __init__(self, config):
  534. super().__init__(config)
  535. self.embeddings = nn.Embedding(config.vocab_size, config.hidden_size)
  536. self.layers = nn.ModuleList([MambaBlock(config, layer_idx=idx) for idx in range(config.num_hidden_layers)])
  537. self.gradient_checkpointing = False
  538. self.norm_f = MambaRMSNorm(config.hidden_size, eps=config.layer_norm_epsilon)
  539. # Initialize weights and apply final processing
  540. self._register_load_state_dict_pre_hook(self.load_hook)
  541. self.post_init()
  542. def load_hook(self, state_dict, prefix, *args):
  543. for k in state_dict:
  544. if "embedding." in k:
  545. state_dict[k.replace("embedding.", "embeddings.")] = state_dict.pop(k)
  546. break
  547. def get_input_embeddings(self):
  548. return self.embeddings
  549. def set_input_embeddings(self, new_embeddings):
  550. self.embeddings = new_embeddings
  551. @auto_docstring
  552. def forward(
  553. self,
  554. input_ids: Optional[torch.LongTensor] = None,
  555. inputs_embeds: Optional[torch.LongTensor] = None,
  556. cache_params: Optional[MambaCache] = None,
  557. use_cache: Optional[bool] = None,
  558. output_hidden_states: Optional[bool] = None,
  559. return_dict: Optional[bool] = None,
  560. cache_position: Optional[torch.LongTensor] = None,
  561. attention_mask: Optional[torch.LongTensor] = None,
  562. ) -> Union[tuple, MambaOutput]:
  563. r"""
  564. cache_params (`MambaCache`, *optional*):
  565. If passed along, the model uses the previous state in all the blocks (which will give the output for the
  566. `input_ids` provided as if the model add `state_input_ids + input_ids` as context).
  567. use_cache (`bool`, *optional*):
  568. If set to `True`, the `cache_params` is returned and can be used to quickly generate the next logits.
  569. """
  570. output_hidden_states = (
  571. output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
  572. )
  573. use_cache = use_cache if use_cache is not None else (self.config.use_cache if not self.training else False)
  574. return_dict = return_dict if return_dict is not None else self.config.use_return_dict
  575. if (input_ids is None) ^ (inputs_embeds is not None): # ^ is python for xor
  576. raise ValueError("You must specify exactly one of input_ids or inputs_embeds")
  577. if inputs_embeds is None:
  578. inputs_embeds = self.embeddings(input_ids)
  579. if self.gradient_checkpointing and self.training and use_cache:
  580. use_cache = False
  581. if use_cache:
  582. if cache_params is None:
  583. cache_params = MambaCache(
  584. self.config, inputs_embeds.size(0), device=inputs_embeds.device, dtype=inputs_embeds.dtype
  585. )
  586. cache_position = torch.arange(0, self.config.conv_kernel, device=inputs_embeds.device)
  587. elif cache_position is None:
  588. # cases when we do manual forward instead of using `model.generate` which will initiate
  589. # `cache_position` and makes sure it is not None, throw error here instead of doing some
  590. # hack to conjecture the current cache position
  591. raise ValueError(
  592. "You have to specify the `cache_position` manually when `use_cache=True` and `cache_params` is passed, "
  593. "you don't have to pass a `cache_params` if you are in prefilling stage because in that case it will "
  594. "be initialized for you automatically"
  595. )
  596. else:
  597. cache_params = None
  598. hidden_states = inputs_embeds
  599. all_hidden_states = () if output_hidden_states else None
  600. for mixer_block in self.layers:
  601. hidden_states = mixer_block(
  602. hidden_states,
  603. cache_params=cache_params,
  604. cache_position=cache_position,
  605. attention_mask=attention_mask,
  606. )
  607. if output_hidden_states:
  608. all_hidden_states = all_hidden_states + (hidden_states,)
  609. hidden_states = self.norm_f(hidden_states)
  610. if output_hidden_states:
  611. all_hidden_states = all_hidden_states + (hidden_states,)
  612. if not return_dict:
  613. return tuple(v for v in [hidden_states, cache_params, all_hidden_states] if v is not None)
  614. return MambaOutput(
  615. last_hidden_state=hidden_states,
  616. cache_params=cache_params if use_cache else None,
  617. hidden_states=all_hidden_states,
  618. )
  619. @auto_docstring(
  620. custom_intro="""
  621. The MAMBA Model transformer with a language modeling head on top (linear layer with weights tied to the input
  622. embeddings).
  623. """
  624. )
  625. class MambaForCausalLM(MambaPreTrainedModel, GenerationMixin):
  626. _tied_weights_keys = ["lm_head.weight"]
  627. def __init__(self, config):
  628. super().__init__(config)
  629. self.backbone = MambaModel(config)
  630. self.lm_head = nn.Linear(config.hidden_size, config.vocab_size, bias=False)
  631. # Initialize weights and apply final processing
  632. self.post_init()
  633. def get_input_embeddings(self):
  634. return self.backbone.get_input_embeddings()
  635. def set_input_embeddings(self, new_embeddings):
  636. return self.backbone.set_input_embeddings(new_embeddings)
  637. def _update_model_kwargs_for_generation(
  638. self, outputs: ModelOutput, model_kwargs: dict[str, Any], num_new_tokens: int = 1, **kwargs
  639. ) -> dict[str, Any]:
  640. model_kwargs["cache_params"] = outputs.get("cache_params", None)
  641. if (
  642. model_kwargs.get("use_cache", True)
  643. and "cache_position" in model_kwargs
  644. and model_kwargs["cache_position"] is not None
  645. ):
  646. model_kwargs["cache_position"] = model_kwargs["cache_position"][-1:] + num_new_tokens
  647. if "attention_mask" in model_kwargs:
  648. attention_mask = model_kwargs["attention_mask"]
  649. model_kwargs["attention_mask"] = torch.cat(
  650. [attention_mask, attention_mask.new_ones((attention_mask.shape[0], 1))], dim=-1
  651. )
  652. return model_kwargs
  653. def prepare_inputs_for_generation(
  654. self,
  655. input_ids,
  656. inputs_embeds=None,
  657. use_cache=None,
  658. cache_params: Optional[MambaCache] = None,
  659. cache_position: Optional[torch.LongTensor] = None,
  660. attention_mask: Optional[torch.LongTensor] = None,
  661. **kwargs,
  662. ):
  663. # Overwritten -- uses `cache_params` as opposed to `past_key_values`
  664. model_inputs = {"input_ids": input_ids.contiguous()}
  665. if use_cache and cache_params is None:
  666. # we initialize the `cache_position` to full size of `conv_states` at prefill stage
  667. # considering padding will be applied when input length is shorter, and truncation
  668. # will be applied when it is longer, so it will be equivalent to always have it match
  669. # the length of `cache_params.conv_states`, which is `config.conv_kernel`
  670. cache_position = torch.arange(0, self.backbone.config.conv_kernel, device=input_ids.device)
  671. if inputs_embeds is not None:
  672. model_inputs = {"inputs_embeds": inputs_embeds}
  673. max_batch_size = inputs_embeds.size(0)
  674. else:
  675. max_batch_size = input_ids.size(0)
  676. cache_params = MambaCache(self.backbone.config, max_batch_size, device=self.device, dtype=self.dtype)
  677. if use_cache and cache_position[0] > 0:
  678. model_inputs["input_ids"] = input_ids[:, -1].unsqueeze(-1).contiguous()
  679. attention_mask = None
  680. if not use_cache and inputs_embeds is not None:
  681. model_inputs = {"inputs_embeds": inputs_embeds}
  682. model_inputs.update(
  683. {
  684. "cache_params": cache_params,
  685. "use_cache": use_cache,
  686. "cache_position": cache_position,
  687. "attention_mask": attention_mask,
  688. }
  689. )
  690. # Forward ALL kwargs that are uninitialized (e.g. `use_cache`).
  691. for key, value in kwargs.items():
  692. if key not in model_inputs:
  693. model_inputs[key] = value
  694. return model_inputs
  695. @auto_docstring
  696. def forward(
  697. self,
  698. input_ids: Optional[torch.LongTensor] = None,
  699. attention_mask: Optional[torch.LongTensor] = None,
  700. inputs_embeds: Optional[torch.FloatTensor] = None,
  701. cache_params: Optional[MambaCache] = None,
  702. labels: Optional[torch.LongTensor] = None,
  703. output_hidden_states: Optional[bool] = None,
  704. return_dict: Optional[bool] = None,
  705. use_cache: Optional[bool] = None,
  706. cache_position: Optional[torch.Tensor] = None,
  707. **kwargs, # for now we need this for generation
  708. ) -> Union[tuple, MambaCausalLMOutput]:
  709. r"""
  710. cache_params (`MambaCache`, *optional*):
  711. If passed along, the model uses the previous state in all the blocks (which will give the output for the
  712. `input_ids` provided as if the model add `state_input_ids + input_ids` as context).
  713. labels (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*):
  714. Labels for language modeling. Note that the labels **are shifted** inside the model, i.e. you can set
  715. `labels = input_ids` Indices are selected in `[-100, 0, ..., config.vocab_size]` All labels set to `-100`
  716. are ignored (masked), the loss is only computed for labels in `[0, ..., config.vocab_size]`
  717. use_cache (`bool`, *optional*):
  718. If set to `True`, the `cache_params` is returned and can be used to quickly generate the next logits.
  719. """
  720. return_dict = return_dict if return_dict is not None else self.config.use_return_dict
  721. mamba_outputs = self.backbone(
  722. input_ids,
  723. cache_params=cache_params,
  724. inputs_embeds=inputs_embeds,
  725. output_hidden_states=output_hidden_states,
  726. return_dict=return_dict,
  727. use_cache=use_cache,
  728. cache_position=cache_position,
  729. attention_mask=attention_mask,
  730. )
  731. hidden_states = mamba_outputs[0]
  732. logits = self.lm_head(hidden_states.to(self.lm_head.weight.dtype)).float()
  733. loss = None
  734. if labels is not None:
  735. # move labels to correct device to enable model parallelism
  736. labels = labels.to(logits.device)
  737. # Shift so that tokens < n predict n
  738. shift_logits = logits[..., :-1, :].contiguous()
  739. shift_labels = labels[..., 1:].contiguous()
  740. # Flatten the tokens
  741. loss_fct = CrossEntropyLoss()
  742. loss = loss_fct(shift_logits.view(-1, shift_logits.size(-1)), shift_labels.view(-1))
  743. if not return_dict:
  744. output = (logits,) + mamba_outputs[1:]
  745. return ((loss,) + output) if loss is not None else output
  746. return MambaCausalLMOutput(
  747. loss=loss,
  748. logits=logits,
  749. cache_params=mamba_outputs.cache_params,
  750. hidden_states=mamba_outputs.hidden_states,
  751. )
  752. __all__ = ["MambaForCausalLM", "MambaModel", "MambaPreTrainedModel", "MambaCache"]