modular_zamba2.py 55 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525526527528529530531532533534535536537538539540541542543544545546547548549550551552553554555556557558559560561562563564565566567568569570571572573574575576577578579580581582583584585586587588589590591592593594595596597598599600601602603604605606607608609610611612613614615616617618619620621622623624625626627628629630631632633634635636637638639640641642643644645646647648649650651652653654655656657658659660661662663664665666667668669670671672673674675676677678679680681682683684685686687688689690691692693694695696697698699700701702703704705706707708709710711712713714715716717718719720721722723724725726727728729730731732733734735736737738739740741742743744745746747748749750751752753754755756757758759760761762763764765766767768769770771772773774775776777778779780781782783784785786787788789790791792793794795796797798799800801802803804805806807808809810811812813814815816817818819820821822823824825826827828829830831832833834835836837838839840841842843844845846847848849850851852853854855856857858859860861862863864865866867868869870871872873874875876877878879880881882883884885886887888889890891892893894895896897898899900901902903904905906907908909910911912913914915916917918919920921922923924925926927928929930931932933934935936937938939940941942943944945946947948949950951952953954955956957958959960961962963964965966967968969970971972973974975976977978979980981982983984985986987988989990991992993994995996997998999100010011002100310041005100610071008100910101011101210131014101510161017101810191020102110221023102410251026102710281029103010311032103310341035103610371038103910401041104210431044104510461047104810491050105110521053105410551056105710581059106010611062106310641065106610671068106910701071107210731074107510761077107810791080108110821083108410851086108710881089109010911092109310941095109610971098109911001101110211031104110511061107110811091110111111121113111411151116111711181119112011211122112311241125112611271128112911301131113211331134113511361137113811391140114111421143114411451146114711481149115011511152
  1. # coding=utf-8
  2. # Copyright 2024 Zyphra Technologies and the HuggingFace Inc. team. All rights reserved.
  3. #
  4. #
  5. # Licensed under the Apache License, Version 2.0 (the "License");
  6. # you may not use this file except in compliance with the License.
  7. # You may obtain a copy of the License at
  8. #
  9. # http://www.apache.org/licenses/LICENSE-2.0
  10. #
  11. # Unless required by applicable law or agreed to in writing, software
  12. # distributed under the License is distributed on an "AS IS" BASIS,
  13. # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  14. # See the License for the specific language governing permissions and
  15. # limitations under the License.
  16. import math
  17. import re
  18. from itertools import cycle
  19. from typing import Callable, Optional, Union
  20. import torch
  21. from torch import nn
  22. from ...activations import ACT2FN
  23. from ...modeling_flash_attention_utils import FlashAttentionKwargs
  24. from ...modeling_outputs import BaseModelOutputWithPast
  25. from ...modeling_utils import ALL_ATTENTION_FUNCTIONS, PreTrainedModel
  26. from ...processing_utils import Unpack
  27. from ...utils import (
  28. logging,
  29. )
  30. from ...utils.deprecation import deprecate_kwarg
  31. from ...utils.import_utils import (
  32. is_causal_conv1d_available,
  33. is_mamba_ssm_available,
  34. )
  35. from ..llama.modeling_llama import LlamaRotaryEmbedding, apply_rotary_pos_emb
  36. from ..mamba2.modeling_mamba2 import pad_tensor_by_size, reshape_into_chunks, segment_sum
  37. from ..zamba.modeling_zamba import (
  38. ZambaAttention,
  39. ZambaAttentionDecoderLayer,
  40. ZambaForCausalLM,
  41. ZambaForSequenceClassification,
  42. ZambaHybridDynamicCache,
  43. ZambaHybridLayer,
  44. ZambaMambaDecoderLayer,
  45. ZambaModel,
  46. ZambaRMSNorm,
  47. eager_attention_forward,
  48. )
  49. from .configuration_zamba2 import Zamba2Config
  50. if is_mamba_ssm_available():
  51. from mamba_ssm.ops.triton.selective_state_update import selective_state_update
  52. from mamba_ssm.ops.triton.ssd_combined import mamba_chunk_scan_combined, mamba_split_conv1d_scan_combined
  53. else:
  54. selective_state_update, mamba_chunk_scan_combined, mamba_split_conv1d_scan_combined = None, None, None
  55. if is_causal_conv1d_available():
  56. from causal_conv1d import causal_conv1d_fn, causal_conv1d_update
  57. else:
  58. causal_conv1d_update, causal_conv1d_fn = None, None
  59. is_fast_path_available = all((selective_state_update, causal_conv1d_fn, causal_conv1d_update))
  60. _CONFIG_FOR_DOC = "Zyphra/Zamba2-2.7B"
  61. logger = logging.get_logger(__name__)
  62. class Zamba2RMSNormGated(torch.nn.Module):
  63. def __init__(self, hidden_size, group_size, eps=1e-6):
  64. super().__init__()
  65. self.weight = nn.Parameter(torch.ones(hidden_size))
  66. self.variance_epsilon = eps
  67. self.group_size = group_size
  68. def forward(self, hidden_states, gate=None):
  69. input_dtype = hidden_states.dtype
  70. hidden_states = hidden_states.to(torch.float32)
  71. if gate is not None:
  72. hidden_states = hidden_states * nn.functional.silu(gate.to(torch.float32))
  73. *prefix_dims, last_dim = hidden_states.shape
  74. group_count = last_dim // self.group_size
  75. hidden_states_group = hidden_states.view(*prefix_dims, group_count, self.group_size)
  76. variance = hidden_states_group.pow(2).mean(-1, keepdim=True)
  77. hidden_states_group = hidden_states_group * torch.rsqrt(variance + self.variance_epsilon)
  78. hidden_states = hidden_states_group.view(*prefix_dims, group_count * self.group_size)
  79. return self.weight * hidden_states.to(input_dtype)
  80. class Zamba2RMSNorm(ZambaRMSNorm):
  81. pass
  82. class Zamba2HybridDynamicCache(ZambaHybridDynamicCache):
  83. """
  84. A dynamic cache that can handle both the attention cache (which has a seq_len dimension) and the mamba cache
  85. (which has a constant shape regardless of seq_len).
  86. This cache has two sets of lists of tensors: `key_cache` and `value_cache` for attention cache and `conv_states`
  87. and `ssm_states` for mamba cache. Each of these lists has `num_layers` tensors. The expected shape for each tensor
  88. For attention layers, `key_cache` and `value_cache` have a shape of `(batch_size, num_heads, seq_len, head_dim)`,
  89. while `conv_states` and `ssm_states` have a shape of `(batch_size, 0)` (empty tensors).
  90. For mamba layers, `key_cache` and `value_cache` have a shape of `(batch_size, 0)` (empty tensors),
  91. while `conv_states` represents the convolution state and has a shape of `(batch_size, d_inner, d_conv)`,
  92. and `ssm_states` represents the ssm state and has a shape of `(batch_size, d_inner, d_state)`.
  93. """
  94. def __init__(
  95. self, config: Zamba2Config, batch_size: int, dtype: torch.dtype = torch.float16, device: Optional[str] = None
  96. ):
  97. self.dtype = dtype
  98. self.layers_block_type = config.layers_block_type
  99. self.has_previous_state = False
  100. self.intermediate_size = int(config.mamba_expand * config.hidden_size)
  101. self.ssm_state_size = config.mamba_d_state
  102. self.conv_kernel_size = config.mamba_d_conv
  103. self.n_mamba_heads = config.n_mamba_heads
  104. self.transformer_layers = []
  105. self._modules = {}
  106. self._parameters = {}
  107. self._buffers = {}
  108. self.conv_states = {}
  109. self.ssm_states = {}
  110. for i in range(config.num_hidden_layers):
  111. self.conv_states[i] = torch.zeros(
  112. batch_size,
  113. self.intermediate_size + 2 * config.mamba_ngroups * config.mamba_d_state,
  114. self.conv_kernel_size,
  115. device=device,
  116. dtype=dtype,
  117. )
  118. self.ssm_states[i] = torch.zeros(
  119. batch_size, self.n_mamba_heads, config.mamba_headdim, self.ssm_state_size, device=device, dtype=dtype
  120. )
  121. if self.layers_block_type[i] == "hybrid":
  122. self.transformer_layers.append(i)
  123. self.key_cache = [torch.tensor([[]] * batch_size, device=device) for _ in range(config.num_hidden_layers)]
  124. self.value_cache = [torch.tensor([[]] * batch_size, device=device) for _ in range(config.num_hidden_layers)]
  125. def update_conv_state(
  126. self, layer_idx: int, new_conv_state: torch.Tensor, cache_position: torch.LongTensor
  127. ) -> torch.Tensor:
  128. conv_state = self.conv_states[layer_idx]
  129. cache_position = cache_position.clamp(0, self.conv_kernel_size - 1)
  130. conv_state = conv_state.roll(shifts=-1, dims=-1)
  131. conv_state[:, :, cache_position] = new_conv_state.to(conv_state.device)
  132. self.conv_states[layer_idx].zero_()
  133. self.conv_states[layer_idx] += conv_state
  134. return self.conv_states[layer_idx]
  135. def reset(self):
  136. self.conv_states.zero_()
  137. self.ssm_states.zero_()
  138. def get_seq_length(self, layer_idx: Optional[int] = 0) -> int:
  139. """Returns the sequence length of the cached states. A layer index can be optionally passed."""
  140. # take any layer that contains cache and not empty tensor
  141. layer_idx = self.transformer_layers[0] if layer_idx not in self.transformer_layers else layer_idx
  142. if len(self.key_cache) <= layer_idx or self.key_cache[layer_idx].numel() == 0:
  143. return 0
  144. return self.key_cache[layer_idx].shape[-2]
  145. class Zamba2RotaryEmbedding(LlamaRotaryEmbedding):
  146. pass
  147. class Zamba2Attention(ZambaAttention):
  148. """
  149. Multi-headed attention from 'Attention Is All You Need' paper.
  150. Adapted from transformers.models.mistral.modeling_mistral.MistralAttention:
  151. The input dimension here is attention_hidden_size = 2 * hidden_size, and head_dim = attention_hidden_size // num_heads.
  152. The extra factor of 2 comes from the input being the concatenation of original_hidden_states with the output of the previous (mamba) layer
  153. (see fig. 2 in https://huggingface.co/papers/2405.16712).
  154. Additionally, replaced
  155. attn_weights = torch.matmul(query_states, key_states.transpose(2, 3)) / math.sqrt(self.head_dim) with
  156. attn_weights = torch.matmul(query_states, key_states.transpose(2, 3)) / math.sqrt(self.head_dim/2)
  157. Finally, this attention layer contributes to tied transformer blocks aimed to increasing compute without increasing model size. Because this
  158. layer is tied, un-tied adapters (formally the same as LoRA but used in the base model) modules are added to the q, k, v projectors to increase
  159. expressivity with a small memory overhead (see Fig. 2 of https://huggingface.co/papers/2411.15242).
  160. """
  161. def __init__(
  162. self,
  163. config: Zamba2Config,
  164. layer_idx: Optional[int] = None,
  165. num_fwd_mem_blocks: Optional[int] = None,
  166. block_id: Optional[int] = None,
  167. ):
  168. super().__init__(config, layer_idx)
  169. self.num_fwd_mem_blocks = num_fwd_mem_blocks
  170. self.layer_block_map = config.hybrid_layer_ids
  171. self.block_id = block_id
  172. if config.use_shared_attention_adapter:
  173. self.linear_q_adapter_list = nn.ModuleList([])
  174. self.linear_k_adapter_list = nn.ModuleList([])
  175. self.linear_v_adapter_list = nn.ModuleList([])
  176. for i in range(self.num_fwd_mem_blocks):
  177. if i % config.num_mem_blocks == block_id:
  178. linear_q_adapter = nn.Sequential(
  179. nn.Linear(self.attention_hidden_size, self.config.adapter_rank, bias=False),
  180. nn.Linear(self.config.adapter_rank, self.attention_hidden_size, bias=False),
  181. )
  182. linear_k_adapter = nn.Sequential(
  183. nn.Linear(self.attention_hidden_size, self.config.adapter_rank, bias=False),
  184. nn.Linear(self.config.adapter_rank, self.attention_hidden_size, bias=False),
  185. )
  186. linear_v_adapter = nn.Sequential(
  187. nn.Linear(self.attention_hidden_size, self.config.adapter_rank, bias=False),
  188. nn.Linear(self.config.adapter_rank, self.attention_hidden_size, bias=False),
  189. )
  190. else:
  191. linear_q_adapter = nn.Identity()
  192. linear_k_adapter = nn.Identity()
  193. linear_v_adapter = nn.Identity()
  194. self.linear_q_adapter_list.append(linear_q_adapter)
  195. self.linear_k_adapter_list.append(linear_k_adapter)
  196. self.linear_v_adapter_list.append(linear_v_adapter)
  197. self.layer_dic = {value: index for index, value in enumerate(self.layer_block_map)}
  198. @deprecate_kwarg("past_key_value", new_name="past_key_values", version="4.58")
  199. def forward(
  200. self,
  201. hidden_states: torch.Tensor,
  202. layer_idx: int,
  203. attention_mask: Optional[torch.Tensor] = None,
  204. past_key_values: Optional[Zamba2HybridDynamicCache] = None,
  205. position_embeddings: Optional[tuple[torch.Tensor, torch.Tensor]] = None,
  206. **kwargs: Unpack[FlashAttentionKwargs],
  207. ) -> tuple[torch.Tensor, Optional[torch.Tensor], Optional[tuple[torch.Tensor]]]:
  208. input_shape = hidden_states.shape[:-1]
  209. hidden_shape = (*input_shape, -1, self.head_dim)
  210. query_states = self.q_proj(hidden_states)
  211. key_states = self.k_proj(hidden_states)
  212. value_states = self.v_proj(hidden_states)
  213. if self.config.use_shared_attention_adapter:
  214. adapter_layer_idx = self.layer_dic[layer_idx]
  215. query_states = query_states + self.linear_q_adapter_list[adapter_layer_idx](hidden_states)
  216. key_states = key_states + self.linear_k_adapter_list[adapter_layer_idx](hidden_states)
  217. value_states = value_states + self.linear_v_adapter_list[adapter_layer_idx](hidden_states)
  218. query_states = query_states.view(hidden_shape).transpose(1, 2)
  219. key_states = key_states.view(hidden_shape).transpose(1, 2)
  220. value_states = value_states.view(hidden_shape).transpose(1, 2)
  221. if self.config.use_mem_rope:
  222. cos, sin = position_embeddings
  223. query_states, key_states = apply_rotary_pos_emb(query_states, key_states, cos, sin)
  224. if past_key_values is not None:
  225. key_states, value_states = past_key_values.update(key_states, value_states, layer_idx)
  226. attention_interface: Callable = eager_attention_forward
  227. if self.config._attn_implementation != "eager":
  228. attention_interface = ALL_ATTENTION_FUNCTIONS[self.config._attn_implementation]
  229. attn_output, attn_weights = attention_interface(
  230. self,
  231. query_states,
  232. key_states,
  233. value_states,
  234. attention_mask,
  235. dropout=0.0 if not self.training else self.attention_dropout,
  236. scaling=self.scaling,
  237. **kwargs,
  238. )
  239. attn_output = attn_output.reshape(*input_shape, -1).contiguous()
  240. attn_output = self.o_proj(attn_output)
  241. return attn_output, attn_weights
  242. class Zamba2MambaMixer(nn.Module):
  243. """
  244. Compute ∆, A, B, C, and D the state space parameters and compute the `contextualized_states`.
  245. A, D are input independent (see Mamba paper [1] Section 3.5.2 "Interpretation of A" for why A isn't selective)
  246. ∆, B, C are input-dependent (this is a key difference between Mamba and the linear time invariant S4,
  247. and is why Mamba is called **selective** state spaces)
  248. """
  249. def __init__(self, config: Zamba2Config, layer_idx: Optional[int] = None):
  250. super().__init__()
  251. self.config = config
  252. self.hidden_size = config.hidden_size
  253. self.ssm_state_size = config.mamba_d_state
  254. self.conv_kernel_size = config.mamba_d_conv
  255. self.intermediate_size = int(config.mamba_expand * self.hidden_size)
  256. self.layer_idx = layer_idx
  257. self.use_conv_bias = config.use_conv_bias
  258. self.activation = "silu"
  259. self.act = nn.SiLU()
  260. self.use_mem_eff_path = config.use_mem_eff_path
  261. self.n_groups = config.mamba_ngroups
  262. self.head_dim = config.mamba_headdim
  263. self.num_heads = self.config.n_mamba_heads
  264. self.chunk_size = config.chunk_size
  265. self.time_step_limit = config.time_step_limit
  266. self.time_step_min = config.time_step_min
  267. self.time_step_max = config.time_step_max
  268. self.conv_dim = self.intermediate_size + 2 * self.n_groups * self.ssm_state_size
  269. self.conv1d = nn.Conv1d(
  270. in_channels=self.conv_dim,
  271. out_channels=self.conv_dim,
  272. bias=True,
  273. kernel_size=config.mamba_d_conv,
  274. groups=self.conv_dim,
  275. padding=config.mamba_d_conv - 1,
  276. )
  277. # projection of the input hidden states
  278. projection_size = self.intermediate_size + self.conv_dim + self.num_heads
  279. self.in_proj = nn.Linear(
  280. self.hidden_size,
  281. projection_size,
  282. bias=config.add_bias_linear,
  283. )
  284. # selective projection used to make dt, B and C input dependent
  285. # time step projection (discretization)
  286. # instantiate once and copy inv_dt in init_weights of PretrainedModel
  287. self.dt_bias = nn.Parameter(torch.ones(self.num_heads))
  288. # S4D real initialization. These are not discretized!
  289. # The core is to load them, compute the discrete states, then write the updated state. Keeps the memory bounded
  290. A = torch.arange(1, self.num_heads + 1)
  291. self.A_log = nn.Parameter(torch.log(A))
  292. self.norm = Zamba2RMSNormGated(
  293. self.intermediate_size, group_size=self.intermediate_size // self.n_groups, eps=1e-5
  294. )
  295. self.D = nn.Parameter(torch.ones(self.num_heads))
  296. self.out_proj = nn.Linear(self.intermediate_size, self.hidden_size, bias=config.add_bias_linear)
  297. if not is_fast_path_available:
  298. logger.warning_once(
  299. "The fast path is not available because one of `(selective_state_update, causal_conv1d_fn, causal_conv1d_update)`"
  300. " is None. Falling back to the naive implementation. To install follow https://github.com/state-spaces/mamba/#installation and"
  301. " https://github.com/Dao-AILab/causal-conv1d"
  302. )
  303. def cuda_kernels_forward(
  304. self,
  305. hidden_states: torch.Tensor,
  306. cache_params: Optional[Zamba2HybridDynamicCache] = None,
  307. attention_mask: Optional[torch.Tensor] = None,
  308. ):
  309. # set up dimensions for reshapes later
  310. batch_size, seq_len, _ = hidden_states.shape
  311. groups_time_state_size = self.n_groups * self.ssm_state_size
  312. d_to_remove = 2 * self.intermediate_size + 2 * self.n_groups * self.ssm_state_size + self.num_heads
  313. # getting projected states from cache if it exists
  314. if cache_params is not None and cache_params.has_previous_state:
  315. in_projected_states = self.in_proj(hidden_states.squeeze(1)) # (B 2D)
  316. d_mlp = (in_projected_states.shape[-1] - d_to_remove) // 2
  317. split_projection_dim = [d_mlp, d_mlp, self.intermediate_size, self.conv_dim, self.num_heads]
  318. _, _, gate, hidden_states_B_C, dt = torch.split(in_projected_states, split_projection_dim, dim=-1)
  319. hidden_states_B_C = causal_conv1d_update(
  320. hidden_states_B_C,
  321. cache_params.conv_states[self.layer_idx],
  322. self.conv1d.weight.squeeze(1),
  323. self.conv1d.bias,
  324. self.activation,
  325. )
  326. hidden_states, B, C = torch.split(
  327. hidden_states_B_C,
  328. [self.intermediate_size, groups_time_state_size, groups_time_state_size],
  329. dim=-1,
  330. )
  331. A = -torch.exp(self.A_log.float()) # (nheads,)
  332. A = A[:, None, ...][:, :, None].expand(-1, self.head_dim, self.ssm_state_size).to(dtype=torch.float32)
  333. dt = dt[:, :, None].expand(-1, -1, self.head_dim)
  334. dt_bias = self.dt_bias[:, None, ...].expand(-1, self.head_dim)
  335. D = self.D[:, None, ...].expand(-1, self.head_dim)
  336. B = B.view(batch_size, self.n_groups, B.shape[1] // self.n_groups)
  337. C = C.view(batch_size, self.n_groups, C.shape[1] // self.n_groups)
  338. hidden_states_reshaped = hidden_states.view(batch_size, self.num_heads, self.head_dim)
  339. hidden_states = selective_state_update(
  340. cache_params.ssm_states[self.layer_idx],
  341. hidden_states_reshaped,
  342. dt,
  343. A,
  344. B,
  345. C,
  346. D,
  347. z=None,
  348. dt_bias=dt_bias,
  349. dt_softplus=True,
  350. )
  351. hidden_states = hidden_states.view(batch_size, self.num_heads * self.head_dim)
  352. hidden_states = self.norm(hidden_states, gate)
  353. out = self.out_proj(hidden_states)[:, None, ...]
  354. # if no cache is found, calling the kernel
  355. else:
  356. if attention_mask is not None and not torch.all(attention_mask == 1):
  357. # tune out hidden states for pad tokens, see https://github.com/state-spaces/mamba/issues/66
  358. dtype = hidden_states.dtype
  359. hidden_states = (hidden_states * attention_mask[:, :, None]).to(dtype)
  360. # 1. Gated MLP's linear projection
  361. projected_states = self.in_proj(hidden_states)
  362. A = -torch.exp(self.A_log.float()) # (num_heads) or (intermediate_size, state_size)
  363. dt_limit_kwargs = {} if self.time_step_limit is None else {"dt_limit": self.time_step_limit}
  364. if attention_mask is not None:
  365. input_not_masked = torch.all(attention_mask == 1)
  366. else:
  367. input_not_masked = True
  368. if self.use_mem_eff_path and self.training and cache_params is None and input_not_masked:
  369. out, ssm_state = mamba_split_conv1d_scan_combined(
  370. projected_states,
  371. self.conv1d.weight.squeeze(1),
  372. self.conv1d.bias,
  373. self.dt_bias,
  374. A,
  375. D=self.D,
  376. chunk_size=self.chunk_size,
  377. seq_idx=None,
  378. activation=self.activation,
  379. rmsnorm_weight=self.norm.weight,
  380. rmsnorm_eps=self.norm.variance_epsilon,
  381. outproj_weight=self.out_proj.weight,
  382. outproj_bias=self.out_proj.bias,
  383. headdim=self.head_dim,
  384. ngroups=self.n_groups,
  385. norm_before_gate=False,
  386. return_final_states=True,
  387. **dt_limit_kwargs,
  388. )
  389. else:
  390. gate, hidden_states_B_C, time_step = torch.split(
  391. projected_states,
  392. [self.intermediate_size, self.conv_dim, self.num_heads],
  393. dim=-1,
  394. )
  395. # 1D Convolution
  396. if cache_params is not None:
  397. hidden_states_B_C_t = hidden_states_B_C.transpose(1, 2)
  398. conv_state = nn.functional.pad(
  399. hidden_states_B_C_t, (self.conv_kernel_size - hidden_states_B_C_t.shape[-1], 0)
  400. )
  401. cache_params.conv_states[self.layer_idx].copy_(conv_state)
  402. if causal_conv1d_fn is None or self.activation not in ["silu", "swish"]:
  403. hidden_states_B_C = self.act(
  404. self.conv1d(hidden_states_B_C.transpose(1, 2)).transpose(1, 2)[:, :seq_len]
  405. ) # (B, L, self.d_inner + 2 * ngroups * d_state)
  406. else:
  407. hidden_states_B_C = causal_conv1d_fn(
  408. x=hidden_states_B_C.transpose(1, 2),
  409. weight=self.conv1d.weight.squeeze(1),
  410. bias=self.conv1d.bias,
  411. activation=self.activation,
  412. ).transpose(1, 2)[:, :seq_len]
  413. hidden_states, B, C = torch.split(
  414. hidden_states_B_C,
  415. [self.intermediate_size, groups_time_state_size, groups_time_state_size],
  416. dim=-1,
  417. )
  418. if attention_mask is not None and not torch.all(attention_mask == 1):
  419. # tune out hidden states for pad tokens, see https://github.com/state-spaces/mamba/issues/66
  420. dtype = hidden_states.dtype
  421. hidden_states = (hidden_states * attention_mask[:, :, None]).to(dtype)
  422. scan_output, ssm_state = mamba_chunk_scan_combined(
  423. hidden_states.view(batch_size, seq_len, -1, self.head_dim),
  424. time_step,
  425. A,
  426. B.view(batch_size, seq_len, self.n_groups, -1),
  427. C.view(batch_size, seq_len, self.n_groups, -1),
  428. chunk_size=self.chunk_size,
  429. D=self.D,
  430. z=None,
  431. seq_idx=None,
  432. return_final_states=True,
  433. dt_bias=self.dt_bias,
  434. dt_softplus=True,
  435. **dt_limit_kwargs,
  436. )
  437. if ssm_state is not None and cache_params is not None:
  438. cache_params.ssm_states[self.layer_idx].copy_(ssm_state)
  439. scan_output = scan_output.view(batch_size, seq_len, -1)
  440. # Multiply "gate" branch and apply extra normalization layer
  441. scan_output = self.norm(scan_output, gate)
  442. out = self.out_proj(scan_output)
  443. return out
  444. # fmt: off
  445. def torch_forward(self, input_states, cache_params: Optional[Zamba2HybridDynamicCache]=None, attention_mask: Optional[torch.Tensor]=None):
  446. batch_size, seq_len, _ = input_states.shape
  447. dtype = input_states.dtype
  448. # Gated MLP's linear projection
  449. if cache_params is not None and cache_params.has_previous_state:
  450. projected_states = self.in_proj(input_states.squeeze(1))
  451. else:
  452. if attention_mask is not None and not torch.all(attention_mask==1):
  453. # tune out hidden states for pad tokens, see https://github.com/state-spaces/mamba/issues/66
  454. input_states = (input_states * attention_mask[:, :, None]).to(dtype)
  455. projected_states = self.in_proj(input_states)
  456. d_mlp = (projected_states.shape[-1] - 2 * self.intermediate_size - 2 * self.n_groups * self.ssm_state_size- self.num_heads) // 2
  457. _, _, gate, hidden_states, dt = projected_states.split(
  458. [d_mlp, d_mlp, self.intermediate_size, self.conv_dim, self.num_heads], dim=-1
  459. )
  460. # Convolution sequence transformation
  461. if cache_params is not None:
  462. ssm_state = cache_params.ssm_states[self.layer_idx].clone()
  463. ssm_state = ssm_state.to(hidden_states.device)
  464. if cache_params.has_previous_state:
  465. gate = gate.unsqueeze(1)
  466. conv_state = cache_params.conv_states[self.layer_idx] # [batch, intermediate_size, conv_kernel_size]
  467. conv_state = torch.roll(conv_state, shifts=-1, dims=-1)
  468. # handle batched generation - states are copied through
  469. conv_state[:, :, -1] = hidden_states[:, 0, :] if hidden_states.ndim == 3 else hidden_states
  470. cache_params.conv_states[self.layer_idx].copy_(conv_state)
  471. hidden_states = torch.sum(conv_state.to(projected_states.device) * self.conv1d.weight[:, 0, :], dim=-1)
  472. if self.use_conv_bias:
  473. hidden_states += self.conv1d.bias
  474. hidden_states = self.act(hidden_states).to(dtype)[:, None, ...] # [batch, 1, intermediate_size] : decoding
  475. else:
  476. hidden_states = hidden_states.transpose(1,2)
  477. conv_state = nn.functional.pad(
  478. hidden_states,
  479. (self.conv_kernel_size - hidden_states.shape[-1], 0)
  480. )
  481. cache_params.conv_states[self.layer_idx].copy_(conv_state)
  482. hidden_states = self.act(self.conv1d(hidden_states).transpose(1,2))[:, :seq_len, :] # [batch, intermediate_size, seq_len]
  483. if attention_mask is not None and not torch.all(attention_mask==1):
  484. dtype = hidden_states.dtype
  485. # tune out hidden states for pad tokens, see https://github.com/state-spaces/mamba/issues/66
  486. hidden_states = (hidden_states * attention_mask[:, :, None]).to(dtype)
  487. else:
  488. ssm_state = torch.zeros(
  489. (batch_size, self.num_heads, self.head_dim, self.ssm_state_size),
  490. device=hidden_states.device, dtype=dtype
  491. )
  492. hidden_states = self.act(self.conv1d(hidden_states.transpose(1, 2))[..., :seq_len].transpose(1, 2))
  493. hidden_states, B, C = torch.split(hidden_states, [self.intermediate_size, self.n_groups * self.ssm_state_size, self.n_groups * self.ssm_state_size], dim=-1)
  494. A = -torch.exp(self.A_log.float()) # [num_heads]
  495. if cache_params is not None and cache_params.has_previous_state:
  496. # Note: there is no need to pad parameter matrices here, as there is just one new token
  497. # for batched generation
  498. dt = dt[:, None, ...] if dt.ndim == 2 else dt[:, 0, :][:, None, ...]
  499. dt = dt.transpose(1, 2).expand(batch_size, dt.shape[-1], self.head_dim)
  500. # [num_heads] -> [num_heads, head_dim]
  501. dt_bias = self.dt_bias[..., None].expand(self.dt_bias.shape[0], self.head_dim)
  502. dt = torch.nn.functional.softplus(dt + dt_bias.to(dt.dtype))
  503. dt = torch.clamp(dt, self.time_step_min) #, self.time_step_max)
  504. A = A[..., None, None].expand(self.num_heads, self.head_dim, self.ssm_state_size).to(dtype=torch.float32)
  505. # [bsz, num_heads, head_dim, state_size]
  506. dA = torch.exp(dt[..., None] * A)
  507. # Discretize B
  508. # [bsz, n_groups * state_size] -> [bsz, n_groups, 1, state_size] ->
  509. # -> [bsz, n_groups, group to head repetition factor, state_size] -> [bsz, num_heads, state_size]
  510. B = B.reshape(batch_size, self.n_groups, -1)[..., None, :]
  511. B = B.expand(batch_size, self.n_groups, self.num_heads // self.n_groups, B.shape[-1]).contiguous()
  512. B = B.reshape(batch_size, -1, B.shape[-1])
  513. # [bsz, num_heads, head_dim, state_size]
  514. dB = dt[..., None] * B[..., None, :]
  515. # Discretize x into dB
  516. # [bsz, intermediate_size] -> [bsz, num_heads, head_dim]
  517. hidden_states = hidden_states.reshape(batch_size, -1, self.head_dim)
  518. dBx = dB * hidden_states[..., None]
  519. # State calculation
  520. cache_params.ssm_states[self.layer_idx].copy_(
  521. cache_params.ssm_states[self.layer_idx] * dA + dBx
  522. )
  523. # Subsequent output
  524. # [bsz, n_groups * state_size] -> [bsz, num_heads, state_size]
  525. C = C.reshape(batch_size, self.n_groups, -1)[..., None, :]
  526. C = C.expand(batch_size, self.n_groups, self.num_heads // self.n_groups, C.shape[-1]).contiguous()
  527. C = C.reshape(batch_size, -1, C.shape[-1])
  528. # [bsz, num_heads, head_dim]
  529. ssm_states = cache_params.ssm_states[self.layer_idx].to(C.dtype) # Shape: [b, h, d, n]
  530. # Reshape ssm_states to merge the first two dimensions
  531. ssm_states_reshaped = ssm_states.view(batch_size * self.num_heads, self.head_dim, self.ssm_state_size) # Shape: [b*h, d, n]
  532. C_reshaped = C.view(batch_size * self.num_heads, self.ssm_state_size, 1) # Shape: [b*h, n, 1]
  533. y = torch.bmm(ssm_states_reshaped, C_reshaped)
  534. y = y.view(batch_size, self.num_heads, self.head_dim)
  535. # D skip connection
  536. # [num_heads] -> [num_heads, head_dim]
  537. D = self.D[..., None].expand(self.D.shape[0], self.head_dim)
  538. y = (y + hidden_states * D).to(y.dtype)
  539. # [bsz, num_heads, head_dim] -> [bsz, 1, intermediate_size]
  540. y = y.reshape(batch_size, -1)[:, None, ...]
  541. else:
  542. # begin ssd naive implementation without einsums
  543. dt = nn.functional.softplus(dt + self.dt_bias)
  544. dt = torch.clamp(dt, self.time_step_min)
  545. hidden_states = hidden_states.reshape(batch_size, seq_len, -1, self.head_dim).float()
  546. B = B.reshape(batch_size, seq_len, -1, self.ssm_state_size).float()
  547. C = C.reshape(batch_size, seq_len, -1, self.ssm_state_size).float()
  548. B = B.repeat_interleave(self.num_heads // self.n_groups, dim=2, output_size=self.num_heads)
  549. C = C.repeat_interleave(self.num_heads // self.n_groups, dim=2, output_size=self.num_heads)
  550. pad_size = (self.chunk_size - seq_len % self.chunk_size) % self.chunk_size
  551. D_residual = self.D[..., None] * pad_tensor_by_size(hidden_states, pad_size)
  552. # Discretize x and A
  553. hidden_states = hidden_states * dt[..., None]
  554. A = A.to(hidden_states.dtype) * dt
  555. # Rearrange into blocks/chunks
  556. hidden_states, A, B, C = [reshape_into_chunks(t, pad_size, self.chunk_size) for t in (hidden_states, A, B, C)]
  557. # [bsz, -1, chunk_size, num_heads] -> [bsz, num_heads, -1, chunk_size]
  558. A = A.permute(0, 3, 1, 2)
  559. A_cumsum = torch.cumsum(A, dim=-1)
  560. # 1. Compute the output for each intra-chunk (diagonal blocks)
  561. # This is the analog of a causal mask
  562. L = torch.exp(segment_sum(A))
  563. # First, contraction of C and B to get G (attention-weights like)
  564. G_intermediate = C[:, :, :, None, :, :] * B[:, :, None, :, : ,:] # shape: (b, c, l, s, h, n)
  565. G = G_intermediate.sum(dim=-1) # shape: (b, c, l, s, h)
  566. # Step 2: Compute M, equivalent to applying attention mask to weights
  567. M_intermediate = G[..., None] * L.permute(0, 2, 3, 4, 1)[..., None]
  568. M = M_intermediate.sum(dim=-1)
  569. # Step 3: Compute Y_diag (apply to values)
  570. Y_diag = (M[..., None] * hidden_states[:, :, None]).sum(3)
  571. # (right term of low-rank factorization of off-diagonal blocks; B terms)
  572. decay_states = torch.exp(A_cumsum[:, :, :, -1:] - A_cumsum)
  573. B_decay_contraction = B * decay_states.permute(0, 2, 3, 1)[..., None]
  574. # permute back B * decay states
  575. states = (B_decay_contraction.permute(0, 1, 3, 2, 4)[..., None] * hidden_states.permute(0, 1, 3, 2, 4)[..., None, :]).sum(dim=3).permute(0, 1, 2, 4, 3)
  576. if cache_params is not None and cache_params.has_previous_state:
  577. previous_states = cache_params.ssm_states[self.layer_idx][:, None, ...]
  578. else:
  579. previous_states = torch.zeros_like(states[:, :1])
  580. states = torch.cat([previous_states, states], dim=1)
  581. decay_chunk = torch.exp(segment_sum(nn.functional.pad(A_cumsum[:, :, :, -1], (1, 0))))
  582. states_permuted = states.permute(0, 2, 1, 3, 4)
  583. result = (decay_chunk[..., None, None] * states_permuted[:, :, None, ...]).sum(dim=2)
  584. new_states = result.permute(0, 2, 1, 3, 4)
  585. states, ssm_state = new_states[:, :-1], new_states[:, -1]
  586. # Compute state -> output conversion per chunk
  587. # (left term of low-rank factorization of off-diagonal blocks; C terms)
  588. state_decay_out = torch.exp(A_cumsum)
  589. # compute Yoff
  590. C_times_states = (C[..., None, :] * states[:, :, None, ...])
  591. state_decay_out_permuted = state_decay_out.permute(0, 2, 3, 1)
  592. Y_off = (C_times_states.sum(-1) * state_decay_out_permuted[..., None])
  593. # Add output of intra-chunk and inter-chunk terms (diagonal and off-diagonal blocks)
  594. y = Y_diag + Y_off
  595. # [bsz, -1, self.chunk_size, num_heads, head_dim] -> [bsz, (padded) seq_len, num_heads, head_dim]
  596. y = y.reshape(batch_size, -1, self.num_heads, self.head_dim)
  597. y = y + D_residual
  598. # Cutting off padded chunks
  599. if pad_size > 0:
  600. y = y[:, :seq_len, :, :]
  601. y = y.reshape(batch_size, seq_len, -1)
  602. if ssm_state is not None and cache_params is not None:
  603. cache_params.ssm_states[self.layer_idx].copy_(ssm_state)
  604. scan_output = self.norm(y, gate)
  605. # end ssd naive
  606. # 4. Final linear projection
  607. contextualized_states = self.out_proj(scan_output.to(dtype)) # [batch, seq_len, hidden_size]
  608. return contextualized_states
  609. # fmt: on
  610. def forward(
  611. self,
  612. hidden_states,
  613. cache_params: Optional[Zamba2HybridDynamicCache] = None,
  614. attention_mask: Optional[torch.Tensor] = None,
  615. ):
  616. if is_fast_path_available and "cuda" in self.in_proj.weight.device.type:
  617. return self.cuda_kernels_forward(hidden_states, cache_params, attention_mask)
  618. return self.torch_forward(hidden_states, cache_params, attention_mask)
  619. class Zamba2MLP(nn.Module):
  620. def __init__(self, config: Zamba2Config, num_fwd_mem_blocks=None, block_id: Optional[int] = None):
  621. """
  622. This MLP layer contributes to tied transformer blocks aimed to increasing compute without increasing model size. Because this layer
  623. is tied, un-tied adapter modules (formally same as LoRA, but used in the base model) are added to the up and gate projectors to increase expressivity with a small memory overhead.
  624. """
  625. super().__init__()
  626. self.config = config
  627. self.hidden_size = config.hidden_size
  628. self.intermediate_size = config.intermediate_size
  629. self.num_fwd_mem_blocks = num_fwd_mem_blocks
  630. self.block_id = block_id
  631. self.gate_up_proj = nn.Linear(self.hidden_size, 2 * self.intermediate_size, bias=config.add_bias_linear)
  632. self.down_proj = nn.Linear(self.intermediate_size, self.hidden_size, bias=config.add_bias_linear)
  633. self.act_fn = ACT2FN[config.hidden_act]
  634. self.gate_up_proj_adapter_list = nn.ModuleList([])
  635. for i in range(self.num_fwd_mem_blocks):
  636. if i % config.num_mem_blocks == block_id:
  637. gate_up_proj_adapter = nn.Sequential(
  638. nn.Linear(self.config.hidden_size, self.config.adapter_rank, bias=False),
  639. nn.Linear(self.config.adapter_rank, 2 * self.intermediate_size, bias=False),
  640. )
  641. else:
  642. gate_up_proj_adapter = nn.Identity()
  643. self.gate_up_proj_adapter_list.append(gate_up_proj_adapter)
  644. layer_block_map = config.hybrid_layer_ids
  645. self.layer_dic = {value: index for index, value in enumerate(layer_block_map)}
  646. def forward(self, hidden_state, layer_idx=None):
  647. gate_up_state = self.gate_up_proj(hidden_state)
  648. layer_idx = self.layer_dic[layer_idx]
  649. gate_up_state = gate_up_state + self.gate_up_proj_adapter_list[layer_idx](hidden_state)
  650. gate_up_state = torch.chunk(gate_up_state, 2, dim=-1)
  651. hidden_state = self.act_fn(gate_up_state[0]) * gate_up_state[1]
  652. output = self.down_proj(hidden_state)
  653. return output
  654. class Zamba2AttentionDecoderLayer(ZambaAttentionDecoderLayer):
  655. def __init__(self, config: Zamba2Config, block_id: Optional[int] = None, layer_idx: Optional[int] = None):
  656. self.block_id = block_id
  657. num_gs = len(config.hybrid_layer_ids)
  658. super().__init__(config, layer_idx)
  659. self.self_attn = Zamba2Attention(config, layer_idx=-1, num_fwd_mem_blocks=num_gs, block_id=block_id)
  660. self.feed_forward = Zamba2MLP(config, num_fwd_mem_blocks=num_gs, block_id=block_id)
  661. @deprecate_kwarg("past_key_value", new_name="past_key_values", version="4.58")
  662. def forward(
  663. self,
  664. hidden_states: torch.Tensor,
  665. original_hidden_states: torch.Tensor,
  666. layer_idx: int,
  667. attention_mask: Optional[torch.Tensor] = None,
  668. past_key_values: Optional[Zamba2HybridDynamicCache] = None,
  669. output_attentions: Optional[bool] = False,
  670. position_embeddings: Optional[torch.LongTensor] = None,
  671. **kwargs: Unpack[FlashAttentionKwargs],
  672. ) -> tuple[torch.FloatTensor, Optional[tuple[torch.FloatTensor, torch.FloatTensor]]]:
  673. """
  674. Args:
  675. hidden_states (`torch.FloatTensor`): output of previous Mamba layer of shape `(batch, seq_len, embed_dim)`
  676. original_hidden_states (`torch.FloatTensor`): word embedding output of shape `(batch, seq_len, embed_dim)`.
  677. This is concatenated with `hidden_states` (which is the output of the previous (mamba) layer). The
  678. concatenated tensor is then used as input of the pre-attention RMSNorm
  679. (see fig. 2 in https://huggingface.co/papers/2405.16712).
  680. attention_mask (`torch.FloatTensor`, *optional*): attention mask of size
  681. `(batch, sequence_length)` where padding elements are indicated by 0.
  682. past_key_values (`Zamba2HybridDynamicCache`, *optional*): cached past key and value projection states
  683. output_attentions (`bool`, *optional*):
  684. Whether or not to return the attentions tensors of all attention layers. See `attentions` under
  685. returned tensors for more detail.
  686. use_cache (`bool`, *optional*):
  687. If set to `True`, `past_key_values` key value states are returned and can be used to speed up decoding
  688. (see `past_key_values`).
  689. position_embeddings (`tuple[torch.FloatTensor, torch.FloatTensor]`, *optional*):
  690. Tuple containing the cosine and sine positional embeddings of shape `(batch_size, seq_len, head_dim)`,
  691. with `head_dim` being the embedding dimension of each attention head.
  692. """
  693. hidden_states = torch.concatenate([hidden_states, original_hidden_states], dim=-1)
  694. hidden_states = self.input_layernorm(hidden_states)
  695. hidden_states, self_attn_weights = self.self_attn(
  696. hidden_states=hidden_states,
  697. layer_idx=layer_idx,
  698. attention_mask=attention_mask,
  699. past_key_values=past_key_values,
  700. output_attentions=output_attentions,
  701. position_embeddings=position_embeddings,
  702. **kwargs,
  703. )
  704. hidden_states = self.pre_ff_layernorm(hidden_states)
  705. hidden_states = self.feed_forward(hidden_states, layer_idx)
  706. outputs = (hidden_states,)
  707. if output_attentions:
  708. outputs += (self_attn_weights,)
  709. return outputs
  710. class Zamba2MambaDecoderLayer(ZambaMambaDecoderLayer):
  711. def __init__(self, config: Zamba2Config, layer_idx: int):
  712. super().__init__(config, layer_idx)
  713. self.mamba = Zamba2MambaMixer(config=config, layer_idx=layer_idx)
  714. self.input_layernorm = Zamba2RMSNorm(config.hidden_size, eps=config.rms_norm_eps)
  715. class Zamba2HybridLayer(ZambaHybridLayer):
  716. def __init__(
  717. self, shared_transformer: Zamba2AttentionDecoderLayer, linear: nn.Linear, mamba: Zamba2MambaDecoderLayer
  718. ):
  719. super().__init__(shared_transformer, linear, mamba)
  720. del self.shared_transf
  721. self.shared_transformer = shared_transformer
  722. @deprecate_kwarg("past_key_value", new_name="past_key_values", version="4.58")
  723. def forward(
  724. self,
  725. hidden_states: torch.Tensor,
  726. original_hidden_states: Optional[torch.Tensor] = None,
  727. layer_idx: Optional[int] = None,
  728. attention_mask: Optional[torch.Tensor] = None,
  729. causal_mask: Optional[torch.Tensor] = None,
  730. past_key_values: Optional[Zamba2HybridDynamicCache] = None,
  731. output_attentions: Optional[bool] = False,
  732. use_cache: Optional[bool] = False,
  733. position_embeddings: Optional[torch.LongTensor] = None,
  734. ) -> tuple[torch.FloatTensor, Optional[tuple[torch.FloatTensor, torch.FloatTensor]]]:
  735. """
  736. Args:
  737. hidden_states (`torch.FloatTensor`): input to the layer of shape `(batch, seq_len, embed_dim)`
  738. original_hidden_states (`torch.FloatTensor`): word embedding output that will be concatenated with
  739. hidden activations to form the input of the shared transformer layer.
  740. layer_idx (`int`): layer number.
  741. attention_mask (`torch.FloatTensor`, *optional*): attention mask of size
  742. `(batch, sequence_length)` where padding elements are indicated by 0.
  743. past_key_values (`Zamba2HybridDynamicCache`, *optional*): cached past key and value projection states
  744. output_attentions (`bool`, *optional*):
  745. Whether or not to return the attentions tensors of all attention layers. See `attentions` under
  746. returned tensors for more detail.
  747. use_cache (`bool`, *optional*):
  748. If set to `True`, `past_key_values` key value states are returned and can be used to speed up decoding
  749. (see `past_key_values`).
  750. position_embeddings (`tuple[torch.FloatTensor, torch.FloatTensor]`, *optional*):
  751. Tuple containing the cosine and sine positional embeddings of shape `(batch_size, seq_len, head_dim)`,
  752. with `head_dim` being the embedding dimension of each attention head.
  753. """
  754. layer_outputs = self.shared_transformer(
  755. hidden_states,
  756. original_hidden_states=original_hidden_states,
  757. layer_idx=layer_idx,
  758. attention_mask=causal_mask,
  759. past_key_values=past_key_values,
  760. output_attentions=output_attentions,
  761. position_embeddings=position_embeddings,
  762. )
  763. transformer_hidden_states = layer_outputs[0]
  764. if output_attentions:
  765. self_attn_weights = layer_outputs[1]
  766. transformer_hidden_states = self.linear(transformer_hidden_states)
  767. layer_outputs = self.mamba_decoder(
  768. hidden_states,
  769. transformer_hidden_states=transformer_hidden_states,
  770. attention_mask=attention_mask,
  771. past_key_values=past_key_values,
  772. output_attentions=output_attentions,
  773. use_cache=use_cache,
  774. position_embeddings=position_embeddings,
  775. )
  776. if output_attentions:
  777. layer_outputs = (layer_outputs[0], self_attn_weights) + layer_outputs[2:]
  778. return layer_outputs
  779. class Zamba2PreTrainedModel(PreTrainedModel):
  780. config: Zamba2Config
  781. base_model_prefix = "model"
  782. supports_gradient_checkpointing = True
  783. _no_split_modules = ["Zamba2AttentionDecoderLayer", "Zamba2MambaDecoderLayer"]
  784. _skip_keys_device_placement = "past_key_values"
  785. _supports_flash_attn = True
  786. _supports_flex_attn = True
  787. _supports_sdpa = True
  788. # Note: only supports Zamba2HybridDynamicCache
  789. _is_stateful = True
  790. def _init_weights(self, module):
  791. super()._init_weights(module)
  792. if isinstance(module, Zamba2MambaMixer):
  793. dt = torch.exp(
  794. torch.rand(self.config.n_mamba_heads)
  795. * (math.log(self.config.time_step_max) - math.log(self.config.time_step_min))
  796. + math.log(self.config.time_step_min)
  797. ).clamp(min=self.config.time_step_floor)
  798. # # Inverse of softplus: https://github.com/pytorch/pytorch/issues/72759
  799. inv_dt = dt + torch.log(-torch.expm1(-dt))
  800. module.dt_bias.data.copy_(inv_dt)
  801. A = torch.arange(1, module.num_heads + 1)
  802. module.A_log.data.copy_(torch.log(A))
  803. module.D.data.fill_(1.0)
  804. class Zamba2Model(ZambaModel, Zamba2PreTrainedModel):
  805. """
  806. Model consisting of *config.num_hidden_layers* layers.
  807. Args:
  808. config: Zamba2Config
  809. """
  810. def __init__(self, config: Zamba2Config):
  811. Zamba2PreTrainedModel.__init__(self, config)
  812. self.config = config
  813. self.padding_idx = config.pad_token_id
  814. self.vocab_size = config.vocab_size
  815. self.embed_tokens = nn.Embedding(config.vocab_size, config.hidden_size, self.padding_idx)
  816. blocks = [Zamba2AttentionDecoderLayer(config, block_id=k) for k in range(config.num_mem_blocks)]
  817. mamba_layers = []
  818. linear_layers = []
  819. self.layers_block_type = config.layers_block_type
  820. for i in range(config.num_hidden_layers):
  821. if config.layers_block_type[i] == "mamba":
  822. mamba_layers.append(Zamba2MambaDecoderLayer(config, layer_idx=i))
  823. elif config.layers_block_type[i] == "hybrid":
  824. linear_layers.append(nn.Linear(self.config.hidden_size, self.config.hidden_size, bias=False))
  825. mamba_layers.append(Zamba2MambaDecoderLayer(config, layer_idx=i))
  826. mamba_layers = iter(mamba_layers)
  827. linear_layers = iter(linear_layers)
  828. blocks = cycle(blocks)
  829. layers = self.get_layers(blocks, linear_layers, mamba_layers)
  830. self.layers = nn.ModuleList(layers)
  831. self._attn_implementation = config._attn_implementation
  832. self.final_layernorm = Zamba2RMSNorm(config.hidden_size, eps=config.rms_norm_eps)
  833. if config.use_mem_rope:
  834. if config.use_long_context:
  835. logger.warning_once(
  836. "`use_long_context` set to `True`: using rescaled `rope_theta` and extended `max_position_embeddings`."
  837. )
  838. self.rotary_emb = Zamba2RotaryEmbedding(config)
  839. self.gradient_checkpointing = False
  840. # Initialize weights and apply final processing
  841. self.post_init()
  842. def get_layers(self, blocks, linear_layers, mamba_layers):
  843. layers = []
  844. self._tied_weights_keys = []
  845. self.first_transformer_layer_id = 0
  846. for layer_id, layer_type in enumerate(self.layers_block_type):
  847. if layer_type == "hybrid":
  848. if self.first_transformer_layer_id == 0:
  849. self.first_transformer_layer_id = layer_id
  850. block = next(blocks)
  851. if self.config.num_mem_blocks * len(self.config.hybrid_layer_ids) > 1:
  852. prefix_pattern = rf"^layers\.{layer_id}\.shared_transformer\."
  853. main_keys_pattern = re.compile(
  854. prefix_pattern
  855. + r"(?:"
  856. + r"self_attn\.(?:q_proj|k_proj|v_proj|o_proj)\.weight|"
  857. + r"feed_forward\.(?:gate_up_proj|down_proj)\.weight|"
  858. + r"(?:input_layernorm|pre_ff_layernorm)\.weight"
  859. + r")$"
  860. )
  861. self._tied_weights_keys.append(main_keys_pattern)
  862. adapter_id = 0
  863. for _layer_type in self.layers_block_type:
  864. if _layer_type == "hybrid" and adapter_id % self.config.num_mem_blocks == block.block_id:
  865. adapter_pattern = re.compile(
  866. r"^shared_transformer\.feed_forward\.gate_up_proj_adapter_list\."
  867. + str(adapter_id)
  868. + r"\.(?:0|1)\.weight$"
  869. )
  870. self._tied_weights_keys.append(adapter_pattern)
  871. adapter_id += 1
  872. if self.config.use_shared_attention_adapter:
  873. adapter_id = 0
  874. for _layer_type in self.layers_block_type:
  875. if _layer_type == "hybrid" and adapter_id % self.config.num_mem_blocks == block.block_id:
  876. attn_adapter_pattern = re.compile(
  877. r"^shared_transformer\.self_attn\."
  878. + r"(?:linear_q_adapter_list|linear_k_adapter_list|linear_v_adapter_list)\."
  879. + str(adapter_id)
  880. + r"\.(?:0|1)\.weight$"
  881. )
  882. self._tied_weights_keys.append(attn_adapter_pattern)
  883. adapter_id += 1
  884. layers.append(Zamba2HybridLayer(block, next(linear_layers), next(mamba_layers)))
  885. else:
  886. layers.append(next(mamba_layers))
  887. return layers
  888. def forward(
  889. self,
  890. input_ids: Optional[torch.LongTensor] = None,
  891. attention_mask: Optional[torch.Tensor] = None,
  892. position_ids: Optional[torch.LongTensor] = None,
  893. past_key_values: Optional[Zamba2HybridDynamicCache] = None,
  894. inputs_embeds: Optional[torch.FloatTensor] = None,
  895. use_cache: Optional[bool] = None,
  896. output_attentions: Optional[bool] = None,
  897. output_hidden_states: Optional[bool] = None,
  898. return_dict: Optional[bool] = None,
  899. cache_position: Optional[torch.LongTensor] = None,
  900. ) -> Union[tuple, BaseModelOutputWithPast]:
  901. output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
  902. output_hidden_states = (
  903. output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
  904. )
  905. use_cache = use_cache if use_cache is not None else self.config.use_cache
  906. return_dict = return_dict if return_dict is not None else self.config.use_return_dict
  907. if (input_ids is None) ^ (inputs_embeds is not None):
  908. raise ValueError(
  909. "You cannot specify both input_ids and inputs_embeds at the same time, and must specify either one"
  910. )
  911. if self.gradient_checkpointing and self.training and use_cache:
  912. logger.warning_once(
  913. "`use_cache=True` is incompatible with gradient checkpointing. Setting `use_cache=False`."
  914. )
  915. use_cache = False
  916. if inputs_embeds is None:
  917. inputs_embeds = self.embed_tokens(input_ids)
  918. hidden_states = inputs_embeds
  919. original_hidden_states = torch.clone(inputs_embeds)
  920. # original_hidden_states: word embedding output that will be concatenated with hidden activations to form the input of the shared transformer layer
  921. if use_cache and past_key_values is None:
  922. batch_size = input_ids.shape[0] if input_ids is not None else inputs_embeds.shape[0]
  923. past_key_values = Zamba2HybridDynamicCache(self.config, batch_size, dtype=self.dtype, device=self.device)
  924. if cache_position is None:
  925. past_seen_tokens = (
  926. past_key_values.get_seq_length(layer_idx=self.first_transformer_layer_id)
  927. if past_key_values is not None
  928. else 0
  929. )
  930. cache_position = torch.arange(
  931. past_seen_tokens, past_seen_tokens + inputs_embeds.shape[1], device=inputs_embeds.device
  932. )
  933. if position_ids is None:
  934. position_ids = cache_position.unsqueeze(0)
  935. causal_mask = self._update_causal_mask(attention_mask, inputs_embeds, cache_position)
  936. # create position embeddings to be shared across the decoder layers
  937. if self.config.use_mem_rope:
  938. position_embeddings = self.rotary_emb(hidden_states, position_ids)
  939. else:
  940. position_embeddings = None
  941. all_hidden_states = () if output_hidden_states else None
  942. all_self_attns = () if output_attentions else None
  943. for layer_idx, layer in enumerate(self.layers):
  944. if output_hidden_states:
  945. all_hidden_states += (hidden_states,)
  946. if self.gradient_checkpointing and self.training:
  947. layer_outputs = self._gradient_checkpointing_func(
  948. layer.__call__,
  949. hidden_states,
  950. original_hidden_states,
  951. layer_idx,
  952. attention_mask,
  953. causal_mask,
  954. past_key_values,
  955. output_attentions,
  956. use_cache,
  957. position_embeddings,
  958. )
  959. else:
  960. layer_outputs = layer(
  961. hidden_states,
  962. original_hidden_states=original_hidden_states,
  963. layer_idx=layer_idx,
  964. attention_mask=attention_mask,
  965. causal_mask=causal_mask,
  966. past_key_values=past_key_values,
  967. output_attentions=output_attentions,
  968. use_cache=use_cache,
  969. position_embeddings=position_embeddings,
  970. )
  971. hidden_states = layer_outputs[0]
  972. if output_attentions:
  973. if layer_outputs[1] is not None:
  974. # append attentions only of attention layers. Mamba layers return `None` as the attention weights
  975. all_self_attns += (layer_outputs[1],)
  976. hidden_states = self.final_layernorm(hidden_states)
  977. # add hidden states from the last decoder layer
  978. if output_hidden_states:
  979. all_hidden_states += (hidden_states,)
  980. if past_key_values is not None and not past_key_values.has_previous_state:
  981. past_key_values.has_previous_state = True
  982. output = BaseModelOutputWithPast(
  983. last_hidden_state=hidden_states,
  984. past_key_values=past_key_values if use_cache else None,
  985. hidden_states=all_hidden_states,
  986. attentions=all_self_attns,
  987. )
  988. return output if return_dict else output.to_tuple()
  989. class Zamba2ForCausalLM(ZambaForCausalLM):
  990. pass
  991. class Zamba2ForSequenceClassification(ZambaForSequenceClassification):
  992. pass
  993. __all__ = [
  994. "Zamba2ForCausalLM",
  995. "Zamba2ForSequenceClassification",
  996. "Zamba2Model",
  997. "Zamba2PreTrainedModel",
  998. ]