modular_lfm2.py 20 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520
  1. # Copyright 2025 The HuggingFace Team. All rights reserved.
  2. #
  3. # Licensed under the Apache License, Version 2.0 (the "License");
  4. # you may not use this file except in compliance with the License.
  5. # You may obtain a copy of the License at
  6. #
  7. # http://www.apache.org/licenses/LICENSE-2.0
  8. #
  9. # Unless required by applicable law or agreed to in writing, software
  10. # distributed under the License is distributed on an "AS IS" BASIS,
  11. # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  12. # See the License for the specific language governing permissions and
  13. # limitations under the License.
  14. from typing import Any, Callable, Optional, Union
  15. import torch
  16. import torch.nn.functional as F
  17. from torch import nn
  18. from ...masking_utils import create_causal_mask
  19. from ...modeling_layers import GradientCheckpointingLayer
  20. from ...modeling_outputs import BaseModelOutputWithPast
  21. from ...modeling_utils import ALL_ATTENTION_FUNCTIONS
  22. from ...processing_utils import Unpack
  23. from ...utils import TransformersKwargs, logging
  24. from ...utils.deprecation import deprecate_kwarg
  25. from ...utils.import_utils import is_causal_conv1d_available
  26. from ..bamba.modeling_bamba import apply_mask_to_padding_states
  27. from ..llama.modeling_llama import (
  28. LlamaAttention,
  29. LlamaForCausalLM,
  30. LlamaModel,
  31. LlamaPreTrainedModel,
  32. LlamaRMSNorm,
  33. LlamaRotaryEmbedding,
  34. apply_rotary_pos_emb,
  35. eager_attention_forward,
  36. )
  37. from .configuration_lfm2 import Lfm2Config
  38. if is_causal_conv1d_available():
  39. from causal_conv1d import causal_conv1d_fn, causal_conv1d_update
  40. else:
  41. causal_conv1d_fn, causal_conv1d_update = None, None
  42. kernel_modules = (causal_conv1d_fn, causal_conv1d_update)
  43. is_fast_path_available = all(kernel_modules)
  44. logger = logging.get_logger(__name__)
  45. class Lfm2RMSNorm(LlamaRMSNorm):
  46. pass
  47. class Lfm2RotaryEmbedding(LlamaRotaryEmbedding):
  48. pass
  49. class Lfm2MLP(nn.Module):
  50. def __init__(self, config: Lfm2Config):
  51. super().__init__()
  52. intermediate_size = config.intermediate_size
  53. if config.block_auto_adjust_ff_dim:
  54. intermediate_size = int(2 * intermediate_size / 3)
  55. # custom dim factor multiplier
  56. if config.block_ffn_dim_multiplier is not None:
  57. intermediate_size = int(config.block_ffn_dim_multiplier * intermediate_size)
  58. intermediate_size = config.block_multiple_of * (
  59. (intermediate_size + config.block_multiple_of - 1) // config.block_multiple_of
  60. )
  61. self.w1 = nn.Linear(config.hidden_size, intermediate_size, bias=False)
  62. self.w3 = nn.Linear(config.hidden_size, intermediate_size, bias=False)
  63. self.w2 = nn.Linear(intermediate_size, config.hidden_size, bias=False)
  64. def forward(self, x):
  65. return self.w2(F.silu(self.w1(x)) * self.w3(x))
  66. class Lfm2HybridConvCache:
  67. """
  68. Attention and conv cache for Lfm2.
  69. It stores the Key and Value states as a list of tensors, one for each layer.
  70. Attention layer cache shape: `[batch_size, num_heads, seq_len, head_dim]`.
  71. Conv layer cache shape: `[batch_size, hidden_size, L_cache-1]`.
  72. """
  73. # Override @property existing in Cache
  74. max_batch_size = None
  75. is_compileable = False
  76. key_cache = None
  77. value_cache = None
  78. def __init__(
  79. self,
  80. config: Lfm2Config,
  81. max_batch_size: int,
  82. dtype: torch.dtype = torch.float32,
  83. device: Union[torch.device, str, None] = None,
  84. ):
  85. self.key_cache = []
  86. self.value_cache = []
  87. self.max_batch_size = max_batch_size
  88. self.layer_types = config.layer_types
  89. self.first_attention_layer = self.layer_types.index("full_attention")
  90. self.conv_L_cache = config.conv_L_cache
  91. self._dtype = dtype
  92. self.conv_cache: list[torch.Tensor] = []
  93. device = torch.device(device) if device is not None else None
  94. for _ in range(config.num_hidden_layers):
  95. conv_state = torch.zeros(
  96. self.max_batch_size,
  97. config.hidden_size,
  98. self.conv_L_cache,
  99. dtype=self._dtype,
  100. device=device,
  101. )
  102. torch._dynamo.mark_static_address(conv_state)
  103. self.conv_cache.append(conv_state)
  104. def update(
  105. self,
  106. key_states: torch.Tensor,
  107. value_states: torch.Tensor,
  108. layer_idx: int,
  109. cache_kwargs: Optional[dict[str, Any]] = None,
  110. ) -> tuple[torch.Tensor, torch.Tensor]:
  111. """
  112. Updates the cache with the new `key_states` and `value_states` for the layer `layer_idx`.
  113. Parameters:
  114. key_states (`torch.Tensor`):
  115. The new key states to cache.
  116. value_states (`torch.Tensor`):
  117. The new value states to cache.
  118. layer_idx (`int`):
  119. The index of the layer to cache the states for.
  120. cache_kwargs (`Dict[str, Any]`, `optional`):
  121. Additional arguments for the cache subclass. No additional arguments are used in `DynamicCache`.
  122. Return:
  123. A tuple containing the updated key and value states.
  124. """
  125. # Update the cache
  126. if key_states is not None:
  127. if len(self.key_cache) <= layer_idx:
  128. # There may be skipped layers, fill them with empty lists
  129. for _ in range(len(self.key_cache), layer_idx):
  130. self.key_cache.append(torch.tensor([]))
  131. self.value_cache.append(torch.tensor([]))
  132. self.key_cache.append(key_states)
  133. self.value_cache.append(value_states)
  134. elif (
  135. not self.key_cache[layer_idx].numel() # prefers not t.numel() to len(t) == 0 to export the model
  136. ): # fills previously skipped layers; checking for tensor causes errors
  137. self.key_cache[layer_idx] = key_states
  138. self.value_cache[layer_idx] = value_states
  139. else:
  140. self.key_cache[layer_idx] = torch.cat([self.key_cache[layer_idx], key_states], dim=-2)
  141. self.value_cache[layer_idx] = torch.cat([self.value_cache[layer_idx], value_states], dim=-2)
  142. return self.key_cache[layer_idx], self.value_cache[layer_idx]
  143. def reorder_cache(self, beam_idx: torch.LongTensor):
  144. """Reorders the cache for beam search, given the selected beam indices."""
  145. for layer_idx in range(len(self.key_cache)):
  146. device = self.key_cache[layer_idx].device
  147. self.key_cache[layer_idx] = self.key_cache[layer_idx].index_select(0, beam_idx.to(device))
  148. device = self.value_cache[layer_idx].device
  149. self.value_cache[layer_idx] = self.value_cache[layer_idx].index_select(0, beam_idx.to(device))
  150. device = self.conv_cache[layer_idx].device
  151. self.conv_cache[layer_idx] = self.conv_cache[layer_idx].index_select(0, beam_idx.to(device))
  152. def get_seq_length(self, layer_idx: Optional[int] = 0) -> int:
  153. """Returns the sequence length of the cached states. A layer index can be optionally passed."""
  154. # take any layer that contains cache and not empty tensor
  155. layer_idx = self.first_attention_layer if self.layer_types[layer_idx] != "full_attention" else layer_idx
  156. if len(self.key_cache) <= layer_idx or self.key_cache[layer_idx].numel() == 0:
  157. return 0
  158. return self.key_cache[layer_idx].shape[-2]
  159. def get_mask_sizes(self, cache_position: torch.Tensor, layer_idx: int) -> tuple[int, int]:
  160. """
  161. Return a tuple (kv_length, kv_offset) corresponding to the length and offset that will be returned for
  162. the given layer at `layer_idx`.
  163. The masks are then prepared according to the given lengths (kv_length, kv_offset) and patterns (i.e. sliding_window, chunk_size),
  164. for each layer.
  165. """
  166. full_mask_kv_offset = 0
  167. query_length = cache_position.shape[0]
  168. past_seen_tokens = self.get_seq_length()
  169. kv_length = query_length + past_seen_tokens
  170. return kv_length, full_mask_kv_offset
  171. def crop(self, max_length: int):
  172. """Crop the cache to the given length"""
  173. if max_length < 0:
  174. max_length = self.get_seq_length() - abs(max_length)
  175. if self.get_seq_length() <= max_length:
  176. return
  177. for idx in range(len(self.key_cache)):
  178. if self.key_cache[idx].numel():
  179. self.key_cache[idx] = self.key_cache[idx][..., :max_length, :]
  180. self.value_cache[idx] = self.value_cache[idx][..., :max_length, :]
  181. def __len__(self) -> int:
  182. return len(self.key_cache)
  183. def __getitem__(self, layer_idx: int) -> tuple[torch.Tensor, torch.Tensor]:
  184. return self.key_cache[layer_idx], self.value_cache[layer_idx]
  185. def reset(self):
  186. for layer_idx in range(len(self.conv_cache)):
  187. # In-place ops prevent breaking the static address
  188. self.conv_cache[layer_idx].zero_()
  189. class Lfm2Attention(LlamaAttention):
  190. def __init__(self, config: Lfm2Config, layer_idx: int):
  191. super().__init__(config, layer_idx)
  192. self.q_proj = nn.Linear(config.hidden_size, config.num_attention_heads * self.head_dim, bias=False)
  193. self.k_proj = nn.Linear(config.hidden_size, config.num_key_value_heads * self.head_dim, bias=False)
  194. self.v_proj = nn.Linear(config.hidden_size, config.num_key_value_heads * self.head_dim, bias=False)
  195. self.out_proj = nn.Linear(config.num_attention_heads * self.head_dim, config.hidden_size, bias=False)
  196. self.q_layernorm = Lfm2RMSNorm(self.head_dim, eps=config.norm_eps)
  197. self.k_layernorm = Lfm2RMSNorm(self.head_dim, eps=config.norm_eps)
  198. del self.o_proj
  199. del self.attention_dropout
  200. @deprecate_kwarg("past_key_value", new_name="past_key_values", version="4.58")
  201. def forward(
  202. self,
  203. hidden_states: torch.Tensor,
  204. position_embeddings: tuple[torch.Tensor, torch.Tensor],
  205. attention_mask: Optional[torch.Tensor],
  206. past_key_values: Optional[Lfm2HybridConvCache] = None,
  207. cache_position: Optional[torch.LongTensor] = None,
  208. **kwargs,
  209. ) -> tuple[torch.Tensor, Optional[torch.Tensor]]:
  210. input_shape = hidden_states.shape[:-1]
  211. hidden_shape = (*input_shape, -1, self.head_dim)
  212. query_states = self.q_layernorm(self.q_proj(hidden_states).view(*hidden_shape)).transpose(1, 2)
  213. key_states = self.k_layernorm(self.k_proj(hidden_states).view(*hidden_shape)).transpose(1, 2)
  214. value_states = self.v_proj(hidden_states).view(*hidden_shape).transpose(1, 2)
  215. cos, sin = position_embeddings
  216. query_states, key_states = apply_rotary_pos_emb(query_states, key_states, cos, sin)
  217. if past_key_values is not None:
  218. cache_kwargs = {"sin": sin, "cos": cos, "cache_position": cache_position}
  219. key_states, value_states = past_key_values.update(key_states, value_states, self.layer_idx, cache_kwargs)
  220. attention_interface: Callable = eager_attention_forward
  221. if self.config._attn_implementation != "eager":
  222. attention_interface = ALL_ATTENTION_FUNCTIONS[self.config._attn_implementation]
  223. attn_output, attn_weights = attention_interface(
  224. self,
  225. query_states,
  226. key_states,
  227. value_states,
  228. attention_mask,
  229. dropout=0.0,
  230. scaling=self.scaling,
  231. **kwargs,
  232. )
  233. attn_output = attn_output.reshape(*input_shape, -1).contiguous()
  234. output = self.out_proj(attn_output)
  235. return output, attn_weights
  236. class Lfm2ShortConv(nn.Module):
  237. def __init__(
  238. self,
  239. config: Lfm2Config,
  240. layer_idx: int,
  241. ):
  242. super().__init__()
  243. self.config = config
  244. self.layer_idx = layer_idx
  245. self.L_cache = config.conv_L_cache
  246. self.bias = config.conv_bias
  247. self.conv = nn.Conv1d(
  248. in_channels=config.hidden_size,
  249. out_channels=config.hidden_size,
  250. kernel_size=self.L_cache,
  251. groups=config.hidden_size,
  252. bias=self.bias,
  253. padding=self.L_cache - 1,
  254. )
  255. self.in_proj = nn.Linear(config.hidden_size, 3 * config.hidden_size, bias=self.bias)
  256. self.out_proj = nn.Linear(config.hidden_size, config.hidden_size, bias=self.bias)
  257. @deprecate_kwarg("past_key_value", new_name="past_key_values", version="4.58")
  258. def cuda_kernels_forward(
  259. self,
  260. x: torch.Tensor,
  261. past_key_values: Optional[Lfm2HybridConvCache] = None,
  262. cache_position: Optional[torch.LongTensor] = None,
  263. attention_mask: Optional[torch.Tensor] = None,
  264. ):
  265. x = apply_mask_to_padding_states(x, attention_mask)
  266. BCx = self.in_proj(x).transpose(-1, -2)
  267. B, C, x = BCx.chunk(3, dim=-2)
  268. Bx = B * x
  269. conv_weights = self.conv.weight.view(self.conv.weight.size(0), self.conv.weight.size(2))
  270. if past_key_values is not None and cache_position[0] > 0:
  271. conv_out = causal_conv1d_update(
  272. Bx.squeeze(-1),
  273. past_key_values.conv_cache[self.layer_idx],
  274. conv_weights,
  275. self.conv.bias,
  276. None,
  277. )
  278. conv_out = conv_out.unsqueeze(-1)
  279. else:
  280. if past_key_values is not None:
  281. conv_state = nn.functional.pad(Bx, (self.L_cache - Bx.shape[-1], 0))
  282. past_key_values.conv_cache[self.layer_idx].copy_(conv_state)
  283. conv_out = causal_conv1d_fn(Bx, conv_weights, self.conv.bias, activation=None)
  284. y = C * conv_out
  285. y = self.out_proj(y.transpose(-1, -2).contiguous())
  286. return y
  287. @deprecate_kwarg("past_key_value", new_name="past_key_values", version="4.58")
  288. def slow_forward(
  289. self,
  290. x: torch.Tensor,
  291. past_key_values: Optional[Lfm2HybridConvCache] = None,
  292. cache_position: Optional[torch.LongTensor] = None,
  293. attention_mask: Optional[torch.Tensor] = None,
  294. ):
  295. seqlen = x.shape[1]
  296. x = apply_mask_to_padding_states(x, attention_mask)
  297. BCx = self.in_proj(x).transpose(-1, -2)
  298. B, C, x = BCx.chunk(3, dim=-2)
  299. Bx = B * x
  300. if past_key_values is not None and cache_position[0] > 0:
  301. conv_state = past_key_values.conv_cache[self.layer_idx]
  302. cache_position = cache_position.clamp(0, self.L_cache - 1)
  303. conv_state = conv_state.roll(shifts=-1, dims=-1)
  304. conv_state[:, :, cache_position] = Bx.to(device=conv_state.device, dtype=conv_state.dtype)
  305. past_key_values.conv_cache[self.layer_idx].copy_(conv_state)
  306. conv_out = torch.sum(conv_state.to(Bx.device) * self.conv.weight[:, 0, :], dim=-1)
  307. if self.bias:
  308. conv_out += self.conv.bias
  309. conv_out = conv_out.unsqueeze(-1)
  310. else:
  311. if past_key_values is not None:
  312. conv_state = nn.functional.pad(Bx, (self.L_cache - Bx.shape[-1], 0))
  313. past_key_values.conv_cache[self.layer_idx].copy_(conv_state)
  314. conv_out = self.conv(Bx)[..., :seqlen]
  315. y = C * conv_out
  316. y = y.transpose(-1, -2).contiguous()
  317. y = self.out_proj(y)
  318. return y
  319. @deprecate_kwarg("past_key_value", new_name="past_key_values", version="4.58")
  320. def forward(
  321. self,
  322. hidden_states: torch.Tensor,
  323. past_key_values: Optional[Lfm2HybridConvCache] = None,
  324. cache_position: Optional[torch.LongTensor] = None,
  325. attention_mask: Optional[torch.Tensor] = None,
  326. ):
  327. if is_fast_path_available and "cuda" in hidden_states.device.type and not torch._dynamo.is_compiling():
  328. return self.cuda_kernels_forward(hidden_states, past_key_values, cache_position, attention_mask)
  329. return self.slow_forward(hidden_states, past_key_values, cache_position, attention_mask)
  330. class Lfm2DecoderLayer(GradientCheckpointingLayer):
  331. def __init__(self, config: Lfm2Config, layer_idx: int):
  332. super().__init__()
  333. self.is_attention_layer = config.layer_types[layer_idx] == "full_attention"
  334. if self.is_attention_layer:
  335. self.self_attn = Lfm2Attention(config, layer_idx)
  336. else:
  337. self.conv = Lfm2ShortConv(config, layer_idx)
  338. self.feed_forward = Lfm2MLP(config)
  339. self.operator_norm = Lfm2RMSNorm(config.hidden_size, eps=config.norm_eps)
  340. self.ffn_norm = Lfm2RMSNorm(config.hidden_size, eps=config.norm_eps)
  341. @deprecate_kwarg("past_key_value", new_name="past_key_values", version="4.58")
  342. def forward(
  343. self,
  344. hidden_states: torch.Tensor,
  345. position_embeddings: tuple[torch.Tensor, torch.Tensor],
  346. attention_mask: Optional[torch.Tensor] = None,
  347. position_ids: Optional[torch.LongTensor] = None,
  348. past_key_values: Optional[Lfm2HybridConvCache] = None,
  349. cache_position: Optional[torch.LongTensor] = None,
  350. **kwargs,
  351. ) -> torch.Tensor:
  352. residual = hidden_states
  353. if self.is_attention_layer:
  354. hidden_states, _ = self.self_attn(
  355. hidden_states=self.operator_norm(hidden_states),
  356. position_embeddings=position_embeddings,
  357. attention_mask=attention_mask,
  358. position_ids=position_ids,
  359. past_key_values=past_key_values,
  360. cache_position=cache_position,
  361. **kwargs,
  362. )
  363. else:
  364. hidden_states = self.conv(
  365. hidden_states=self.operator_norm(hidden_states),
  366. past_key_values=past_key_values,
  367. cache_position=cache_position,
  368. attention_mask=attention_mask,
  369. )
  370. hidden_states = hidden_states + residual
  371. hidden_states = hidden_states + self.feed_forward(self.ffn_norm(hidden_states))
  372. return hidden_states
  373. class Lfm2PreTrainedModel(LlamaPreTrainedModel):
  374. _can_compile_fullgraph = False
  375. class Lfm2Model(LlamaModel):
  376. def __init__(self, config: Lfm2Config):
  377. super().__init__(config)
  378. self.pos_emb = Lfm2RotaryEmbedding(config)
  379. self.embedding_norm = Lfm2RMSNorm(config.hidden_size, eps=config.norm_eps)
  380. del self.norm
  381. del self.rotary_emv
  382. def forward(
  383. self,
  384. input_ids: Optional[torch.LongTensor] = None,
  385. attention_mask: Optional[torch.Tensor] = None,
  386. position_ids: Optional[torch.LongTensor] = None,
  387. past_key_values: Optional[Lfm2HybridConvCache] = None,
  388. inputs_embeds: Optional[torch.FloatTensor] = None,
  389. use_cache: Optional[bool] = None,
  390. cache_position: Optional[torch.LongTensor] = None,
  391. **kwargs: Unpack[TransformersKwargs],
  392. ) -> BaseModelOutputWithPast:
  393. if (input_ids is None) ^ (inputs_embeds is not None):
  394. raise ValueError("You must specify exactly one of input_ids or inputs_embeds")
  395. if inputs_embeds is None:
  396. inputs_embeds = self.embed_tokens(input_ids)
  397. if use_cache and past_key_values is None:
  398. batch_size = inputs_embeds.shape[0]
  399. past_key_values = Lfm2HybridConvCache(
  400. config=self.config, max_batch_size=batch_size, dtype=self.dtype, device=self.device
  401. )
  402. if cache_position is None:
  403. past_seen_tokens = past_key_values.get_seq_length() if past_key_values is not None else 0
  404. cache_position = torch.arange(
  405. past_seen_tokens, past_seen_tokens + inputs_embeds.shape[1], device=inputs_embeds.device
  406. )
  407. if position_ids is None:
  408. position_ids = cache_position.unsqueeze(0)
  409. causal_mask = create_causal_mask(
  410. config=self.config,
  411. input_embeds=inputs_embeds,
  412. attention_mask=attention_mask,
  413. cache_position=cache_position,
  414. past_key_values=past_key_values,
  415. position_ids=position_ids,
  416. )
  417. hidden_states = inputs_embeds
  418. position_embeddings = self.pos_emb(hidden_states, position_ids)
  419. # decoder layers
  420. for decoder_layer in self.layers[: self.config.num_hidden_layers]:
  421. hidden_states = decoder_layer(
  422. hidden_states,
  423. attention_mask=causal_mask,
  424. position_ids=position_ids,
  425. past_key_values=past_key_values,
  426. cache_position=cache_position,
  427. position_embeddings=position_embeddings,
  428. **kwargs,
  429. )
  430. hidden_states = self.embedding_norm(hidden_states)
  431. return BaseModelOutputWithPast(
  432. last_hidden_state=hidden_states,
  433. past_key_values=past_key_values,
  434. )
  435. class Lfm2ForCausalLM(LlamaForCausalLM):
  436. pass
  437. __all__ = ["Lfm2ForCausalLM", "Lfm2Model", "Lfm2PreTrainedModel"]