modeling_lfm2.py 31 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525526527528529530531532533534535536537538539540541542543544545546547548549550551552553554555556557558559560561562563564565566567568569570571572573574575576577578579580581582583584585586587588589590591592593594595596597598599600601602603604605606607608609610611612613614615616617618619620621622623624625626627628629630631632633634635636637638639640641642643644645646647648649650651652653654655656657658659660661662663664665666667668669670671672673674675676677678679680681682683684685686687688689690691692693694695696697698699700701702703704705706707708709710711712713714715716717718719720721722723724725726727728729730731732733734735736737738739740741742743744745746747748749750751752753754
  1. # 🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨
  2. # This file was automatically generated from src/transformers/models/lfm2/modular_lfm2.py.
  3. # Do NOT edit this file manually as any edits will be overwritten by the generation of
  4. # the file from the modular. If any change should be done, please apply the change to the
  5. # modular_lfm2.py file directly. One of our CI enforces this.
  6. # 🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨
  7. # Copyright 2025 The HuggingFace Team. All rights reserved.
  8. #
  9. # Licensed under the Apache License, Version 2.0 (the "License");
  10. # you may not use this file except in compliance with the License.
  11. # You may obtain a copy of the License at
  12. #
  13. # http://www.apache.org/licenses/LICENSE-2.0
  14. #
  15. # Unless required by applicable law or agreed to in writing, software
  16. # distributed under the License is distributed on an "AS IS" BASIS,
  17. # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  18. # See the License for the specific language governing permissions and
  19. # limitations under the License.
  20. from typing import Any, Callable, Optional, Union
  21. import torch
  22. import torch.nn.functional as F
  23. from torch import nn
  24. from ...cache_utils import Cache
  25. from ...generation import GenerationMixin
  26. from ...integrations import use_kernel_forward_from_hub
  27. from ...masking_utils import create_causal_mask
  28. from ...modeling_layers import GradientCheckpointingLayer
  29. from ...modeling_outputs import BaseModelOutputWithPast, CausalLMOutputWithPast
  30. from ...modeling_rope_utils import ROPE_INIT_FUNCTIONS, dynamic_rope_update
  31. from ...modeling_utils import ALL_ATTENTION_FUNCTIONS, PreTrainedModel
  32. from ...processing_utils import Unpack
  33. from ...utils import TransformersKwargs, auto_docstring, can_return_tuple
  34. from ...utils.deprecation import deprecate_kwarg
  35. from ...utils.generic import check_model_inputs
  36. from ...utils.import_utils import is_causal_conv1d_available
  37. from .configuration_lfm2 import Lfm2Config
  38. if is_causal_conv1d_available():
  39. from causal_conv1d import causal_conv1d_fn, causal_conv1d_update
  40. else:
  41. causal_conv1d_fn, causal_conv1d_update = None, None
  42. @use_kernel_forward_from_hub("RMSNorm")
  43. class Lfm2RMSNorm(nn.Module):
  44. def __init__(self, hidden_size, eps=1e-6):
  45. """
  46. Lfm2RMSNorm is equivalent to T5LayerNorm
  47. """
  48. super().__init__()
  49. self.weight = nn.Parameter(torch.ones(hidden_size))
  50. self.variance_epsilon = eps
  51. def forward(self, hidden_states):
  52. input_dtype = hidden_states.dtype
  53. hidden_states = hidden_states.to(torch.float32)
  54. variance = hidden_states.pow(2).mean(-1, keepdim=True)
  55. hidden_states = hidden_states * torch.rsqrt(variance + self.variance_epsilon)
  56. return self.weight * hidden_states.to(input_dtype)
  57. def extra_repr(self):
  58. return f"{tuple(self.weight.shape)}, eps={self.variance_epsilon}"
  59. class Lfm2RotaryEmbedding(nn.Module):
  60. inv_freq: torch.Tensor # fix linting for `register_buffer`
  61. def __init__(self, config: Lfm2Config, device=None):
  62. super().__init__()
  63. # BC: "rope_type" was originally "type"
  64. if hasattr(config, "rope_scaling") and isinstance(config.rope_scaling, dict):
  65. self.rope_type = config.rope_scaling.get("rope_type", config.rope_scaling.get("type"))
  66. else:
  67. self.rope_type = "default"
  68. self.max_seq_len_cached = config.max_position_embeddings
  69. self.original_max_seq_len = config.max_position_embeddings
  70. self.config = config
  71. self.rope_init_fn = ROPE_INIT_FUNCTIONS[self.rope_type]
  72. inv_freq, self.attention_scaling = self.rope_init_fn(self.config, device)
  73. self.register_buffer("inv_freq", inv_freq, persistent=False)
  74. self.original_inv_freq = self.inv_freq
  75. @torch.no_grad()
  76. @dynamic_rope_update # power user: used with advanced RoPE types (e.g. dynamic rope)
  77. def forward(self, x, position_ids):
  78. inv_freq_expanded = self.inv_freq[None, :, None].float().expand(position_ids.shape[0], -1, 1).to(x.device)
  79. position_ids_expanded = position_ids[:, None, :].float()
  80. device_type = x.device.type if isinstance(x.device.type, str) and x.device.type != "mps" else "cpu"
  81. with torch.autocast(device_type=device_type, enabled=False): # Force float32
  82. freqs = (inv_freq_expanded.float() @ position_ids_expanded.float()).transpose(1, 2)
  83. emb = torch.cat((freqs, freqs), dim=-1)
  84. cos = emb.cos() * self.attention_scaling
  85. sin = emb.sin() * self.attention_scaling
  86. return cos.to(dtype=x.dtype), sin.to(dtype=x.dtype)
  87. class Lfm2MLP(nn.Module):
  88. def __init__(self, config: Lfm2Config):
  89. super().__init__()
  90. intermediate_size = config.intermediate_size
  91. if config.block_auto_adjust_ff_dim:
  92. intermediate_size = int(2 * intermediate_size / 3)
  93. # custom dim factor multiplier
  94. if config.block_ffn_dim_multiplier is not None:
  95. intermediate_size = int(config.block_ffn_dim_multiplier * intermediate_size)
  96. intermediate_size = config.block_multiple_of * (
  97. (intermediate_size + config.block_multiple_of - 1) // config.block_multiple_of
  98. )
  99. self.w1 = nn.Linear(config.hidden_size, intermediate_size, bias=False)
  100. self.w3 = nn.Linear(config.hidden_size, intermediate_size, bias=False)
  101. self.w2 = nn.Linear(intermediate_size, config.hidden_size, bias=False)
  102. def forward(self, x):
  103. return self.w2(F.silu(self.w1(x)) * self.w3(x))
  104. class Lfm2HybridConvCache:
  105. """
  106. Attention and conv cache for Lfm2.
  107. It stores the Key and Value states as a list of tensors, one for each layer.
  108. Attention layer cache shape: `[batch_size, num_heads, seq_len, head_dim]`.
  109. Conv layer cache shape: `[batch_size, hidden_size, L_cache-1]`.
  110. """
  111. # Override @property existing in Cache
  112. max_batch_size = None
  113. is_compileable = False
  114. key_cache = None
  115. value_cache = None
  116. def __init__(
  117. self,
  118. config: Lfm2Config,
  119. max_batch_size: int,
  120. dtype: torch.dtype = torch.float32,
  121. device: Union[torch.device, str, None] = None,
  122. ):
  123. self.key_cache = []
  124. self.value_cache = []
  125. self.max_batch_size = max_batch_size
  126. self.layer_types = config.layer_types
  127. self.first_attention_layer = self.layer_types.index("full_attention")
  128. self.conv_L_cache = config.conv_L_cache
  129. self._dtype = dtype
  130. self.conv_cache: list[torch.Tensor] = []
  131. device = torch.device(device) if device is not None else None
  132. for _ in range(config.num_hidden_layers):
  133. conv_state = torch.zeros(
  134. self.max_batch_size,
  135. config.hidden_size,
  136. self.conv_L_cache,
  137. dtype=self._dtype,
  138. device=device,
  139. )
  140. torch._dynamo.mark_static_address(conv_state)
  141. self.conv_cache.append(conv_state)
  142. def update(
  143. self,
  144. key_states: torch.Tensor,
  145. value_states: torch.Tensor,
  146. layer_idx: int,
  147. cache_kwargs: Optional[dict[str, Any]] = None,
  148. ) -> tuple[torch.Tensor, torch.Tensor]:
  149. """
  150. Updates the cache with the new `key_states` and `value_states` for the layer `layer_idx`.
  151. Parameters:
  152. key_states (`torch.Tensor`):
  153. The new key states to cache.
  154. value_states (`torch.Tensor`):
  155. The new value states to cache.
  156. layer_idx (`int`):
  157. The index of the layer to cache the states for.
  158. cache_kwargs (`Dict[str, Any]`, `optional`):
  159. Additional arguments for the cache subclass. No additional arguments are used in `DynamicCache`.
  160. Return:
  161. A tuple containing the updated key and value states.
  162. """
  163. # Update the cache
  164. if key_states is not None:
  165. if len(self.key_cache) <= layer_idx:
  166. # There may be skipped layers, fill them with empty lists
  167. for _ in range(len(self.key_cache), layer_idx):
  168. self.key_cache.append(torch.tensor([]))
  169. self.value_cache.append(torch.tensor([]))
  170. self.key_cache.append(key_states)
  171. self.value_cache.append(value_states)
  172. elif (
  173. not self.key_cache[layer_idx].numel() # prefers not t.numel() to len(t) == 0 to export the model
  174. ): # fills previously skipped layers; checking for tensor causes errors
  175. self.key_cache[layer_idx] = key_states
  176. self.value_cache[layer_idx] = value_states
  177. else:
  178. self.key_cache[layer_idx] = torch.cat([self.key_cache[layer_idx], key_states], dim=-2)
  179. self.value_cache[layer_idx] = torch.cat([self.value_cache[layer_idx], value_states], dim=-2)
  180. return self.key_cache[layer_idx], self.value_cache[layer_idx]
  181. def reorder_cache(self, beam_idx: torch.LongTensor):
  182. """Reorders the cache for beam search, given the selected beam indices."""
  183. for layer_idx in range(len(self.key_cache)):
  184. device = self.key_cache[layer_idx].device
  185. self.key_cache[layer_idx] = self.key_cache[layer_idx].index_select(0, beam_idx.to(device))
  186. device = self.value_cache[layer_idx].device
  187. self.value_cache[layer_idx] = self.value_cache[layer_idx].index_select(0, beam_idx.to(device))
  188. device = self.conv_cache[layer_idx].device
  189. self.conv_cache[layer_idx] = self.conv_cache[layer_idx].index_select(0, beam_idx.to(device))
  190. def get_seq_length(self, layer_idx: Optional[int] = 0) -> int:
  191. """Returns the sequence length of the cached states. A layer index can be optionally passed."""
  192. # take any layer that contains cache and not empty tensor
  193. layer_idx = self.first_attention_layer if self.layer_types[layer_idx] != "full_attention" else layer_idx
  194. if len(self.key_cache) <= layer_idx or self.key_cache[layer_idx].numel() == 0:
  195. return 0
  196. return self.key_cache[layer_idx].shape[-2]
  197. def get_mask_sizes(self, cache_position: torch.Tensor, layer_idx: int) -> tuple[int, int]:
  198. """
  199. Return a tuple (kv_length, kv_offset) corresponding to the length and offset that will be returned for
  200. the given layer at `layer_idx`.
  201. The masks are then prepared according to the given lengths (kv_length, kv_offset) and patterns (i.e. sliding_window, chunk_size),
  202. for each layer.
  203. """
  204. full_mask_kv_offset = 0
  205. query_length = cache_position.shape[0]
  206. past_seen_tokens = self.get_seq_length()
  207. kv_length = query_length + past_seen_tokens
  208. return kv_length, full_mask_kv_offset
  209. def crop(self, max_length: int):
  210. """Crop the cache to the given length"""
  211. if max_length < 0:
  212. max_length = self.get_seq_length() - abs(max_length)
  213. if self.get_seq_length() <= max_length:
  214. return
  215. for idx in range(len(self.key_cache)):
  216. if self.key_cache[idx].numel():
  217. self.key_cache[idx] = self.key_cache[idx][..., :max_length, :]
  218. self.value_cache[idx] = self.value_cache[idx][..., :max_length, :]
  219. def __len__(self) -> int:
  220. return len(self.key_cache)
  221. def __getitem__(self, layer_idx: int) -> tuple[torch.Tensor, torch.Tensor]:
  222. return self.key_cache[layer_idx], self.value_cache[layer_idx]
  223. def reset(self):
  224. for layer_idx in range(len(self.conv_cache)):
  225. # In-place ops prevent breaking the static address
  226. self.conv_cache[layer_idx].zero_()
  227. def rotate_half(x):
  228. """Rotates half the hidden dims of the input."""
  229. x1 = x[..., : x.shape[-1] // 2]
  230. x2 = x[..., x.shape[-1] // 2 :]
  231. return torch.cat((-x2, x1), dim=-1)
  232. def apply_rotary_pos_emb(q, k, cos, sin, position_ids=None, unsqueeze_dim=1):
  233. """Applies Rotary Position Embedding to the query and key tensors.
  234. Args:
  235. q (`torch.Tensor`): The query tensor.
  236. k (`torch.Tensor`): The key tensor.
  237. cos (`torch.Tensor`): The cosine part of the rotary embedding.
  238. sin (`torch.Tensor`): The sine part of the rotary embedding.
  239. position_ids (`torch.Tensor`, *optional*):
  240. Deprecated and unused.
  241. unsqueeze_dim (`int`, *optional*, defaults to 1):
  242. The 'unsqueeze_dim' argument specifies the dimension along which to unsqueeze cos[position_ids] and
  243. sin[position_ids] so that they can be properly broadcasted to the dimensions of q and k. For example, note
  244. that cos[position_ids] and sin[position_ids] have the shape [batch_size, seq_len, head_dim]. Then, if q and
  245. k have the shape [batch_size, heads, seq_len, head_dim], then setting unsqueeze_dim=1 makes
  246. cos[position_ids] and sin[position_ids] broadcastable to the shapes of q and k. Similarly, if q and k have
  247. the shape [batch_size, seq_len, heads, head_dim], then set unsqueeze_dim=2.
  248. Returns:
  249. `tuple(torch.Tensor)` comprising of the query and key tensors rotated using the Rotary Position Embedding.
  250. """
  251. cos = cos.unsqueeze(unsqueeze_dim)
  252. sin = sin.unsqueeze(unsqueeze_dim)
  253. q_embed = (q * cos) + (rotate_half(q) * sin)
  254. k_embed = (k * cos) + (rotate_half(k) * sin)
  255. return q_embed, k_embed
  256. def repeat_kv(hidden_states: torch.Tensor, n_rep: int) -> torch.Tensor:
  257. """
  258. This is the equivalent of torch.repeat_interleave(x, dim=1, repeats=n_rep). The hidden states go from (batch,
  259. num_key_value_heads, seqlen, head_dim) to (batch, num_attention_heads, seqlen, head_dim)
  260. """
  261. batch, num_key_value_heads, slen, head_dim = hidden_states.shape
  262. if n_rep == 1:
  263. return hidden_states
  264. hidden_states = hidden_states[:, :, None, :, :].expand(batch, num_key_value_heads, n_rep, slen, head_dim)
  265. return hidden_states.reshape(batch, num_key_value_heads * n_rep, slen, head_dim)
  266. def eager_attention_forward(
  267. module: nn.Module,
  268. query: torch.Tensor,
  269. key: torch.Tensor,
  270. value: torch.Tensor,
  271. attention_mask: Optional[torch.Tensor],
  272. scaling: float,
  273. dropout: float = 0.0,
  274. **kwargs: Unpack[TransformersKwargs],
  275. ):
  276. key_states = repeat_kv(key, module.num_key_value_groups)
  277. value_states = repeat_kv(value, module.num_key_value_groups)
  278. attn_weights = torch.matmul(query, key_states.transpose(2, 3)) * scaling
  279. if attention_mask is not None:
  280. causal_mask = attention_mask[:, :, :, : key_states.shape[-2]]
  281. attn_weights = attn_weights + causal_mask
  282. attn_weights = nn.functional.softmax(attn_weights, dim=-1, dtype=torch.float32).to(query.dtype)
  283. attn_weights = nn.functional.dropout(attn_weights, p=dropout, training=module.training)
  284. attn_output = torch.matmul(attn_weights, value_states)
  285. attn_output = attn_output.transpose(1, 2).contiguous()
  286. return attn_output, attn_weights
  287. class Lfm2Attention(nn.Module):
  288. """Multi-headed attention from 'Attention Is All You Need' paper"""
  289. def __init__(self, config: Lfm2Config, layer_idx: int):
  290. super().__init__()
  291. self.config = config
  292. self.layer_idx = layer_idx
  293. self.head_dim = getattr(config, "head_dim", config.hidden_size // config.num_attention_heads)
  294. self.num_key_value_groups = config.num_attention_heads // config.num_key_value_heads
  295. self.scaling = self.head_dim**-0.5
  296. self.is_causal = True
  297. self.q_proj = nn.Linear(config.hidden_size, config.num_attention_heads * self.head_dim, bias=False)
  298. self.k_proj = nn.Linear(config.hidden_size, config.num_key_value_heads * self.head_dim, bias=False)
  299. self.v_proj = nn.Linear(config.hidden_size, config.num_key_value_heads * self.head_dim, bias=False)
  300. self.out_proj = nn.Linear(config.num_attention_heads * self.head_dim, config.hidden_size, bias=False)
  301. self.q_layernorm = Lfm2RMSNorm(self.head_dim, eps=config.norm_eps)
  302. self.k_layernorm = Lfm2RMSNorm(self.head_dim, eps=config.norm_eps)
  303. @deprecate_kwarg("past_key_value", new_name="past_key_values", version="4.58")
  304. def forward(
  305. self,
  306. hidden_states: torch.Tensor,
  307. position_embeddings: tuple[torch.Tensor, torch.Tensor],
  308. attention_mask: Optional[torch.Tensor],
  309. past_key_values: Optional[Lfm2HybridConvCache] = None,
  310. cache_position: Optional[torch.LongTensor] = None,
  311. **kwargs,
  312. ) -> tuple[torch.Tensor, Optional[torch.Tensor]]:
  313. input_shape = hidden_states.shape[:-1]
  314. hidden_shape = (*input_shape, -1, self.head_dim)
  315. query_states = self.q_layernorm(self.q_proj(hidden_states).view(*hidden_shape)).transpose(1, 2)
  316. key_states = self.k_layernorm(self.k_proj(hidden_states).view(*hidden_shape)).transpose(1, 2)
  317. value_states = self.v_proj(hidden_states).view(*hidden_shape).transpose(1, 2)
  318. cos, sin = position_embeddings
  319. query_states, key_states = apply_rotary_pos_emb(query_states, key_states, cos, sin)
  320. if past_key_values is not None:
  321. cache_kwargs = {"sin": sin, "cos": cos, "cache_position": cache_position}
  322. key_states, value_states = past_key_values.update(key_states, value_states, self.layer_idx, cache_kwargs)
  323. attention_interface: Callable = eager_attention_forward
  324. if self.config._attn_implementation != "eager":
  325. attention_interface = ALL_ATTENTION_FUNCTIONS[self.config._attn_implementation]
  326. attn_output, attn_weights = attention_interface(
  327. self,
  328. query_states,
  329. key_states,
  330. value_states,
  331. attention_mask,
  332. dropout=0.0,
  333. scaling=self.scaling,
  334. **kwargs,
  335. )
  336. attn_output = attn_output.reshape(*input_shape, -1).contiguous()
  337. output = self.out_proj(attn_output)
  338. return output, attn_weights
  339. def apply_mask_to_padding_states(hidden_states, attention_mask):
  340. """
  341. Tunes out the hidden states for padding tokens, see https://github.com/state-spaces/mamba/issues/66
  342. """
  343. if attention_mask is not None and attention_mask.shape[1] > 1 and attention_mask.shape[0] > 1:
  344. dtype = hidden_states.dtype
  345. hidden_states = (hidden_states * attention_mask[:, :, None]).to(dtype)
  346. return hidden_states
  347. kernel_modules = (causal_conv1d_fn, causal_conv1d_update)
  348. is_fast_path_available = all(kernel_modules)
  349. class Lfm2ShortConv(nn.Module):
  350. def __init__(
  351. self,
  352. config: Lfm2Config,
  353. layer_idx: int,
  354. ):
  355. super().__init__()
  356. self.config = config
  357. self.layer_idx = layer_idx
  358. self.L_cache = config.conv_L_cache
  359. self.bias = config.conv_bias
  360. self.conv = nn.Conv1d(
  361. in_channels=config.hidden_size,
  362. out_channels=config.hidden_size,
  363. kernel_size=self.L_cache,
  364. groups=config.hidden_size,
  365. bias=self.bias,
  366. padding=self.L_cache - 1,
  367. )
  368. self.in_proj = nn.Linear(config.hidden_size, 3 * config.hidden_size, bias=self.bias)
  369. self.out_proj = nn.Linear(config.hidden_size, config.hidden_size, bias=self.bias)
  370. @deprecate_kwarg("past_key_value", new_name="past_key_values", version="4.58")
  371. def cuda_kernels_forward(
  372. self,
  373. x: torch.Tensor,
  374. past_key_values: Optional[Lfm2HybridConvCache] = None,
  375. cache_position: Optional[torch.LongTensor] = None,
  376. attention_mask: Optional[torch.Tensor] = None,
  377. ):
  378. x = apply_mask_to_padding_states(x, attention_mask)
  379. BCx = self.in_proj(x).transpose(-1, -2)
  380. B, C, x = BCx.chunk(3, dim=-2)
  381. Bx = B * x
  382. conv_weights = self.conv.weight.view(self.conv.weight.size(0), self.conv.weight.size(2))
  383. if past_key_values is not None and cache_position[0] > 0:
  384. conv_out = causal_conv1d_update(
  385. Bx.squeeze(-1),
  386. past_key_values.conv_cache[self.layer_idx],
  387. conv_weights,
  388. self.conv.bias,
  389. None,
  390. )
  391. conv_out = conv_out.unsqueeze(-1)
  392. else:
  393. if past_key_values is not None:
  394. conv_state = nn.functional.pad(Bx, (self.L_cache - Bx.shape[-1], 0))
  395. past_key_values.conv_cache[self.layer_idx].copy_(conv_state)
  396. conv_out = causal_conv1d_fn(Bx, conv_weights, self.conv.bias, activation=None)
  397. y = C * conv_out
  398. y = self.out_proj(y.transpose(-1, -2).contiguous())
  399. return y
  400. @deprecate_kwarg("past_key_value", new_name="past_key_values", version="4.58")
  401. def slow_forward(
  402. self,
  403. x: torch.Tensor,
  404. past_key_values: Optional[Lfm2HybridConvCache] = None,
  405. cache_position: Optional[torch.LongTensor] = None,
  406. attention_mask: Optional[torch.Tensor] = None,
  407. ):
  408. seqlen = x.shape[1]
  409. x = apply_mask_to_padding_states(x, attention_mask)
  410. BCx = self.in_proj(x).transpose(-1, -2)
  411. B, C, x = BCx.chunk(3, dim=-2)
  412. Bx = B * x
  413. if past_key_values is not None and cache_position[0] > 0:
  414. conv_state = past_key_values.conv_cache[self.layer_idx]
  415. cache_position = cache_position.clamp(0, self.L_cache - 1)
  416. conv_state = conv_state.roll(shifts=-1, dims=-1)
  417. conv_state[:, :, cache_position] = Bx.to(device=conv_state.device, dtype=conv_state.dtype)
  418. past_key_values.conv_cache[self.layer_idx].copy_(conv_state)
  419. conv_out = torch.sum(conv_state.to(Bx.device) * self.conv.weight[:, 0, :], dim=-1)
  420. if self.bias:
  421. conv_out += self.conv.bias
  422. conv_out = conv_out.unsqueeze(-1)
  423. else:
  424. if past_key_values is not None:
  425. conv_state = nn.functional.pad(Bx, (self.L_cache - Bx.shape[-1], 0))
  426. past_key_values.conv_cache[self.layer_idx].copy_(conv_state)
  427. conv_out = self.conv(Bx)[..., :seqlen]
  428. y = C * conv_out
  429. y = y.transpose(-1, -2).contiguous()
  430. y = self.out_proj(y)
  431. return y
  432. @deprecate_kwarg("past_key_value", new_name="past_key_values", version="4.58")
  433. def forward(
  434. self,
  435. hidden_states: torch.Tensor,
  436. past_key_values: Optional[Lfm2HybridConvCache] = None,
  437. cache_position: Optional[torch.LongTensor] = None,
  438. attention_mask: Optional[torch.Tensor] = None,
  439. ):
  440. if is_fast_path_available and "cuda" in hidden_states.device.type and not torch._dynamo.is_compiling():
  441. return self.cuda_kernels_forward(hidden_states, past_key_values, cache_position, attention_mask)
  442. return self.slow_forward(hidden_states, past_key_values, cache_position, attention_mask)
  443. class Lfm2DecoderLayer(GradientCheckpointingLayer):
  444. def __init__(self, config: Lfm2Config, layer_idx: int):
  445. super().__init__()
  446. self.is_attention_layer = config.layer_types[layer_idx] == "full_attention"
  447. if self.is_attention_layer:
  448. self.self_attn = Lfm2Attention(config, layer_idx)
  449. else:
  450. self.conv = Lfm2ShortConv(config, layer_idx)
  451. self.feed_forward = Lfm2MLP(config)
  452. self.operator_norm = Lfm2RMSNorm(config.hidden_size, eps=config.norm_eps)
  453. self.ffn_norm = Lfm2RMSNorm(config.hidden_size, eps=config.norm_eps)
  454. @deprecate_kwarg("past_key_value", new_name="past_key_values", version="4.58")
  455. def forward(
  456. self,
  457. hidden_states: torch.Tensor,
  458. position_embeddings: tuple[torch.Tensor, torch.Tensor],
  459. attention_mask: Optional[torch.Tensor] = None,
  460. position_ids: Optional[torch.LongTensor] = None,
  461. past_key_values: Optional[Lfm2HybridConvCache] = None,
  462. cache_position: Optional[torch.LongTensor] = None,
  463. **kwargs,
  464. ) -> torch.Tensor:
  465. residual = hidden_states
  466. if self.is_attention_layer:
  467. hidden_states, _ = self.self_attn(
  468. hidden_states=self.operator_norm(hidden_states),
  469. position_embeddings=position_embeddings,
  470. attention_mask=attention_mask,
  471. position_ids=position_ids,
  472. past_key_values=past_key_values,
  473. cache_position=cache_position,
  474. **kwargs,
  475. )
  476. else:
  477. hidden_states = self.conv(
  478. hidden_states=self.operator_norm(hidden_states),
  479. past_key_values=past_key_values,
  480. cache_position=cache_position,
  481. attention_mask=attention_mask,
  482. )
  483. hidden_states = hidden_states + residual
  484. hidden_states = hidden_states + self.feed_forward(self.ffn_norm(hidden_states))
  485. return hidden_states
  486. @auto_docstring
  487. class Lfm2PreTrainedModel(PreTrainedModel):
  488. config: Lfm2Config
  489. base_model_prefix = "model"
  490. supports_gradient_checkpointing = True
  491. _no_split_modules = ["Lfm2DecoderLayer"]
  492. _skip_keys_device_placement = ["past_key_values"]
  493. _supports_flash_attn = True
  494. _supports_sdpa = True
  495. _supports_flex_attn = True
  496. _can_compile_fullgraph = False
  497. _supports_attention_backend = True
  498. _can_record_outputs = {
  499. "hidden_states": Lfm2DecoderLayer,
  500. "attentions": Lfm2Attention,
  501. }
  502. @auto_docstring
  503. class Lfm2Model(Lfm2PreTrainedModel):
  504. def __init__(self, config: Lfm2Config):
  505. super().__init__(config)
  506. self.padding_idx = config.pad_token_id
  507. self.vocab_size = config.vocab_size
  508. self.embed_tokens = nn.Embedding(config.vocab_size, config.hidden_size, self.padding_idx)
  509. self.layers = nn.ModuleList(
  510. [Lfm2DecoderLayer(config, layer_idx) for layer_idx in range(config.num_hidden_layers)]
  511. )
  512. self.rotary_emb = Lfm2RotaryEmbedding(config=config)
  513. self.gradient_checkpointing = False
  514. self.pos_emb = Lfm2RotaryEmbedding(config)
  515. self.embedding_norm = Lfm2RMSNorm(config.hidden_size, eps=config.norm_eps)
  516. # Initialize weights and apply final processing
  517. self.post_init()
  518. @check_model_inputs()
  519. @auto_docstring
  520. def forward(
  521. self,
  522. input_ids: Optional[torch.LongTensor] = None,
  523. attention_mask: Optional[torch.Tensor] = None,
  524. position_ids: Optional[torch.LongTensor] = None,
  525. past_key_values: Optional[Lfm2HybridConvCache] = None,
  526. inputs_embeds: Optional[torch.FloatTensor] = None,
  527. use_cache: Optional[bool] = None,
  528. cache_position: Optional[torch.LongTensor] = None,
  529. **kwargs: Unpack[TransformersKwargs],
  530. ) -> BaseModelOutputWithPast:
  531. if (input_ids is None) ^ (inputs_embeds is not None):
  532. raise ValueError("You must specify exactly one of input_ids or inputs_embeds")
  533. if inputs_embeds is None:
  534. inputs_embeds = self.embed_tokens(input_ids)
  535. if use_cache and past_key_values is None:
  536. batch_size = inputs_embeds.shape[0]
  537. past_key_values = Lfm2HybridConvCache(
  538. config=self.config, max_batch_size=batch_size, dtype=self.dtype, device=self.device
  539. )
  540. if cache_position is None:
  541. past_seen_tokens = past_key_values.get_seq_length() if past_key_values is not None else 0
  542. cache_position = torch.arange(
  543. past_seen_tokens, past_seen_tokens + inputs_embeds.shape[1], device=inputs_embeds.device
  544. )
  545. if position_ids is None:
  546. position_ids = cache_position.unsqueeze(0)
  547. causal_mask = create_causal_mask(
  548. config=self.config,
  549. input_embeds=inputs_embeds,
  550. attention_mask=attention_mask,
  551. cache_position=cache_position,
  552. past_key_values=past_key_values,
  553. position_ids=position_ids,
  554. )
  555. hidden_states = inputs_embeds
  556. position_embeddings = self.pos_emb(hidden_states, position_ids)
  557. # decoder layers
  558. for decoder_layer in self.layers[: self.config.num_hidden_layers]:
  559. hidden_states = decoder_layer(
  560. hidden_states,
  561. attention_mask=causal_mask,
  562. position_ids=position_ids,
  563. past_key_values=past_key_values,
  564. cache_position=cache_position,
  565. position_embeddings=position_embeddings,
  566. **kwargs,
  567. )
  568. hidden_states = self.embedding_norm(hidden_states)
  569. return BaseModelOutputWithPast(
  570. last_hidden_state=hidden_states,
  571. past_key_values=past_key_values,
  572. )
  573. @auto_docstring
  574. class Lfm2ForCausalLM(Lfm2PreTrainedModel, GenerationMixin):
  575. _tied_weights_keys = ["lm_head.weight"]
  576. _tp_plan = {"lm_head": "colwise_rep"}
  577. _pp_plan = {"lm_head": (["hidden_states"], ["logits"])}
  578. def __init__(self, config):
  579. super().__init__(config)
  580. self.model = Lfm2Model(config)
  581. self.vocab_size = config.vocab_size
  582. self.lm_head = nn.Linear(config.hidden_size, config.vocab_size, bias=False)
  583. # Initialize weights and apply final processing
  584. self.post_init()
  585. @can_return_tuple
  586. @auto_docstring
  587. def forward(
  588. self,
  589. input_ids: Optional[torch.LongTensor] = None,
  590. attention_mask: Optional[torch.Tensor] = None,
  591. position_ids: Optional[torch.LongTensor] = None,
  592. past_key_values: Optional[Cache] = None,
  593. inputs_embeds: Optional[torch.FloatTensor] = None,
  594. labels: Optional[torch.LongTensor] = None,
  595. use_cache: Optional[bool] = None,
  596. cache_position: Optional[torch.LongTensor] = None,
  597. logits_to_keep: Union[int, torch.Tensor] = 0,
  598. **kwargs: Unpack[TransformersKwargs],
  599. ) -> CausalLMOutputWithPast:
  600. r"""
  601. Example:
  602. ```python
  603. >>> from transformers import AutoTokenizer, Lfm2ForCausalLM
  604. >>> model = Lfm2ForCausalLM.from_pretrained("meta-lfm2/Lfm2-2-7b-hf")
  605. >>> tokenizer = AutoTokenizer.from_pretrained("meta-lfm2/Lfm2-2-7b-hf")
  606. >>> prompt = "Hey, are you conscious? Can you talk to me?"
  607. >>> inputs = tokenizer(prompt, return_tensors="pt")
  608. >>> # Generate
  609. >>> generate_ids = model.generate(inputs.input_ids, max_length=30)
  610. >>> tokenizer.batch_decode(generate_ids, skip_special_tokens=True, clean_up_tokenization_spaces=False)[0]
  611. "Hey, are you conscious? Can you talk to me?\nI'm not conscious, but I can talk to you."
  612. ```"""
  613. outputs: BaseModelOutputWithPast = self.model(
  614. input_ids=input_ids,
  615. attention_mask=attention_mask,
  616. position_ids=position_ids,
  617. past_key_values=past_key_values,
  618. inputs_embeds=inputs_embeds,
  619. use_cache=use_cache,
  620. cache_position=cache_position,
  621. **kwargs,
  622. )
  623. hidden_states = outputs.last_hidden_state
  624. # Only compute necessary logits, and do not upcast them to float if we are not computing the loss
  625. slice_indices = slice(-logits_to_keep, None) if isinstance(logits_to_keep, int) else logits_to_keep
  626. logits = self.lm_head(hidden_states[:, slice_indices, :])
  627. loss = None
  628. if labels is not None:
  629. loss = self.loss_function(logits=logits, labels=labels, vocab_size=self.config.vocab_size, **kwargs)
  630. return CausalLMOutputWithPast(
  631. loss=loss,
  632. logits=logits,
  633. past_key_values=outputs.past_key_values,
  634. hidden_states=outputs.hidden_states,
  635. attentions=outputs.attentions,
  636. )
  637. __all__ = ["Lfm2ForCausalLM", "Lfm2Model", "Lfm2PreTrainedModel"]