modeling_zamba.py 62 KB

12345678910111213141516171819202122232425262728293031323334353637383940414243444546474849505152535455565758596061626364656667686970717273747576777879808182838485868788899091929394959697989910010110210310410510610710810911011111211311411511611711811912012112212312412512612712812913013113213313413513613713813914014114214314414514614714814915015115215315415515615715815916016116216316416516616716816917017117217317417517617717817918018118218318418518618718818919019119219319419519619719819920020120220320420520620720820921021121221321421521621721821922022122222322422522622722822923023123223323423523623723823924024124224324424524624724824925025125225325425525625725825926026126226326426526626726826927027127227327427527627727827928028128228328428528628728828929029129229329429529629729829930030130230330430530630730830931031131231331431531631731831932032132232332432532632732832933033133233333433533633733833934034134234334434534634734834935035135235335435535635735835936036136236336436536636736836937037137237337437537637737837938038138238338438538638738838939039139239339439539639739839940040140240340440540640740840941041141241341441541641741841942042142242342442542642742842943043143243343443543643743843944044144244344444544644744844945045145245345445545645745845946046146246346446546646746846947047147247347447547647747847948048148248348448548648748848949049149249349449549649749849950050150250350450550650750850951051151251351451551651751851952052152252352452552652752852953053153253353453553653753853954054154254354454554654754854955055155255355455555655755855956056156256356456556656756856957057157257357457557657757857958058158258358458558658758858959059159259359459559659759859960060160260360460560660760860961061161261361461561661761861962062162262362462562662762862963063163263363463563663763863964064164264364464564664764864965065165265365465565665765865966066166266366466566666766866967067167267367467567667767867968068168268368468568668768868969069169269369469569669769869970070170270370470570670770870971071171271371471571671771871972072172272372472572672772872973073173273373473573673773873974074174274374474574674774874975075175275375475575675775875976076176276376476576676776876977077177277377477577677777877978078178278378478578678778878979079179279379479579679779879980080180280380480580680780880981081181281381481581681781881982082182282382482582682782882983083183283383483583683783883984084184284384484584684784884985085185285385485585685785885986086186286386486586686786886987087187287387487587687787887988088188288388488588688788888989089189289389489589689789889990090190290390490590690790890991091191291391491591691791891992092192292392492592692792892993093193293393493593693793893994094194294394494594694794894995095195295395495595695795895996096196296396496596696796896997097197297397497597697797897998098198298398498598698798898999099199299399499599699799899910001001100210031004100510061007100810091010101110121013101410151016101710181019102010211022102310241025102610271028102910301031103210331034103510361037103810391040104110421043104410451046104710481049105010511052105310541055105610571058105910601061106210631064106510661067106810691070107110721073107410751076107710781079108010811082108310841085108610871088108910901091109210931094109510961097109810991100110111021103110411051106110711081109111011111112111311141115111611171118111911201121112211231124112511261127112811291130113111321133113411351136113711381139114011411142114311441145114611471148114911501151115211531154115511561157115811591160116111621163116411651166116711681169117011711172117311741175117611771178117911801181118211831184118511861187118811891190119111921193119411951196119711981199120012011202120312041205120612071208120912101211121212131214121512161217121812191220122112221223122412251226122712281229123012311232123312341235123612371238123912401241124212431244124512461247124812491250125112521253125412551256125712581259126012611262126312641265126612671268126912701271127212731274127512761277127812791280128112821283128412851286128712881289129012911292129312941295129612971298129913001301130213031304130513061307130813091310131113121313131413151316
  1. # coding=utf-8
  2. # Copyright 2024 Zyphra Technologies and the HuggingFace Inc. team. All rights reserved.
  3. #
  4. # This code is based on EleutherAI's GPT-NeoX library and the GPT-NeoX
  5. # and OPT implementations in this library. It has been modified from its
  6. # original forms to accommodate minor architectural differences compared
  7. # to GPT-NeoX and OPT used by the Meta AI team that trained the model.
  8. #
  9. # Licensed under the Apache License, Version 2.0 (the "License");
  10. # you may not use this file except in compliance with the License.
  11. # You may obtain a copy of the License at
  12. #
  13. # http://www.apache.org/licenses/LICENSE-2.0
  14. #
  15. # Unless required by applicable law or agreed to in writing, software
  16. # distributed under the License is distributed on an "AS IS" BASIS,
  17. # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  18. # See the License for the specific language governing permissions and
  19. # limitations under the License.
  20. """PyTorch Zamba model."""
  21. import math
  22. from typing import Any, Callable, Optional, Union
  23. import torch
  24. from torch import nn
  25. from torch.nn import BCEWithLogitsLoss, CrossEntropyLoss, MSELoss
  26. from ...activations import ACT2FN
  27. from ...cache_utils import Cache
  28. from ...generation import GenerationMixin
  29. from ...modeling_attn_mask_utils import AttentionMaskConverter
  30. from ...modeling_flash_attention_utils import FlashAttentionKwargs
  31. from ...modeling_outputs import BaseModelOutputWithPast, CausalLMOutputWithPast, SequenceClassifierOutputWithPast
  32. from ...modeling_utils import ALL_ATTENTION_FUNCTIONS, PreTrainedModel
  33. from ...processing_utils import Unpack
  34. from ...utils import auto_docstring, logging
  35. from ...utils.deprecation import deprecate_kwarg
  36. from ...utils.import_utils import is_causal_conv1d_available, is_mamba_ssm_available
  37. from .configuration_zamba import ZambaConfig
  38. if is_mamba_ssm_available():
  39. from mamba_ssm.ops.selective_scan_interface import mamba_inner_fn, selective_scan_fn
  40. from mamba_ssm.ops.triton.selective_state_update import selective_state_update
  41. else:
  42. selective_state_update, selective_scan_fn, mamba_inner_fn = None, None, None
  43. if is_causal_conv1d_available():
  44. from causal_conv1d import causal_conv1d_fn, causal_conv1d_update
  45. else:
  46. causal_conv1d_update, causal_conv1d_fn = None, None
  47. is_fast_path_available = all(
  48. (selective_state_update, selective_scan_fn, causal_conv1d_fn, causal_conv1d_update, mamba_inner_fn)
  49. )
  50. logger = logging.get_logger(__name__)
  51. # Copied from transformers.models.llama.modeling_llama.LlamaRMSNorm with Llama->Zamba
  52. class ZambaRMSNorm(nn.Module):
  53. def __init__(self, hidden_size, eps=1e-6):
  54. """
  55. ZambaRMSNorm is equivalent to T5LayerNorm
  56. """
  57. super().__init__()
  58. self.weight = nn.Parameter(torch.ones(hidden_size))
  59. self.variance_epsilon = eps
  60. def forward(self, hidden_states):
  61. input_dtype = hidden_states.dtype
  62. hidden_states = hidden_states.to(torch.float32)
  63. variance = hidden_states.pow(2).mean(-1, keepdim=True)
  64. hidden_states = hidden_states * torch.rsqrt(variance + self.variance_epsilon)
  65. return self.weight * hidden_states.to(input_dtype)
  66. def extra_repr(self):
  67. return f"{tuple(self.weight.shape)}, eps={self.variance_epsilon}"
  68. # Copied from transformers.models.llama.modeling_llama.repeat_kv
  69. def repeat_kv(hidden_states: torch.Tensor, n_rep: int) -> torch.Tensor:
  70. """
  71. This is the equivalent of torch.repeat_interleave(x, dim=1, repeats=n_rep). The hidden states go from (batch,
  72. num_key_value_heads, seqlen, head_dim) to (batch, num_attention_heads, seqlen, head_dim)
  73. """
  74. batch, num_key_value_heads, slen, head_dim = hidden_states.shape
  75. if n_rep == 1:
  76. return hidden_states
  77. hidden_states = hidden_states[:, :, None, :, :].expand(batch, num_key_value_heads, n_rep, slen, head_dim)
  78. return hidden_states.reshape(batch, num_key_value_heads * n_rep, slen, head_dim)
  79. class ZambaHybridDynamicCache:
  80. """
  81. A dynamic cache that can handle both the attention cache (which has a seq_len dimension) and the mamba cache
  82. (which has a constant shape regardless of seq_len).
  83. This cache has two sets of lists of tensors: `key_cache` and `value_cache` for attention cache and `conv_states`
  84. and `ssm_states` for mamba cache. Each of these lists has `num_layers` tensors. The expected shape for each tensor
  85. For attention layers, `key_cache` and `value_cache` have a shape of `(batch_size, num_heads, seq_len, head_dim)`,
  86. while `conv_states` and `ssm_states` have a shape of `(batch_size, 0)` (empty tensors).
  87. For mamba layers, `key_cache` and `value_cache` have a shape of `(batch_size, 0)` (empty tensors),
  88. while `conv_states` represents the convolution state and has a shape of `(batch_size, d_inner, d_conv)`,
  89. and `ssm_states` represents the ssm state and has a shape of `(batch_size, d_inner, d_state)`.
  90. """
  91. is_compileable = False
  92. def __init__(self, config, batch_size, dtype=torch.float16, device=None):
  93. self.dtype = dtype
  94. self.is_compileable = False
  95. self.layers_block_type = config.layers_block_type
  96. self.has_previous_state = False # only used by mamba
  97. self.intermediate_size = config.mamba_expand * config.hidden_size
  98. self.ssm_state_size = config.mamba_d_state
  99. self.conv_kernel_size = config.mamba_d_conv
  100. self.n_mamba_heads = config.n_mamba_heads
  101. self.conv_states = []
  102. self.ssm_states = []
  103. self.transformer_layers = []
  104. self._modules = {}
  105. self._parameters = {}
  106. self._buffers = {}
  107. for i in range(config.num_hidden_layers):
  108. self.conv_states += [
  109. torch.zeros(batch_size, self.intermediate_size, self.conv_kernel_size, device=device, dtype=dtype)
  110. ]
  111. cache_shape = (
  112. batch_size,
  113. self.n_mamba_heads,
  114. self.intermediate_size // self.n_mamba_heads,
  115. self.ssm_state_size,
  116. )
  117. self.ssm_states += [torch.zeros(cache_shape, device=device, dtype=dtype)]
  118. if self.layers_block_type[i] == "hybrid":
  119. self.transformer_layers.append(i)
  120. self.key_cache = [torch.tensor([[]] * batch_size, device=device) for _ in range(config.num_hidden_layers)]
  121. self.value_cache = [torch.tensor([[]] * batch_size, device=device) for _ in range(config.num_hidden_layers)]
  122. def __len__(self):
  123. return len(self.key_cache)
  124. def __getitem__(self, layer_idx: int) -> tuple[torch.Tensor, torch.Tensor]:
  125. return self.key_cache[layer_idx], self.value_cache[layer_idx]
  126. # Copied from transformers.models.jamba.modeling_jamba.HybridMambaAttentionDynamicCache.update
  127. def update(
  128. self,
  129. key_states: torch.Tensor,
  130. value_states: torch.Tensor,
  131. layer_idx: int,
  132. cache_kwargs: Optional[dict[str, Any]] = None,
  133. ) -> tuple[torch.Tensor, torch.Tensor]:
  134. # Update the cache
  135. if self.key_cache[layer_idx].shape[-1] == 0:
  136. self.key_cache[layer_idx] = key_states
  137. self.value_cache[layer_idx] = value_states
  138. else:
  139. self.key_cache[layer_idx] = torch.cat([self.key_cache[layer_idx], key_states], dim=2)
  140. self.value_cache[layer_idx] = torch.cat([self.value_cache[layer_idx], value_states], dim=2)
  141. return self.key_cache[layer_idx], self.value_cache[layer_idx]
  142. # Copied from transformers.models.jamba.modeling_jamba.HybridMambaAttentionDynamicCache.reorder_cache
  143. def reorder_cache(self, beam_idx: torch.LongTensor):
  144. """Reorders the cache for beam search, given the selected beam indices."""
  145. for layer_idx in range(len(self.key_cache)):
  146. device = self.key_cache[layer_idx].device
  147. self.key_cache[layer_idx] = self.key_cache[layer_idx].index_select(0, beam_idx.to(device))
  148. device = self.value_cache[layer_idx].device
  149. self.value_cache[layer_idx] = self.value_cache[layer_idx].index_select(0, beam_idx.to(device))
  150. device = self.conv_states[layer_idx].device
  151. self.conv_states[layer_idx] = self.conv_states[layer_idx].index_select(0, beam_idx.to(device))
  152. device = self.ssm_states[layer_idx].device
  153. self.ssm_states[layer_idx] = self.ssm_states[layer_idx].index_select(0, beam_idx.to(device))
  154. # Copied from transformers.models.jamba.modeling_jamba.HybridMambaAttentionDynamicCache.get_seq_length
  155. def get_seq_length(self, layer_idx: Optional[int] = 0) -> int:
  156. """Returns the sequence length of the cached states. A layer index can be optionally passed."""
  157. # take any layer that contains cache and not empty tensor
  158. layer_idx = self.transformer_layers[0] if layer_idx not in self.transformer_layers else layer_idx
  159. if len(self.key_cache) <= layer_idx:
  160. return 0
  161. return self.key_cache[layer_idx].shape[-2]
  162. def eager_attention_forward(
  163. module: nn.Module,
  164. query: torch.Tensor,
  165. key: torch.Tensor,
  166. value: torch.Tensor,
  167. attention_mask: Optional[torch.Tensor],
  168. scaling: float,
  169. dropout: float = 0.0,
  170. **kwargs,
  171. ):
  172. key_states = repeat_kv(key, module.num_key_value_groups)
  173. value_states = repeat_kv(value, module.num_key_value_groups)
  174. attn_weights = torch.matmul(query, key_states.transpose(2, 3)) * scaling
  175. if attention_mask is not None:
  176. causal_mask = attention_mask[:, :, :, : key_states.shape[-2]]
  177. attn_weights = attn_weights + causal_mask
  178. attn_weights = nn.functional.softmax(attn_weights, dim=-1, dtype=torch.float32).to(query.dtype)
  179. attn_weights = nn.functional.dropout(attn_weights, p=dropout, training=module.training)
  180. attn_output = torch.matmul(attn_weights, value_states)
  181. attn_output = attn_output.transpose(1, 2).contiguous()
  182. return attn_output, attn_weights
  183. class ZambaAttention(nn.Module):
  184. """
  185. Multi-headed attention from 'Attention Is All You Need' paper. Modified to use sliding window attention: Longformer
  186. and "Generating Long Sequences with Sparse Transformers".
  187. Adapted from transformers.models.mistral.modeling_mistral.MistralAttention:
  188. The input dimension here is attention_hidden_size = 2 * hidden_size, and head_dim = attention_hidden_size // num_heads.
  189. The extra factor of 2 comes from the input being the concatenation of original_hidden_states with the output of the previous (mamba) layer
  190. (see fig. 2 in https://huggingface.co/papers/2405.16712).
  191. Additionally, replaced
  192. attn_weights = torch.matmul(query_states, key_states.transpose(2, 3)) / math.sqrt(self.head_dim) with
  193. attn_weights = torch.matmul(query_states, key_states.transpose(2, 3)) / math.sqrt(self.head_dim/2)
  194. """
  195. def __init__(self, config: ZambaConfig, layer_idx: int):
  196. super().__init__()
  197. self.config = config
  198. self.layer_idx = layer_idx
  199. self.attention_hidden_size = config.attention_hidden_size
  200. self.head_dim = config.attention_head_dim
  201. self.num_key_value_groups = config.num_attention_heads // config.num_key_value_heads
  202. self.max_position_embeddings = config.max_position_embeddings
  203. self.scaling = (self.head_dim / 2) ** -0.5
  204. self.is_causal = True
  205. self.attention_dropout = config.attention_dropout
  206. self.q_proj = nn.Linear(config.attention_hidden_size, config.num_attention_heads * self.head_dim, bias=False)
  207. self.k_proj = nn.Linear(config.attention_hidden_size, config.num_key_value_heads * self.head_dim, bias=False)
  208. self.v_proj = nn.Linear(config.attention_hidden_size, config.num_key_value_heads * self.head_dim, bias=False)
  209. self.o_proj = nn.Linear(config.num_attention_heads * self.head_dim, config.hidden_size, bias=False)
  210. @deprecate_kwarg("past_key_value", new_name="past_key_values", version="4.58")
  211. def forward(
  212. self,
  213. hidden_states: torch.Tensor,
  214. layer_idx: int,
  215. attention_mask: Optional[torch.Tensor],
  216. past_key_values: Optional[ZambaHybridDynamicCache] = None,
  217. **kwargs: Unpack[FlashAttentionKwargs],
  218. ) -> tuple[torch.Tensor, Optional[torch.Tensor], Optional[tuple[torch.Tensor]]]:
  219. input_shape = hidden_states.shape[:-1]
  220. hidden_shape = (*input_shape, -1, self.head_dim)
  221. query_states = self.q_proj(hidden_states).view(hidden_shape).transpose(1, 2)
  222. key_states = self.k_proj(hidden_states).view(hidden_shape).transpose(1, 2)
  223. value_states = self.v_proj(hidden_states).view(hidden_shape).transpose(1, 2)
  224. if past_key_values is not None:
  225. key_states, value_states = past_key_values.update(key_states, value_states, layer_idx)
  226. attention_interface: Callable = eager_attention_forward
  227. if self.config._attn_implementation != "eager":
  228. attention_interface = ALL_ATTENTION_FUNCTIONS[self.config._attn_implementation]
  229. attn_output, attn_weights = attention_interface(
  230. self,
  231. query_states,
  232. key_states,
  233. value_states,
  234. attention_mask,
  235. dropout=0.0 if not self.training else self.attention_dropout,
  236. scaling=self.scaling,
  237. **kwargs,
  238. )
  239. attn_output = attn_output.reshape(*input_shape, -1).contiguous()
  240. attn_output = self.o_proj(attn_output)
  241. return attn_output, attn_weights
  242. class ZambaMambaMixer(nn.Module):
  243. """
  244. Compute ∆, A, B, C, and D the state space parameters and compute the `contextualized_states`.
  245. A, D are input independent (see Mamba paper [1] Section 3.5.2 "Interpretation of A" for why A isn't selective)
  246. ∆, B, C are input-dependent (this is a key difference between Mamba and the linear time invariant S4,
  247. and is why Mamba is called **selective** state spaces)
  248. This module differs from `transformers.models.mamba.modeling_mamba.MambaMixer` in two ways:
  249. - Added multi-head: the output of `self.in_proj` is split into `self.n_mamba_heads` heads, and each head
  250. undergoes an independent forward pass, identical to the original `MambaMixer`, up until the pre-activations of
  251. `self.out_proj`. The pre-activations, coming from different mamba heads, are then concatenated and fed into `self.out_proj`.
  252. """
  253. def __init__(self, config: ZambaConfig, layer_idx):
  254. super().__init__()
  255. self.config = config
  256. self.layer_idx = layer_idx
  257. self.hidden_size = config.hidden_size
  258. self.ssm_state_size = config.mamba_d_state
  259. self.conv_kernel_size = config.mamba_d_conv
  260. self.intermediate_size = config.mamba_expand * config.hidden_size
  261. self.time_step_rank = config.mamba_dt_rank
  262. self.n_mamba_heads = config.n_mamba_heads
  263. self.mamba_head_dim = self.intermediate_size // self.n_mamba_heads
  264. self.use_conv_bias = config.mamba_conv_bias
  265. self.use_bias = config.mamba_proj_bias
  266. self.conv1d = nn.Conv1d(
  267. in_channels=self.intermediate_size,
  268. out_channels=self.intermediate_size,
  269. bias=self.use_conv_bias,
  270. kernel_size=self.conv_kernel_size,
  271. groups=self.intermediate_size,
  272. padding=self.conv_kernel_size - 1,
  273. )
  274. self.activation = config.hidden_mamba_act
  275. self.act = ACT2FN[config.hidden_mamba_act]
  276. self.use_fast_kernels = config.use_mamba_kernels
  277. # projection of the input hidden states
  278. self.in_proj = nn.Linear(self.hidden_size, self.intermediate_size * 2, bias=self.use_bias)
  279. # weight associated to the selective projection used to make dt, B and C input dependent
  280. # each mamba head is processed independently
  281. self.x_proj_weight = nn.Parameter(
  282. torch.zeros(
  283. self.n_mamba_heads,
  284. self.time_step_rank + self.ssm_state_size * 2,
  285. self.mamba_head_dim,
  286. )
  287. )
  288. # time step projection (discretization)
  289. self.dt_proj_weight = nn.Parameter(
  290. (torch.zeros(self.n_mamba_heads, self.mamba_head_dim, self.time_step_rank) - 0.5)
  291. * 2
  292. / self.time_step_rank**0.5
  293. )
  294. self.dt_proj_bias = nn.Parameter(torch.zeros(self.n_mamba_heads, self.mamba_head_dim))
  295. # S4D real initialization. These are not discretized!
  296. # The core is to load them, compute the discrete states, then write the updated state. Keeps the memory bounded
  297. A = torch.arange(1, self.ssm_state_size + 1, dtype=torch.float32)[None, :]
  298. A = A.expand(self.intermediate_size, -1).contiguous()
  299. self.A_log = nn.Parameter(torch.log(A).reshape(self.n_mamba_heads, self.mamba_head_dim, -1))
  300. self.D = nn.Parameter(torch.ones(self.n_mamba_heads, self.mamba_head_dim))
  301. self.out_proj = nn.Linear(self.intermediate_size, self.hidden_size, bias=self.use_bias)
  302. if not is_fast_path_available:
  303. logger.warning_once(
  304. "The fast path is not available because one of `(selective_state_update, selective_scan_fn, causal_conv1d_fn, causal_conv1d_update, mamba_inner_fn)`"
  305. " is None. To install follow https://github.com/state-spaces/mamba/#installation and"
  306. " https://github.com/Dao-AILab/causal-conv1d. If you want to use the naive implementation, set `use_mamba_kernels=False` in the model config"
  307. )
  308. def cuda_kernels_forward(
  309. self, hidden_states: torch.Tensor, cache_params: ZambaHybridDynamicCache = None, attention_mask=None
  310. ):
  311. batch_size, seq_len, _ = hidden_states.shape
  312. use_precomputed_states = cache_params is not None and cache_params.has_previous_state and seq_len == 1
  313. # 1. Gated linear projection
  314. projected_states = self.in_proj(hidden_states).transpose(1, 2)
  315. hidden_states, gate = projected_states.view(batch_size, -1, 2, seq_len).chunk(2, dim=2)
  316. hidden_states = hidden_states.squeeze(2).contiguous()
  317. gate = gate.squeeze(2)
  318. gate = gate.reshape(batch_size, self.n_mamba_heads, -1, seq_len).transpose(0, 1)
  319. # 2. Convolution sequence transformation
  320. conv_weights = self.conv1d.weight.view(self.conv1d.weight.size(0), self.conv1d.weight.size(2))
  321. if use_precomputed_states:
  322. hidden_states = causal_conv1d_update(
  323. hidden_states.squeeze(-1),
  324. cache_params.conv_states[self.layer_idx],
  325. conv_weights,
  326. self.conv1d.bias,
  327. self.activation,
  328. )
  329. hidden_states = hidden_states.unsqueeze(-1)
  330. else:
  331. if attention_mask is not None and not torch.all(attention_mask == 1):
  332. hidden_states = hidden_states * attention_mask.unsqueeze(1)
  333. if cache_params is not None:
  334. conv_states = nn.functional.pad(hidden_states, (self.conv_kernel_size - hidden_states.shape[-1], 0))
  335. cache_params.conv_states[self.layer_idx].copy_(conv_states)
  336. hidden_states = causal_conv1d_fn(hidden_states, conv_weights, self.conv1d.bias, activation=self.activation)
  337. if attention_mask is not None and not torch.all(attention_mask == 1):
  338. hidden_states = hidden_states * attention_mask.unsqueeze(1)
  339. # 3. SSM sequence transformation
  340. # 3.a. input varying initialization of time_step, B and C
  341. hidden_states = hidden_states.reshape(-1, self.n_mamba_heads, self.mamba_head_dim, seq_len).transpose(0, 1)
  342. ssm_parameters = (self.x_proj_weight[:, None, :, :] @ hidden_states).transpose(-1, -2)
  343. time_step, B, C = torch.split(
  344. ssm_parameters, [self.time_step_rank, self.ssm_state_size, self.ssm_state_size], dim=-1
  345. )
  346. discrete_time_step = self.dt_proj_weight[:, None] @ time_step.transpose(-1, -2)
  347. A = -torch.exp(self.A_log.float())
  348. # 3.c perform the recurrence y ← SSM(A, B, C)(x)
  349. time_proj_bias = self.dt_proj_bias.float() if self.dt_proj_bias is not None else None
  350. scan_outputs = torch.empty((batch_size, 0, seq_len), device=hidden_states.device, dtype=hidden_states.dtype)
  351. if use_precomputed_states:
  352. for n in range(self.n_mamba_heads):
  353. scan_outputs_ = selective_state_update(
  354. cache_params.ssm_states[self.layer_idx][:, n],
  355. hidden_states[n, ..., 0],
  356. discrete_time_step[n, ..., 0],
  357. A[n],
  358. B[n, :, 0],
  359. C[n, :, 0],
  360. self.D[n],
  361. gate[n, ..., 0],
  362. time_proj_bias[n],
  363. dt_softplus=True,
  364. ).unsqueeze(-1)
  365. scan_outputs = torch.cat((scan_outputs, scan_outputs_), dim=1)
  366. else:
  367. ssm_state = torch.empty(
  368. (batch_size, 0, self.mamba_head_dim, self.ssm_state_size),
  369. device=hidden_states.device,
  370. dtype=hidden_states.dtype,
  371. )
  372. for n in range(self.n_mamba_heads):
  373. scan_outputs_, ssm_state_ = selective_scan_fn(
  374. hidden_states[n],
  375. discrete_time_step[n],
  376. A[n],
  377. B[n].transpose(1, 2),
  378. C[n].transpose(1, 2),
  379. self.D[n].float(),
  380. gate[n],
  381. time_proj_bias[n],
  382. delta_softplus=True,
  383. return_last_state=True,
  384. )
  385. scan_outputs = torch.cat((scan_outputs, scan_outputs_), dim=1).contiguous()
  386. ssm_state = torch.cat((ssm_state, ssm_state_.unsqueeze(1)), dim=1)
  387. if ssm_state is not None and cache_params is not None:
  388. cache_params.ssm_states[self.layer_idx].copy_(ssm_state)
  389. # 4. Final linear projection
  390. contextualized_states = self.out_proj(scan_outputs.transpose(1, 2))
  391. return contextualized_states
  392. def slow_forward(self, input_states, cache_params: ZambaHybridDynamicCache = None, attention_mask=None):
  393. batch_size, seq_len, _ = input_states.shape
  394. dtype = input_states.dtype
  395. # 1. Gated linear projection
  396. projected_states = self.in_proj(input_states).transpose(1, 2)
  397. hidden_states, gate = projected_states.view(batch_size, -1, 2, seq_len).chunk(2, dim=2)
  398. hidden_states = hidden_states.squeeze(2).contiguous()
  399. gate = gate.squeeze(2)
  400. gate = gate.reshape(batch_size, self.n_mamba_heads, -1, seq_len).transpose(0, 1)
  401. use_cache = isinstance(cache_params, ZambaHybridDynamicCache)
  402. # 2. Convolution sequence transformation
  403. if use_cache and cache_params.ssm_states[self.layer_idx].shape[0] == batch_size:
  404. if self.training:
  405. # In training mode, we don't want to perform in-place operations on ssm_state so we can compute the backwards pass
  406. ssm_state = cache_params.ssm_states[self.layer_idx].clone()
  407. else:
  408. ssm_state = cache_params.ssm_states[self.layer_idx]
  409. ssm_state = ssm_state.to(hidden_states.device)
  410. if (
  411. cache_params.has_previous_state
  412. and seq_len == 1
  413. and cache_params.conv_states[self.layer_idx].shape[0] == batch_size
  414. ):
  415. conv_state = cache_params.conv_states[self.layer_idx]
  416. conv_state = torch.roll(conv_state, shifts=-1, dims=-1)
  417. conv_state[:, :, -1] = hidden_states[:, :, 0]
  418. cache_params.conv_states[self.layer_idx] = conv_state
  419. hidden_states = torch.sum(conv_state * self.conv1d.weight[:, 0, :], dim=-1)
  420. if self.use_conv_bias:
  421. hidden_states += self.conv1d.bias
  422. hidden_states = self.act(hidden_states).to(dtype).unsqueeze(-1)
  423. else:
  424. if attention_mask is not None and not torch.all(attention_mask == 1):
  425. hidden_states = hidden_states * attention_mask[:, -hidden_states.shape[-1] :].unsqueeze(1)
  426. conv_state = nn.functional.pad(hidden_states, (self.conv_kernel_size - hidden_states.shape[-1], 0))
  427. cache_params.conv_states[self.layer_idx] = conv_state
  428. hidden_states = self.act(self.conv1d(hidden_states)[..., :seq_len])
  429. if attention_mask is not None and not torch.all(attention_mask == 1):
  430. hidden_states = hidden_states * attention_mask[:, -hidden_states.shape[-1] :].unsqueeze(1)
  431. else:
  432. ssm_state = torch.zeros(
  433. (batch_size, self.n_mamba_heads, self.mamba_head_dim, self.ssm_state_size),
  434. device=hidden_states.device,
  435. dtype=dtype,
  436. )
  437. if attention_mask is not None and not torch.all(attention_mask == 1):
  438. hidden_states = hidden_states * attention_mask.unsqueeze(1)
  439. hidden_states = self.act(self.conv1d(hidden_states)[..., :seq_len])
  440. if attention_mask is not None and not torch.all(attention_mask == 1):
  441. hidden_states = hidden_states * attention_mask.unsqueeze(1)
  442. # 3. State Space Model sequence transformation
  443. # 3.a. Selection: [batch, seq_len, self.time_step_rank + self.ssm_state_size * 2]
  444. hidden_states = hidden_states.reshape(-1, self.n_mamba_heads, self.mamba_head_dim, seq_len).transpose(0, 1)
  445. ssm_parameters = (self.x_proj_weight[:, None, :, :] @ hidden_states).transpose(-1, -2)
  446. time_step, B, C = torch.split(
  447. ssm_parameters, [self.time_step_rank, self.ssm_state_size, self.ssm_state_size], dim=-1
  448. )
  449. discrete_time_step = (self.dt_proj_weight[:, None] @ time_step.transpose(-1, -2)) + self.dt_proj_bias[
  450. :, None, :, None
  451. ]
  452. discrete_time_step = nn.functional.softplus(discrete_time_step)
  453. # 3.b. Discretization: B and C to [batch, seq_len, intermediate_size, ssm_state_size] (SRAM)
  454. A = -torch.exp(self.A_log.float())
  455. discrete_A = torch.exp(A[:, None, :, None, :] * discrete_time_step[:, :, :, :, None])
  456. discrete_B = discrete_time_step[:, :, :, :, None] * B[:, :, None, :, :].float()
  457. deltaB_u = discrete_B * hidden_states[:, :, :, :, None].float()
  458. # 3.c perform the recurrence y ← SSM(A, B, C)(x)
  459. scan_outputs = []
  460. for i in range(seq_len):
  461. ssm_state = discrete_A[:, :, :, i, :].transpose(0, 1) * ssm_state + deltaB_u[:, :, :, i, :].transpose(0, 1)
  462. scan_output = torch.matmul(ssm_state.transpose(0, 1).to(dtype), C[:, :, i, :].unsqueeze(-1))
  463. scan_outputs.append(scan_output[:, :, :, 0])
  464. scan_output = torch.stack(scan_outputs, dim=-1)
  465. scan_output = scan_output + (hidden_states * self.D[:, None, :, None])
  466. scan_output = scan_output * self.act(gate)
  467. if use_cache:
  468. cache_params.ssm_states[self.layer_idx] = ssm_state
  469. # 4. Final linear projection
  470. contextualized_states = self.out_proj(
  471. scan_output.transpose(0, 1).reshape(batch_size, -1, seq_len).transpose(1, 2)
  472. )
  473. return contextualized_states
  474. def forward(self, hidden_states, cache_params: ZambaHybridDynamicCache = None, attention_mask=None):
  475. if self.use_fast_kernels:
  476. if not is_fast_path_available or "cuda" not in self.x_proj_weight.device.type:
  477. raise ValueError(
  478. "Fast Mamba kernels are not available. Make sure to they are installed and that "
  479. "the mamba module is on a CUDA device. lease run 'pip install causal-conv1d>=1.2.0' "
  480. "and 'pip install mamba-ssm', or set use_mamba_kernels=False in the model's config."
  481. )
  482. return self.cuda_kernels_forward(hidden_states, cache_params, attention_mask=attention_mask)
  483. return self.slow_forward(hidden_states, cache_params, attention_mask=attention_mask)
  484. # Copied from transformers.models.mistral.modeling_mistral.MistralMLP with Mistral->Zamba
  485. class ZambaMLP(nn.Module):
  486. def __init__(self, config):
  487. super().__init__()
  488. self.config = config
  489. self.hidden_size = config.hidden_size
  490. self.intermediate_size = config.intermediate_size
  491. self.gate_proj = nn.Linear(self.hidden_size, self.intermediate_size, bias=False)
  492. self.up_proj = nn.Linear(self.hidden_size, self.intermediate_size, bias=False)
  493. self.down_proj = nn.Linear(self.intermediate_size, self.hidden_size, bias=False)
  494. self.act_fn = ACT2FN[config.hidden_act]
  495. def forward(self, x):
  496. down_proj = self.down_proj(self.act_fn(self.gate_proj(x)) * self.up_proj(x))
  497. return down_proj
  498. class ZambaAttentionDecoderLayer(nn.Module):
  499. def __init__(self, config: ZambaConfig, layer_idx: Optional[int] = None):
  500. super().__init__()
  501. self.self_attn = ZambaAttention(config, layer_idx)
  502. self.feed_forward = ZambaMLP(config)
  503. self.input_layernorm = ZambaRMSNorm(config.attention_hidden_size, eps=config.rms_norm_eps)
  504. self.pre_ff_layernorm = ZambaRMSNorm(config.hidden_size, eps=config.rms_norm_eps)
  505. @deprecate_kwarg("past_key_value", new_name="past_key_values", version="4.58")
  506. def forward(
  507. self,
  508. hidden_states: torch.Tensor,
  509. original_hidden_states: torch.Tensor,
  510. layer_idx: int,
  511. attention_mask: Optional[torch.Tensor] = None,
  512. past_key_values: Optional[ZambaHybridDynamicCache] = None,
  513. output_attentions: Optional[bool] = False,
  514. use_cache: Optional[bool] = False,
  515. **kwargs: Unpack[FlashAttentionKwargs],
  516. ) -> tuple[torch.FloatTensor, Optional[tuple[torch.FloatTensor, torch.FloatTensor]]]:
  517. """
  518. Args:
  519. hidden_states (`torch.FloatTensor`): output of previous Mamba layer of shape `(batch, seq_len, embed_dim)`
  520. original_hidden_states (`torch.FloatTensor`): word embedding output of shape `(batch, seq_len, embed_dim)`.
  521. This is concatenated with `hidden_states` (which is the output of the previous (mamba) layer). The
  522. concatenated tensor is then used as input of the pre-attention RMSNorm
  523. (see fig. 2 in https://huggingface.co/papers/2405.16712).
  524. layer_idx (`int`): layer_idx in the forward pass. Used to distinguish Zamba's tied transformer layers.
  525. attention_mask (`torch.FloatTensor`, *optional*): attention mask of size
  526. `(batch, sequence_length)` where padding elements are indicated by 0.
  527. past_key_values (`ZambaHybridDynamicCache`, *optional*): cached past key and value projection states
  528. output_attentions (`bool`, *optional*):
  529. Whether or not to return the attentions tensors of all attention layers. See `attentions` under
  530. returned tensors for more detail.
  531. use_cache (`bool`, *optional*):
  532. If set to `True`, `past_key_values` key value states are returned and can be used to speed up decoding
  533. (see `past_key_values`).
  534. cache_position (`torch.LongTensor` of shape `(sequence_length)`, *optional*):
  535. Indices depicting the position of the input sequence tokens in the sequence.
  536. """
  537. hidden_states = torch.concatenate([hidden_states, original_hidden_states], dim=-1)
  538. hidden_states = self.input_layernorm(hidden_states)
  539. hidden_states, self_attn_weights = self.self_attn(
  540. hidden_states=hidden_states,
  541. layer_idx=layer_idx,
  542. attention_mask=attention_mask,
  543. past_key_values=past_key_values,
  544. output_attentions=output_attentions,
  545. use_cache=use_cache,
  546. **kwargs,
  547. )
  548. # feed-forward (MLP)
  549. hidden_states = self.pre_ff_layernorm(hidden_states)
  550. hidden_states = self.feed_forward(hidden_states)
  551. outputs = (hidden_states,)
  552. if output_attentions:
  553. outputs += (self_attn_weights,)
  554. return outputs
  555. class ZambaMambaDecoderLayer(nn.Module):
  556. def __init__(self, config: ZambaConfig, layer_idx: int):
  557. super().__init__()
  558. self.mamba = ZambaMambaMixer(config=config, layer_idx=layer_idx)
  559. self.input_layernorm = ZambaRMSNorm(config.hidden_size, eps=config.rms_norm_eps)
  560. self.layer_idx = layer_idx
  561. @deprecate_kwarg("past_key_value", new_name="past_key_values", version="4.58")
  562. def forward(
  563. self,
  564. hidden_states: torch.Tensor,
  565. original_hidden_states: Optional[torch.Tensor] = None,
  566. layer_idx: Optional[int] = None,
  567. attention_mask: Optional[torch.Tensor] = None,
  568. causal_mask: Optional[torch.Tensor] = None,
  569. past_key_values: Optional[ZambaHybridDynamicCache] = None,
  570. output_attentions: Optional[bool] = False,
  571. use_cache: Optional[bool] = False,
  572. cache_position: Optional[torch.LongTensor] = None,
  573. transformer_hidden_states: Optional[torch.Tensor] = None,
  574. **kwargs,
  575. ) -> tuple[torch.FloatTensor, Optional[tuple[torch.FloatTensor, torch.FloatTensor]]]:
  576. """
  577. Args:
  578. hidden_states (`torch.FloatTensor`): input to the layer of shape `(batch, seq_len, embed_dim)`
  579. attention_mask (`torch.FloatTensor`, *optional*): attention mask of size
  580. `(batch, sequence_length)` where padding elements are indicated by 0.
  581. past_key_values (`ZambaHybridDynamicCache`, *optional*): cached past key and value projection states
  582. output_attentions (`bool`, *optional*):
  583. Whether or not to return the attentions tensors of all attention layers. See `attentions` under
  584. returned tensors for more detail.
  585. use_cache (`bool`, *optional*):
  586. If set to `True`, `past_key_values` key value states are returned and can be used to speed up decoding
  587. (see `past_key_values`).
  588. cache_position (`torch.LongTensor` of shape `(sequence_length)`, *optional*):
  589. Indices depicting the position of the input sequence tokens in the sequence.
  590. """
  591. residual = hidden_states
  592. # `transformer_hidden_states` is the output from shared transformer + linear layer (see fig. 2 in https://huggingface.co/papers/2405.16712).
  593. # `transformer_hidden_states` is then added to the input to the mamba layer below (as described in eq. (6) of https://huggingface.co/papers/2405.16712).
  594. hidden_states = (
  595. hidden_states + transformer_hidden_states if transformer_hidden_states is not None else hidden_states
  596. )
  597. hidden_states = self.input_layernorm(hidden_states)
  598. hidden_states = self.mamba(
  599. hidden_states=hidden_states,
  600. cache_params=past_key_values,
  601. attention_mask=attention_mask,
  602. )
  603. self_attn_weights = None
  604. # residual connection after mamba
  605. hidden_states = residual + hidden_states
  606. outputs = (hidden_states,)
  607. if output_attentions:
  608. outputs += (self_attn_weights,)
  609. if use_cache:
  610. outputs += (past_key_values,)
  611. return outputs
  612. class ZambaHybridLayer(nn.Module):
  613. def __init__(self, shared_transf: ZambaAttentionDecoderLayer, linear: nn.Linear, mamba: ZambaMambaDecoderLayer):
  614. super().__init__()
  615. self.shared_transf = shared_transf
  616. self.linear = linear
  617. self.mamba_decoder = mamba
  618. @deprecate_kwarg("past_key_value", new_name="past_key_values", version="4.58")
  619. def forward(
  620. self,
  621. hidden_states: torch.Tensor,
  622. original_hidden_states: Optional[torch.Tensor] = None,
  623. layer_idx: Optional[int] = None,
  624. attention_mask: Optional[torch.Tensor] = None,
  625. causal_mask: Optional[torch.Tensor] = None,
  626. past_key_values: Optional[ZambaHybridDynamicCache] = None,
  627. output_attentions: Optional[bool] = False,
  628. use_cache: Optional[bool] = False,
  629. cache_position: Optional[torch.LongTensor] = None,
  630. ) -> tuple[torch.FloatTensor, Optional[tuple[torch.FloatTensor, torch.FloatTensor]]]:
  631. """
  632. Args:
  633. hidden_states (`torch.FloatTensor`): input to the layer of shape `(batch, seq_len, embed_dim)`
  634. original_hidden_states (`torch.FloatTensor`): word embedding output that will be concatenated with
  635. hidden activations to form the input of the shared transformer layer.
  636. layer_idx (`int`): layer number.
  637. attention_mask (`torch.FloatTensor`, *optional*): attention mask of size
  638. `(batch, sequence_length)` where padding elements are indicated by 0.
  639. past_key_values (`ZambaHybridDynamicCache`, *optional*): cached past key and value projection states
  640. output_attentions (`bool`, *optional*):
  641. Whether or not to return the attentions tensors of all attention layers. See `attentions` under
  642. returned tensors for more detail.
  643. use_cache (`bool`, *optional*):
  644. If set to `True`, `past_key_values` key value states are returned and can be used to speed up decoding
  645. (see `past_key_values`).
  646. cache_position (`torch.LongTensor` of shape `(sequence_length)`, *optional*):
  647. Indices depicting the position of the input sequence tokens in the sequence.
  648. """
  649. layer_outputs = self.shared_transf(
  650. hidden_states,
  651. original_hidden_states=original_hidden_states,
  652. layer_idx=layer_idx,
  653. attention_mask=causal_mask,
  654. past_key_values=past_key_values,
  655. output_attentions=output_attentions,
  656. use_cache=use_cache,
  657. cache_position=cache_position,
  658. )
  659. transformer_hidden_states = layer_outputs[0]
  660. if output_attentions:
  661. self_attn_weights = layer_outputs[1]
  662. transformer_hidden_states = self.linear(transformer_hidden_states)
  663. layer_outputs = self.mamba_decoder(
  664. hidden_states,
  665. transformer_hidden_states=transformer_hidden_states,
  666. attention_mask=attention_mask,
  667. past_key_values=past_key_values,
  668. output_attentions=output_attentions,
  669. use_cache=use_cache,
  670. cache_position=cache_position,
  671. )
  672. if output_attentions:
  673. layer_outputs = (layer_outputs[0], self_attn_weights) + layer_outputs[2:]
  674. return layer_outputs
  675. @auto_docstring
  676. class ZambaPreTrainedModel(PreTrainedModel):
  677. config: ZambaConfig
  678. base_model_prefix = "model"
  679. supports_gradient_checkpointing = True
  680. _no_split_modules = ["ZambaAttentionDecoderLayer", "ZambaMambaDecoderLayer"]
  681. _skip_keys_device_placement = "past_key_values"
  682. _supports_flash_attn = False
  683. _supports_sdpa = False
  684. # Note: only supports ZambaHybridDynamicCache
  685. _is_stateful = True
  686. def _init_weights(self, module):
  687. std = self.config.initializer_range
  688. if isinstance(module, (nn.Linear, nn.Conv1d)):
  689. module.weight.data.normal_(mean=0.0, std=std)
  690. if module.bias is not None:
  691. module.bias.data.zero_()
  692. elif isinstance(module, nn.Embedding):
  693. module.weight.data.normal_(mean=0.0, std=std)
  694. if module.padding_idx is not None:
  695. module.weight.data[module.padding_idx].zero_()
  696. elif isinstance(module, ZambaRMSNorm):
  697. module.weight.data.fill_(1.0)
  698. elif isinstance(module, ZambaMambaMixer):
  699. module.x_proj_weight.data.normal_(mean=0.0, std=std)
  700. dt_init_std = self.config.mamba_dt_rank**-0.5
  701. nn.init.uniform_(module.dt_proj_weight, -dt_init_std, dt_init_std)
  702. mamba_head_dim = self.config.mamba_expand * self.config.hidden_size // self.config.n_mamba_heads
  703. dt = torch.exp(
  704. torch.rand(self.config.n_mamba_heads, mamba_head_dim)
  705. * (math.log(self.config.time_step_max) - math.log(self.config.time_step_min))
  706. + math.log(self.config.time_step_min)
  707. ).clamp(min=self.config.time_step_floor)
  708. # # Inverse of softplus: https://github.com/pytorch/pytorch/issues/72759
  709. inv_dt = dt + torch.log(-torch.expm1(-dt))
  710. module.dt_proj_bias.data.copy_(inv_dt)
  711. A = torch.arange(1, module.ssm_state_size + 1, dtype=torch.float32)[None, :]
  712. A = A.expand(module.intermediate_size, -1).contiguous()
  713. module.A_log.data.copy_(torch.log(A).reshape(module.n_mamba_heads, module.mamba_head_dim, -1))
  714. module.D.data.fill_(1.0)
  715. @auto_docstring
  716. class ZambaModel(ZambaPreTrainedModel):
  717. """
  718. Transformer decoder consisting of *config.num_hidden_layers* layers. Each layer is a [`ZambaDecoderLayer`]
  719. Args:
  720. config: ZambaConfig
  721. """
  722. def __init__(self, config: ZambaConfig):
  723. super().__init__(config)
  724. self.padding_idx = config.pad_token_id
  725. self.vocab_size = config.vocab_size
  726. self.embed_tokens = nn.Embedding(config.vocab_size, config.hidden_size, self.padding_idx)
  727. block = ZambaAttentionDecoderLayer(config)
  728. mamba_layers = []
  729. linear_layers = []
  730. self.layers_block_type = config.layers_block_type
  731. for i in range(config.num_hidden_layers):
  732. if config.layers_block_type[i] == "mamba":
  733. mamba_layers.append(ZambaMambaDecoderLayer(config, layer_idx=i))
  734. elif config.layers_block_type[i] == "hybrid":
  735. linear_layers.append(nn.Linear(self.config.hidden_size, self.config.hidden_size, bias=False))
  736. mamba_layers.append(ZambaMambaDecoderLayer(config, layer_idx=i))
  737. mamba_layers = iter(mamba_layers)
  738. linear_layers = iter(linear_layers)
  739. layers = []
  740. self._tied_weights_keys = []
  741. for layer_id, layer_type in enumerate(self.layers_block_type):
  742. if layer_type == "hybrid":
  743. prefix_name = f"layers.{layer_id}."
  744. tied_keys = [
  745. "shared_transf.self_attn.q_proj.weight",
  746. "shared_transf.self_attn.k_proj.weight",
  747. "shared_transf.self_attn.v_proj.weight",
  748. "shared_transf.self_attn.o_proj.weight",
  749. "shared_transf.feed_forward.gate_proj.weight",
  750. "shared_transf.feed_forward.up_proj.weight",
  751. "shared_transf.feed_forward.down_proj.weight",
  752. "shared_transf.input_layernorm.weight",
  753. "shared_transf.pre_ff_layernorm.weight",
  754. ]
  755. self._tied_weights_keys = [*self._tied_weights_keys, *[prefix_name + key for key in tied_keys]]
  756. layers.append(ZambaHybridLayer(block, next(linear_layers), next(mamba_layers)))
  757. else:
  758. layers.append(next(mamba_layers))
  759. self.layers = nn.ModuleList(layers)
  760. self._attn_implementation = config._attn_implementation
  761. self.final_layernorm = ZambaRMSNorm(config.hidden_size, eps=config.rms_norm_eps)
  762. self.gradient_checkpointing = False
  763. # Initialize weights and apply final processing
  764. self.post_init()
  765. @auto_docstring
  766. def forward(
  767. self,
  768. input_ids: Optional[torch.LongTensor] = None,
  769. attention_mask: Optional[torch.Tensor] = None,
  770. position_ids: Optional[torch.LongTensor] = None,
  771. past_key_values: Optional[ZambaHybridDynamicCache] = None,
  772. inputs_embeds: Optional[torch.FloatTensor] = None,
  773. use_cache: Optional[bool] = None,
  774. output_attentions: Optional[bool] = None,
  775. output_hidden_states: Optional[bool] = None,
  776. return_dict: Optional[bool] = None,
  777. cache_position: Optional[torch.LongTensor] = None,
  778. ) -> Union[tuple, BaseModelOutputWithPast]:
  779. output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
  780. output_hidden_states = (
  781. output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
  782. )
  783. use_cache = use_cache if use_cache is not None else self.config.use_cache
  784. return_dict = return_dict if return_dict is not None else self.config.use_return_dict
  785. if (input_ids is None) ^ (inputs_embeds is not None):
  786. raise ValueError(
  787. "You cannot specify both input_ids and inputs_embeds at the same time, and must specify either one"
  788. )
  789. if self.gradient_checkpointing and self.training and use_cache:
  790. logger.warning_once(
  791. "`use_cache=True` is incompatible with gradient checkpointing. Setting `use_cache=False`."
  792. )
  793. use_cache = False
  794. if inputs_embeds is None:
  795. inputs_embeds = self.embed_tokens(input_ids)
  796. hidden_states = inputs_embeds
  797. original_hidden_states = torch.clone(inputs_embeds)
  798. # original_hidden_states: word embedding output that will be concatenated with hidden activations to form the input of the shared transformer layer
  799. if use_cache and past_key_values is None:
  800. logger.warning_once(
  801. "Zamba requires an initialized `ZambaHybridDynamicCache` to return a cache. None was "
  802. "provided, so no cache will be returned."
  803. )
  804. if cache_position is None:
  805. cache_position = torch.arange(hidden_states.shape[1], device=hidden_states.device)
  806. if position_ids is None:
  807. position_ids = cache_position.unsqueeze(0)
  808. causal_mask = self._update_causal_mask(attention_mask, inputs_embeds, cache_position)
  809. all_hidden_states = () if output_hidden_states else None
  810. all_self_attns = () if output_attentions else None
  811. for layer_idx, layer in enumerate(self.layers):
  812. if output_hidden_states:
  813. all_hidden_states += (hidden_states,)
  814. if self.gradient_checkpointing and self.training:
  815. layer_outputs = self._gradient_checkpointing_func(
  816. layer.__call__,
  817. hidden_states,
  818. original_hidden_states,
  819. layer_idx,
  820. attention_mask,
  821. causal_mask,
  822. past_key_values,
  823. output_attentions,
  824. use_cache,
  825. cache_position,
  826. )
  827. else:
  828. layer_outputs = layer(
  829. hidden_states,
  830. original_hidden_states=original_hidden_states,
  831. layer_idx=layer_idx,
  832. attention_mask=attention_mask,
  833. causal_mask=causal_mask,
  834. past_key_values=past_key_values,
  835. output_attentions=output_attentions,
  836. use_cache=use_cache,
  837. cache_position=cache_position,
  838. )
  839. hidden_states = layer_outputs[0]
  840. if output_attentions:
  841. if layer_outputs[1] is not None:
  842. # append attentions only of attention layers. Mamba layers return `None` as the attention weights
  843. all_self_attns += (layer_outputs[1],)
  844. hidden_states = self.final_layernorm(hidden_states)
  845. # add hidden states from the last decoder layer
  846. if output_hidden_states:
  847. all_hidden_states += (hidden_states,)
  848. if past_key_values and not past_key_values.has_previous_state:
  849. past_key_values.has_previous_state = True
  850. output = BaseModelOutputWithPast(
  851. last_hidden_state=hidden_states,
  852. past_key_values=past_key_values if use_cache else None,
  853. hidden_states=all_hidden_states,
  854. attentions=all_self_attns,
  855. )
  856. return output if return_dict else output.to_tuple()
  857. # Copied from transformers.models.jamba.modeling_jamba.JambaModel._update_causal_mask
  858. def _update_causal_mask(self, attention_mask, input_tensor, cache_position):
  859. if self.config._attn_implementation == "flash_attention_2":
  860. if attention_mask is not None and 0.0 in attention_mask:
  861. return attention_mask
  862. return None
  863. dtype, device = input_tensor.dtype, input_tensor.device
  864. min_dtype = torch.finfo(dtype).min
  865. sequence_length = input_tensor.shape[1]
  866. target_length = cache_position[-1] + 1
  867. causal_mask = torch.full((sequence_length, target_length), fill_value=min_dtype, dtype=dtype, device=device)
  868. if sequence_length != 1:
  869. causal_mask = torch.triu(causal_mask, diagonal=1)
  870. causal_mask *= torch.arange(target_length, device=device) > cache_position.reshape(-1, 1)
  871. causal_mask = causal_mask[None, None, :, :].expand(input_tensor.shape[0], 1, -1, -1)
  872. if attention_mask is not None:
  873. causal_mask = causal_mask.clone() # copy to contiguous memory for in-place edit
  874. if attention_mask.dim() == 2:
  875. mask_length = attention_mask.shape[-1]
  876. padding_mask = causal_mask[..., :mask_length].eq(0.0) * attention_mask[:, None, None, :].eq(0.0)
  877. causal_mask[..., :mask_length] = causal_mask[..., :mask_length].masked_fill(padding_mask, min_dtype)
  878. if (
  879. self.config._attn_implementation == "sdpa"
  880. and attention_mask is not None
  881. and attention_mask.device.type in ["cuda", "xpu", "npu"]
  882. ):
  883. # Attend to all tokens in fully masked rows in the causal_mask, for example the relevant first rows when
  884. # using left padding. This is required by F.scaled_dot_product_attention memory-efficient attention path.
  885. # Details: https://github.com/pytorch/pytorch/issues/110213
  886. causal_mask = AttentionMaskConverter._unmask_unattended(causal_mask, min_dtype)
  887. return causal_mask
  888. # Adapted from transformers.models.jamba.modeling_jamba.JambaForCausalLM with Jamba->Zamba, JAMBA->ZAMBA
  889. class ZambaForCausalLM(ZambaPreTrainedModel, GenerationMixin):
  890. def __init__(self, config: ZambaConfig):
  891. super().__init__(config)
  892. self.model = ZambaModel(config)
  893. self._tied_weights_keys = ["lm_head.weight", *self.model._tied_weights_keys]
  894. self.vocab_size = config.vocab_size
  895. self.lm_head = nn.Linear(config.hidden_size, config.vocab_size, bias=False)
  896. # Initialize weights and apply final processing
  897. self.post_init()
  898. @auto_docstring
  899. def forward(
  900. self,
  901. input_ids: Optional[torch.LongTensor] = None,
  902. attention_mask: Optional[torch.Tensor] = None,
  903. position_ids: Optional[torch.LongTensor] = None,
  904. past_key_values: Optional[ZambaHybridDynamicCache] = None,
  905. inputs_embeds: Optional[torch.FloatTensor] = None,
  906. labels: Optional[torch.LongTensor] = None,
  907. use_cache: Optional[bool] = None,
  908. output_attentions: Optional[bool] = None,
  909. output_hidden_states: Optional[bool] = None,
  910. return_dict: Optional[bool] = None,
  911. cache_position: Optional[torch.LongTensor] = None,
  912. logits_to_keep: Union[int, torch.Tensor] = 0,
  913. **kwargs,
  914. ) -> Union[tuple, CausalLMOutputWithPast]:
  915. r"""
  916. labels (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*):
  917. Labels for computing the masked language modeling loss. Indices should either be in `[0, ...,
  918. config.vocab_size]` or -100 (see `input_ids` docstring). Tokens with indices set to `-100` are ignored
  919. (masked), the loss is only computed for the tokens with labels in `[0, ..., config.vocab_size]`.
  920. Example:
  921. ```python
  922. >>> from transformers import AutoTokenizer, ZambaForCausalLM
  923. >>> model = ZambaForCausalLM.from_pretrained("Zyphra/Zamba-7B-v1")
  924. >>> tokenizer = AutoTokenizer.from_pretrained("Zyphra/Zamba-7B-v1")
  925. >>> prompt = "Hey, are you conscious? Can you talk to me?"
  926. >>> inputs = tokenizer(prompt, return_tensors="pt")
  927. >>> # Generate
  928. >>> generate_ids = model.generate(inputs.input_ids, max_length=30)
  929. >>> tokenizer.batch_decode(generate_ids, skip_special_tokens=True, clean_up_tokenization_spaces=False)[0]
  930. "Hey, are you conscious? Can you talk to me?\nI'm not conscious, but I can talk to you."
  931. ```"""
  932. output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
  933. output_hidden_states = (
  934. output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
  935. )
  936. return_dict = return_dict if return_dict is not None else self.config.use_return_dict
  937. # decoder outputs consists of (dec_features, layer_state, dec_hidden, dec_attn)
  938. outputs = self.model(
  939. input_ids=input_ids,
  940. attention_mask=attention_mask,
  941. position_ids=position_ids,
  942. past_key_values=past_key_values,
  943. inputs_embeds=inputs_embeds,
  944. use_cache=use_cache,
  945. output_attentions=output_attentions,
  946. output_hidden_states=output_hidden_states,
  947. cache_position=cache_position,
  948. return_dict=return_dict,
  949. )
  950. hidden_states = outputs[0]
  951. # Only compute necessary logits, and do not upcast them to float if we are not computing the loss
  952. slice_indices = slice(-logits_to_keep, None) if isinstance(logits_to_keep, int) else logits_to_keep
  953. logits = self.lm_head(hidden_states[:, slice_indices, :])
  954. loss = None
  955. if labels is not None:
  956. loss = self.loss_function(logits, labels, self.vocab_size, **kwargs)
  957. if not return_dict:
  958. output = (logits,) + outputs[1:]
  959. return (loss,) + output if loss is not None else output
  960. return CausalLMOutputWithPast(
  961. loss=loss,
  962. logits=logits,
  963. past_key_values=outputs.past_key_values,
  964. hidden_states=outputs.hidden_states,
  965. attentions=outputs.attentions,
  966. )
  967. def prepare_inputs_for_generation(
  968. self,
  969. input_ids,
  970. past_key_values=None,
  971. attention_mask=None,
  972. inputs_embeds=None,
  973. cache_position=None,
  974. position_ids=None,
  975. use_cache=True,
  976. **kwargs,
  977. ):
  978. # Overwritten -- has a unique cache type, `ZambaHybridDynamicCache`
  979. empty_past_kv = past_key_values is None
  980. # Omit tokens covered by past_key_values
  981. if not empty_past_kv:
  982. # If we have cache: let's slice `input_ids` through `cache_position`, to keep only the unprocessed tokens
  983. # Exception 1: when passing input_embeds, input_ids may be missing entries
  984. # Exception 2: some generation methods do special slicing of input_ids, so we don't need to do it here
  985. # Exception 3: with synced GPUs cache_position may go out of bounds, but we only want dummy token in that case.
  986. # (we can't check exception 3 while compiling)
  987. if (
  988. inputs_embeds is not None # Exception 1
  989. or cache_position[-1] >= input_ids.shape[1] # Exception 3
  990. ):
  991. input_ids = input_ids[:, -cache_position.shape[0] :]
  992. elif input_ids.shape[1] != cache_position.shape[0]: # Default case (the "else", a no op, is Exception 2)
  993. input_ids = input_ids[:, cache_position]
  994. else:
  995. past_key_values = ZambaHybridDynamicCache(
  996. self.config, input_ids.shape[0], dtype=self.dtype, device=self.device
  997. )
  998. if attention_mask is not None and position_ids is None:
  999. # create position_ids on the fly for batch generation
  1000. position_ids = attention_mask.long().cumsum(-1) - 1
  1001. position_ids.masked_fill_(attention_mask == 0, 1)
  1002. if not empty_past_kv:
  1003. position_ids = position_ids[:, -input_ids.shape[1] :]
  1004. # if `inputs_embeds` are passed, we only want to use them in the 1st generation step
  1005. if inputs_embeds is not None and empty_past_kv:
  1006. model_inputs = {"inputs_embeds": inputs_embeds}
  1007. else:
  1008. model_inputs = {"input_ids": input_ids.contiguous()} # `contiguous()` needed for compilation use cases
  1009. model_inputs.update(
  1010. {
  1011. "position_ids": position_ids,
  1012. "past_key_values": past_key_values,
  1013. "use_cache": use_cache,
  1014. "attention_mask": attention_mask,
  1015. "logits_to_keep": self.config.num_logits_to_keep,
  1016. "cache_position": cache_position,
  1017. }
  1018. )
  1019. # Forward ALL kwargs that are uninitialized (e.g. `use_cache`).
  1020. for key, value in kwargs.items():
  1021. if key not in model_inputs:
  1022. model_inputs[key] = value
  1023. return model_inputs
  1024. @auto_docstring(
  1025. custom_intro="""
  1026. The Zamba Model with a sequence classification head on top (linear layer).
  1027. [`ZambaForSequenceClassification`] uses the last token in order to do the classification, as other causal models
  1028. (e.g. GPT-2) do.
  1029. Since it does classification on the last token, it requires to know the position of the last token. If a
  1030. `pad_token_id` is defined in the configuration, it finds the last token that is not a padding token in each row. If
  1031. no `pad_token_id` is defined, it simply takes the last value in each row of the batch. Since it cannot guess the
  1032. padding tokens when `inputs_embeds` are passed instead of `input_ids`, it does the same (take the last value in
  1033. each row of the batch).
  1034. """
  1035. )
  1036. class ZambaForSequenceClassification(ZambaPreTrainedModel):
  1037. def __init__(self, config):
  1038. super().__init__(config)
  1039. self.num_labels = config.num_labels
  1040. self.model = ZambaModel(config)
  1041. self._tied_weights_keys = self.model._tied_weights_keys
  1042. self.score = nn.Linear(config.hidden_size, self.num_labels, bias=False)
  1043. # Initialize weights and apply final processing
  1044. self.post_init()
  1045. @auto_docstring
  1046. def forward(
  1047. self,
  1048. input_ids: Optional[torch.LongTensor] = None,
  1049. attention_mask: Optional[torch.Tensor] = None,
  1050. position_ids: Optional[torch.LongTensor] = None,
  1051. past_key_values: Optional[Union[Cache, list[torch.FloatTensor]]] = None,
  1052. inputs_embeds: Optional[torch.FloatTensor] = None,
  1053. labels: Optional[torch.LongTensor] = None,
  1054. use_cache: Optional[bool] = None,
  1055. output_attentions: Optional[bool] = None,
  1056. output_hidden_states: Optional[bool] = None,
  1057. return_dict: Optional[bool] = None,
  1058. ) -> Union[tuple, SequenceClassifierOutputWithPast]:
  1059. r"""
  1060. labels (`torch.LongTensor` of shape `(batch_size,)`, *optional*):
  1061. Labels for computing the sequence classification/regression loss. Indices should be in `[0, ...,
  1062. config.num_labels - 1]`. If `config.num_labels == 1` a regression loss is computed (Mean-Square loss), If
  1063. `config.num_labels > 1` a classification loss is computed (Cross-Entropy).
  1064. """
  1065. return_dict = return_dict if return_dict is not None else self.config.use_return_dict
  1066. transformer_outputs = self.model(
  1067. input_ids,
  1068. attention_mask=attention_mask,
  1069. position_ids=position_ids,
  1070. past_key_values=past_key_values,
  1071. inputs_embeds=inputs_embeds,
  1072. use_cache=use_cache,
  1073. output_attentions=output_attentions,
  1074. output_hidden_states=output_hidden_states,
  1075. return_dict=return_dict,
  1076. )
  1077. hidden_states = transformer_outputs[0]
  1078. logits = self.score(hidden_states)
  1079. if input_ids is not None:
  1080. batch_size = input_ids.shape[0]
  1081. else:
  1082. batch_size = inputs_embeds.shape[0]
  1083. if self.config.pad_token_id is None and batch_size != 1:
  1084. raise ValueError("Cannot handle batch sizes > 1 if no padding token is defined.")
  1085. if self.config.pad_token_id is None:
  1086. last_non_pad_token = -1
  1087. elif input_ids is not None:
  1088. # To handle both left- and right- padding, we take the rightmost token that is not equal to pad_token_id
  1089. non_pad_mask = (input_ids != self.config.pad_token_id).to(logits.device, torch.int32)
  1090. token_indices = torch.arange(input_ids.shape[-1], device=logits.device, dtype=torch.int32)
  1091. last_non_pad_token = (token_indices * non_pad_mask).argmax(-1)
  1092. else:
  1093. last_non_pad_token = -1
  1094. logger.warning_once(
  1095. f"{self.__class__.__name__} will not detect padding tokens in `inputs_embeds`. Results may be "
  1096. "unexpected if using padding tokens in conjunction with `inputs_embeds.`"
  1097. )
  1098. pooled_logits = logits[torch.arange(batch_size, device=logits.device), last_non_pad_token]
  1099. loss = None
  1100. if labels is not None:
  1101. labels = labels.to(logits.device)
  1102. if self.config.problem_type is None:
  1103. if self.num_labels == 1:
  1104. self.config.problem_type = "regression"
  1105. elif self.num_labels > 1 and (labels.dtype == torch.long or labels.dtype == torch.int):
  1106. self.config.problem_type = "single_label_classification"
  1107. else:
  1108. self.config.problem_type = "multi_label_classification"
  1109. if self.config.problem_type == "regression":
  1110. loss_fct = MSELoss()
  1111. if self.num_labels == 1:
  1112. loss = loss_fct(pooled_logits.squeeze(), labels.squeeze())
  1113. else:
  1114. loss = loss_fct(pooled_logits, labels)
  1115. elif self.config.problem_type == "single_label_classification":
  1116. loss_fct = CrossEntropyLoss()
  1117. loss = loss_fct(pooled_logits.view(-1, self.num_labels), labels.view(-1))
  1118. elif self.config.problem_type == "multi_label_classification":
  1119. loss_fct = BCEWithLogitsLoss()
  1120. loss = loss_fct(pooled_logits, labels)
  1121. if not return_dict:
  1122. output = (pooled_logits,) + transformer_outputs[1:]
  1123. return ((loss,) + output) if loss is not None else output
  1124. return SequenceClassifierOutputWithPast(
  1125. loss=loss,
  1126. logits=pooled_logits,
  1127. past_key_values=transformer_outputs.past_key_values,
  1128. hidden_states=transformer_outputs.hidden_states,
  1129. attentions=transformer_outputs.attentions,
  1130. )
  1131. __all__ = ["ZambaForCausalLM", "ZambaForSequenceClassification", "ZambaModel", "ZambaPreTrainedModel"]