modeling_stablelm.py 44 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525526527528529530531532533534535536537538539540541542543544545546547548549550551552553554555556557558559560561562563564565566567568569570571572573574575576577578579580581582583584585586587588589590591592593594595596597598599600601602603604605606607608609610611612613614615616617618619620621622623624625626627628629630631632633634635636637638639640641642643644645646647648649650651652653654655656657658659660661662663664665666667668669670671672673674675676677678679680681682683684685686687688689690691692693694695696697698699700701702703704705706707708709710711712713714715716717718719720721722723724725726727728729730731732733734735736737738739740741742743744745746747748749750751752753754755756757758759760761762763764765766767768769770771772773774775776777778779780781782783784785786787788789790791792793794795796797798799800801802803804805806807808809810811812813814815816817818819820821822823824825826827828829830831832833834835836837838839840841842843844845846847848849850851852853854855856857858859860861862863864865866867868869870871872873874875876877878879880881882883884885886887888889890891892893894895896897898899900901902903904905906907908909910911912913914915916917918919920921922923924925926927928929930931932933934935936937938939940941942943944945946947948949950951952953954955956957958959960961962963964965966967968969970971972973974975976977978979980981982983984985986987988989990991992993994995996
  1. # coding=utf-8
  2. # Copyright 2024 EleutherAI and the HuggingFace Inc. team. All rights reserved.
  3. #
  4. # This code is based on EleutherAI's GPT-NeoX library and the GPT-NeoX
  5. # and OPT implementations in this library. It has been modified from its
  6. # original forms to accommodate minor architectural differences compared
  7. # to GPT-NeoX and OPT used by the Meta AI team that trained the model.
  8. #
  9. # Licensed under the Apache License, Version 2.0 (the "License");
  10. # you may not use this file except in compliance with the License.
  11. # You may obtain a copy of the License at
  12. #
  13. # http://www.apache.org/licenses/LICENSE-2.0
  14. #
  15. # Unless required by applicable law or agreed to in writing, software
  16. # distributed under the License is distributed on an "AS IS" BASIS,
  17. # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  18. # See the License for the specific language governing permissions and
  19. # limitations under the License.
  20. """PyTorch StableLM model."""
  21. import math
  22. from typing import Optional, Union
  23. import torch
  24. from torch import nn
  25. from ...activations import ACT2FN
  26. from ...cache_utils import Cache, DynamicCache
  27. from ...generation import GenerationMixin
  28. from ...modeling_attn_mask_utils import AttentionMaskConverter
  29. from ...modeling_flash_attention_utils import flash_attn_supports_top_left_mask, is_flash_attn_available
  30. from ...modeling_layers import (
  31. GenericForSequenceClassification,
  32. GenericForTokenClassification,
  33. GradientCheckpointingLayer,
  34. )
  35. from ...modeling_outputs import (
  36. BaseModelOutputWithPast,
  37. CausalLMOutputWithPast,
  38. )
  39. from ...modeling_rope_utils import ROPE_INIT_FUNCTIONS, dynamic_rope_update
  40. from ...modeling_utils import PreTrainedModel
  41. from ...utils import auto_docstring, can_return_tuple, is_torch_flex_attn_available, logging
  42. from ...utils.deprecation import deprecate_kwarg
  43. from .configuration_stablelm import StableLmConfig
  44. if is_torch_flex_attn_available():
  45. from torch.nn.attention.flex_attention import BlockMask
  46. from ...integrations.flex_attention import make_flex_block_causal_mask
  47. if is_flash_attn_available():
  48. from ...modeling_flash_attention_utils import _flash_attention_forward
  49. logger = logging.get_logger(__name__)
  50. # Copied from transformers.models.llama.modeling_llama.LlamaRotaryEmbedding with Llama->StableLm
  51. class StableLmRotaryEmbedding(nn.Module):
  52. inv_freq: torch.Tensor # fix linting for `register_buffer`
  53. def __init__(self, config: StableLmConfig, device=None):
  54. super().__init__()
  55. # BC: "rope_type" was originally "type"
  56. if hasattr(config, "rope_scaling") and isinstance(config.rope_scaling, dict):
  57. self.rope_type = config.rope_scaling.get("rope_type", config.rope_scaling.get("type"))
  58. else:
  59. self.rope_type = "default"
  60. self.max_seq_len_cached = config.max_position_embeddings
  61. self.original_max_seq_len = config.max_position_embeddings
  62. self.config = config
  63. self.rope_init_fn = ROPE_INIT_FUNCTIONS[self.rope_type]
  64. inv_freq, self.attention_scaling = self.rope_init_fn(self.config, device)
  65. self.register_buffer("inv_freq", inv_freq, persistent=False)
  66. self.original_inv_freq = self.inv_freq
  67. @torch.no_grad()
  68. @dynamic_rope_update # power user: used with advanced RoPE types (e.g. dynamic rope)
  69. def forward(self, x, position_ids):
  70. inv_freq_expanded = self.inv_freq[None, :, None].float().expand(position_ids.shape[0], -1, 1).to(x.device)
  71. position_ids_expanded = position_ids[:, None, :].float()
  72. device_type = x.device.type if isinstance(x.device.type, str) and x.device.type != "mps" else "cpu"
  73. with torch.autocast(device_type=device_type, enabled=False): # Force float32
  74. freqs = (inv_freq_expanded.float() @ position_ids_expanded.float()).transpose(1, 2)
  75. emb = torch.cat((freqs, freqs), dim=-1)
  76. cos = emb.cos() * self.attention_scaling
  77. sin = emb.sin() * self.attention_scaling
  78. return cos.to(dtype=x.dtype), sin.to(dtype=x.dtype)
  79. # Copied from transformers.models.llama.modeling_llama.rotate_half
  80. def rotate_half(x):
  81. """Rotates half the hidden dims of the input."""
  82. x1 = x[..., : x.shape[-1] // 2]
  83. x2 = x[..., x.shape[-1] // 2 :]
  84. return torch.cat((-x2, x1), dim=-1)
  85. # Copied from transformers.models.llama.modeling_llama.apply_rotary_pos_emb
  86. def apply_rotary_pos_emb(q, k, cos, sin, position_ids=None, unsqueeze_dim=1):
  87. """Applies Rotary Position Embedding to the query and key tensors.
  88. Args:
  89. q (`torch.Tensor`): The query tensor.
  90. k (`torch.Tensor`): The key tensor.
  91. cos (`torch.Tensor`): The cosine part of the rotary embedding.
  92. sin (`torch.Tensor`): The sine part of the rotary embedding.
  93. position_ids (`torch.Tensor`, *optional*):
  94. Deprecated and unused.
  95. unsqueeze_dim (`int`, *optional*, defaults to 1):
  96. The 'unsqueeze_dim' argument specifies the dimension along which to unsqueeze cos[position_ids] and
  97. sin[position_ids] so that they can be properly broadcasted to the dimensions of q and k. For example, note
  98. that cos[position_ids] and sin[position_ids] have the shape [batch_size, seq_len, head_dim]. Then, if q and
  99. k have the shape [batch_size, heads, seq_len, head_dim], then setting unsqueeze_dim=1 makes
  100. cos[position_ids] and sin[position_ids] broadcastable to the shapes of q and k. Similarly, if q and k have
  101. the shape [batch_size, seq_len, heads, head_dim], then set unsqueeze_dim=2.
  102. Returns:
  103. `tuple(torch.Tensor)` comprising of the query and key tensors rotated using the Rotary Position Embedding.
  104. """
  105. cos = cos.unsqueeze(unsqueeze_dim)
  106. sin = sin.unsqueeze(unsqueeze_dim)
  107. q_embed = (q * cos) + (rotate_half(q) * sin)
  108. k_embed = (k * cos) + (rotate_half(k) * sin)
  109. return q_embed, k_embed
  110. # Copied from transformers.models.mistral.modeling_mistral.MistralMLP with Mistral->StableLm
  111. class StableLmMLP(nn.Module):
  112. def __init__(self, config):
  113. super().__init__()
  114. self.config = config
  115. self.hidden_size = config.hidden_size
  116. self.intermediate_size = config.intermediate_size
  117. self.gate_proj = nn.Linear(self.hidden_size, self.intermediate_size, bias=False)
  118. self.up_proj = nn.Linear(self.hidden_size, self.intermediate_size, bias=False)
  119. self.down_proj = nn.Linear(self.intermediate_size, self.hidden_size, bias=False)
  120. self.act_fn = ACT2FN[config.hidden_act]
  121. def forward(self, x):
  122. down_proj = self.down_proj(self.act_fn(self.gate_proj(x)) * self.up_proj(x))
  123. return down_proj
  124. class StableLmLayerNormPerHead(nn.Module):
  125. def __init__(self, dim, num_heads, eps=1e-5, bias=False):
  126. super().__init__()
  127. self.dim = dim
  128. self.num_heads = num_heads
  129. self.norms = nn.ModuleList([nn.LayerNorm(dim, eps=eps, bias=bias) for _ in range(self.num_heads)])
  130. def forward(self, hidden_states: torch.Tensor):
  131. # Split along the num_heads axis to get per-head inputs
  132. # [batch_size, num_heads, seq_len, head_dim] -> [batch_size, 1, seq_len, head_dim] * num_heads
  133. states_per_heads = torch.split(hidden_states, 1, dim=1)
  134. # Normalize and merge the heads back together
  135. return torch.cat([norm(hidden_states) for norm, hidden_states in zip(self.norms, states_per_heads)], dim=1)
  136. # Copied from transformers.models.llama.modeling_llama.repeat_kv
  137. def repeat_kv(hidden_states: torch.Tensor, n_rep: int) -> torch.Tensor:
  138. """
  139. This is the equivalent of torch.repeat_interleave(x, dim=1, repeats=n_rep). The hidden states go from (batch,
  140. num_key_value_heads, seqlen, head_dim) to (batch, num_attention_heads, seqlen, head_dim)
  141. """
  142. batch, num_key_value_heads, slen, head_dim = hidden_states.shape
  143. if n_rep == 1:
  144. return hidden_states
  145. hidden_states = hidden_states[:, :, None, :, :].expand(batch, num_key_value_heads, n_rep, slen, head_dim)
  146. return hidden_states.reshape(batch, num_key_value_heads * n_rep, slen, head_dim)
  147. class StableLmAttention(nn.Module):
  148. """Multi-headed attention from 'Attention Is All You Need' paper"""
  149. def __init__(self, config: StableLmConfig, layer_idx: Optional[int] = None):
  150. super().__init__()
  151. self.config = config
  152. self.layer_idx = layer_idx
  153. if layer_idx is None:
  154. logger.warning_once(
  155. f"Instantiating {self.__class__.__name__} without passing a `layer_idx` is not recommended and will "
  156. "lead to errors during the forward call if caching is used. Please make sure to provide a `layer_idx` "
  157. "when creating this class."
  158. )
  159. self.hidden_size = config.hidden_size
  160. self.num_heads = config.num_attention_heads
  161. self.head_dim = self.hidden_size // self.num_heads
  162. self.num_key_value_heads = config.num_key_value_heads
  163. self.num_key_value_groups = self.num_heads // self.num_key_value_heads
  164. self.rope_theta = config.rope_theta
  165. self.rotary_ndims = int(self.head_dim * config.partial_rotary_factor)
  166. self.is_causal = True
  167. if (self.head_dim * self.num_heads) != self.hidden_size:
  168. raise ValueError(
  169. f"hidden_size must be divisible by num_heads (got `hidden_size`: {self.hidden_size}"
  170. f" and `num_heads`: {self.num_heads})."
  171. )
  172. self.q_proj = nn.Linear(self.hidden_size, self.num_heads * self.head_dim, bias=config.use_qkv_bias)
  173. self.k_proj = nn.Linear(self.hidden_size, self.num_key_value_heads * self.head_dim, bias=config.use_qkv_bias)
  174. self.v_proj = nn.Linear(self.hidden_size, self.num_key_value_heads * self.head_dim, bias=config.use_qkv_bias)
  175. self.o_proj = nn.Linear(self.hidden_size, self.hidden_size, bias=False)
  176. self.qk_layernorm = config.qk_layernorm
  177. if self.qk_layernorm:
  178. self.q_layernorm = StableLmLayerNormPerHead(self.head_dim, self.num_heads, eps=config.layer_norm_eps)
  179. self.k_layernorm = StableLmLayerNormPerHead(
  180. self.head_dim, self.num_key_value_heads, eps=config.layer_norm_eps
  181. )
  182. self.attention_dropout = nn.Dropout(config.attention_dropout)
  183. self.rotary_emb = StableLmRotaryEmbedding(config=self.config)
  184. @deprecate_kwarg("past_key_value", new_name="past_key_values", version="4.58")
  185. def forward(
  186. self,
  187. hidden_states: torch.Tensor,
  188. attention_mask: Optional[torch.Tensor] = None,
  189. position_ids: Optional[torch.LongTensor] = None,
  190. past_key_values: Optional[Cache] = None,
  191. output_attentions: bool = False,
  192. use_cache: bool = False,
  193. cache_position: Optional[torch.LongTensor] = None,
  194. position_embeddings: Optional[tuple[torch.Tensor, torch.Tensor]] = None, # necessary, but kept here for BC
  195. ) -> tuple[torch.Tensor, Optional[torch.Tensor], Optional[tuple[torch.Tensor]]]:
  196. bsz, q_len, _ = hidden_states.size()
  197. query_states = self.q_proj(hidden_states)
  198. key_states = self.k_proj(hidden_states)
  199. value_states = self.v_proj(hidden_states)
  200. query_states = query_states.view(bsz, q_len, self.num_heads, self.head_dim).transpose(1, 2)
  201. key_states = key_states.view(bsz, q_len, self.num_key_value_heads, self.head_dim).transpose(1, 2)
  202. value_states = value_states.view(bsz, q_len, self.num_key_value_heads, self.head_dim).transpose(1, 2)
  203. if self.qk_layernorm:
  204. query_states = self.q_layernorm(query_states)
  205. key_states = self.k_layernorm(key_states)
  206. cos, sin = position_embeddings
  207. # Partial rotary embedding
  208. query_rot, query_pass = (
  209. query_states[..., : self.rotary_ndims],
  210. query_states[..., self.rotary_ndims :],
  211. )
  212. key_rot, key_pass = (
  213. key_states[..., : self.rotary_ndims],
  214. key_states[..., self.rotary_ndims :],
  215. )
  216. # [batch_size, seq_length, num_heads, head_dim // config.partial_rotary_factor]
  217. query_rot, key_rot = apply_rotary_pos_emb(query_rot, key_rot, cos, sin)
  218. # [batch_size, seq_length, num_heads, head_dim]
  219. query_states = torch.cat((query_rot, query_pass), dim=-1)
  220. key_states = torch.cat((key_rot, key_pass), dim=-1)
  221. if past_key_values is not None:
  222. # Specific to RoPE models with partial rotation
  223. cache_kwargs = {
  224. "sin": sin,
  225. "cos": cos,
  226. "partial_rotation_size": self.rotary_ndims,
  227. "cache_position": cache_position,
  228. }
  229. key_states, value_states = past_key_values.update(key_states, value_states, self.layer_idx, cache_kwargs)
  230. # Repeat k/v heads if n_kv_heads < n_heads
  231. key_states = repeat_kv(key_states, self.num_key_value_groups)
  232. value_states = repeat_kv(value_states, self.num_key_value_groups)
  233. attn_weights = torch.matmul(query_states, key_states.transpose(2, 3)) / math.sqrt(self.head_dim)
  234. if attention_mask is not None: # no matter the length, we just slice it
  235. causal_mask = attention_mask[:, :, :, : key_states.shape[-2]]
  236. attn_weights += causal_mask
  237. # upcast attention to fp32
  238. attn_weights = nn.functional.softmax(attn_weights, dtype=torch.float32, dim=-1).to(query_states.dtype)
  239. attn_weights = self.attention_dropout(attn_weights)
  240. attn_output = torch.matmul(attn_weights, value_states)
  241. if attn_output.size() != (bsz, self.num_heads, q_len, self.head_dim):
  242. raise ValueError(
  243. f"`attn_output` should be of size {(bsz, self.num_heads, q_len, self.head_dim)}, but is"
  244. f" {attn_output.size()}"
  245. )
  246. attn_output = attn_output.transpose(1, 2).contiguous()
  247. attn_output = attn_output.reshape(bsz, q_len, self.hidden_size)
  248. attn_output = self.o_proj(attn_output)
  249. if not output_attentions:
  250. attn_weights = None
  251. return attn_output, attn_weights
  252. class StableLmSdpaAttention(StableLmAttention):
  253. @deprecate_kwarg("past_key_value", new_name="past_key_values", version="4.58")
  254. def forward(
  255. self,
  256. hidden_states: torch.Tensor,
  257. attention_mask: Optional[torch.Tensor] = None,
  258. position_ids: Optional[torch.LongTensor] = None,
  259. past_key_values: Optional[Cache] = None,
  260. output_attentions: bool = False,
  261. use_cache: bool = False,
  262. cache_position: Optional[torch.LongTensor] = None,
  263. position_embeddings: Optional[tuple[torch.Tensor, torch.Tensor]] = None, # necessary, but kept here for BC
  264. ) -> tuple[torch.Tensor, Optional[torch.Tensor], Optional[tuple[torch.Tensor]]]:
  265. if output_attentions:
  266. # TODO: Improve this warning with e.g. `model.config.attn_implementation = "manual"` once this is implemented.
  267. logger.warning_once(
  268. "StableLmModel is using StableLmSdpaAttention, but `torch.nn.functional.scaled_dot_product_attention` does not support `output_attentions=True`. Falling back to the manual attention implementation, "
  269. 'but specifying the manual implementation will be required from Transformers version v5.0.0 onwards. This warning can be removed using the argument `attn_implementation="eager"` when loading the model.'
  270. )
  271. return super().forward(
  272. hidden_states=hidden_states,
  273. attention_mask=attention_mask,
  274. position_ids=position_ids,
  275. past_key_values=past_key_values,
  276. output_attentions=output_attentions,
  277. use_cache=use_cache,
  278. cache_position=cache_position,
  279. position_embeddings=position_embeddings,
  280. )
  281. bsz, q_len, _ = hidden_states.size()
  282. query_states = self.q_proj(hidden_states)
  283. key_states = self.k_proj(hidden_states)
  284. value_states = self.v_proj(hidden_states)
  285. query_states = query_states.view(bsz, q_len, self.num_heads, self.head_dim).transpose(1, 2)
  286. key_states = key_states.view(bsz, q_len, self.num_key_value_heads, self.head_dim).transpose(1, 2)
  287. value_states = value_states.view(bsz, q_len, self.num_key_value_heads, self.head_dim).transpose(1, 2)
  288. if self.qk_layernorm:
  289. query_states = self.q_layernorm(query_states)
  290. key_states = self.k_layernorm(key_states)
  291. cos, sin = position_embeddings
  292. # Partial rotary embedding
  293. query_rot, query_pass = (
  294. query_states[..., : self.rotary_ndims],
  295. query_states[..., self.rotary_ndims :],
  296. )
  297. key_rot, key_pass = (
  298. key_states[..., : self.rotary_ndims],
  299. key_states[..., self.rotary_ndims :],
  300. )
  301. # [batch_size, seq_length, num_heads, head_dim // config.partial_rotary_factor]
  302. query_rot, key_rot = apply_rotary_pos_emb(query_rot, key_rot, cos, sin)
  303. # [batch_size, seq_length, num_heads, head_dim]
  304. query_states = torch.cat((query_rot, query_pass), dim=-1)
  305. key_states = torch.cat((key_rot, key_pass), dim=-1)
  306. if past_key_values is not None:
  307. # Specific to RoPE models with partial rotation
  308. cache_kwargs = {
  309. "sin": sin,
  310. "cos": cos,
  311. "partial_rotation_size": self.rotary_ndims,
  312. "cache_position": cache_position,
  313. }
  314. key_states, value_states = past_key_values.update(key_states, value_states, self.layer_idx, cache_kwargs)
  315. # Repeat k/v heads if n_kv_heads < n_heads
  316. key_states = repeat_kv(key_states, self.num_key_value_groups)
  317. value_states = repeat_kv(value_states, self.num_key_value_groups)
  318. causal_mask = attention_mask
  319. if attention_mask is not None: # no matter the length, we just slice it
  320. causal_mask = attention_mask[:, :, :, : key_states.shape[-2]]
  321. # SDPA with memory-efficient backend is currently (torch==2.1.2) bugged with non-contiguous inputs with custom attn_mask,
  322. # Reference: https://github.com/pytorch/pytorch/issues/112577.
  323. if query_states.device.type == "cuda" and attention_mask is not None:
  324. query_states = query_states.contiguous()
  325. key_states = key_states.contiguous()
  326. value_states = value_states.contiguous()
  327. # We dispatch to SDPA's Flash Attention or Efficient kernels via this `is_causal` if statement instead of an inline conditional assignment
  328. # in SDPA to support both torch.compile's dynamic shapes and full graph options. An inline conditional prevents dynamic shapes from compiling.
  329. # The q_len > 1 is necessary to match with AttentionMaskConverter.to_causal_4d that does not create a causal mask in case q_len == 1.
  330. is_causal = bool(causal_mask is None and q_len > 1)
  331. attn_output = torch.nn.functional.scaled_dot_product_attention(
  332. query_states,
  333. key_states,
  334. value_states,
  335. attn_mask=causal_mask,
  336. dropout_p=self.attention_dropout.p if self.training else 0.0,
  337. is_causal=is_causal,
  338. )
  339. attn_output = attn_output.transpose(1, 2).contiguous()
  340. attn_output = attn_output.view(bsz, q_len, self.hidden_size)
  341. attn_output = self.o_proj(attn_output)
  342. return attn_output, None
  343. class StableLmFlashAttention2(StableLmAttention):
  344. """
  345. StableLM flash attention module. This module inherits from `StableLmAttention` as the weights of the module stays
  346. untouched. The only required change would be on the forward pass where it needs to correctly call the public API of
  347. flash attention and deal with padding tokens in case the input contains any of them.
  348. """
  349. def __init__(self, *args, **kwargs):
  350. super().__init__(*args, **kwargs)
  351. # TODO: Should be removed once Flash Attention for RoCm is bumped to 2.1.
  352. # flash_attn<2.1 generates top-left aligned causal mask, while what is needed here is bottom-right alignment, that was made default for flash_attn>=2.1. This attribute is used to handle this difference. Reference: https://github.com/Dao-AILab/flash-attention/releases/tag/v2.1.0.
  353. # Beware that with flash_attn<2.1, using q_seqlen != k_seqlen (except for the case q_seqlen == 1) produces a wrong mask (top-left).
  354. self._flash_attn_uses_top_left_mask = flash_attn_supports_top_left_mask()
  355. @deprecate_kwarg("past_key_value", new_name="past_key_values", version="4.58")
  356. def forward(
  357. self,
  358. hidden_states: torch.Tensor,
  359. attention_mask: Optional[torch.LongTensor] = None,
  360. position_ids: Optional[torch.LongTensor] = None,
  361. past_key_values: Optional[Cache] = None,
  362. output_attentions: bool = False,
  363. use_cache: bool = False,
  364. cache_position: Optional[torch.LongTensor] = None,
  365. position_embeddings: Optional[tuple[torch.Tensor, torch.Tensor]] = None, # necessary, but kept here for BC
  366. **kwargs,
  367. ) -> tuple[torch.Tensor, Optional[torch.Tensor], Optional[tuple[torch.Tensor]]]:
  368. # StableLmFlashAttention2 attention does not support output_attentions
  369. output_attentions = False
  370. bsz, q_len, _ = hidden_states.size()
  371. query_states = self.q_proj(hidden_states)
  372. key_states = self.k_proj(hidden_states)
  373. value_states = self.v_proj(hidden_states)
  374. # Flash attention requires the input to have the shape
  375. # batch_size x seq_length x head_dim x hidden_dim
  376. # therefore we just need to keep the original shape
  377. query_states = query_states.view(bsz, q_len, self.num_heads, self.head_dim).transpose(1, 2)
  378. key_states = key_states.view(bsz, q_len, self.num_key_value_heads, self.head_dim).transpose(1, 2)
  379. value_states = value_states.view(bsz, q_len, self.num_key_value_heads, self.head_dim).transpose(1, 2)
  380. if self.qk_layernorm:
  381. query_states = self.q_layernorm(query_states)
  382. key_states = self.k_layernorm(key_states)
  383. cos, sin = position_embeddings
  384. # Partial rotary embedding
  385. query_rot, query_pass = (
  386. query_states[..., : self.rotary_ndims],
  387. query_states[..., self.rotary_ndims :],
  388. )
  389. key_rot, key_pass = (
  390. key_states[..., : self.rotary_ndims],
  391. key_states[..., self.rotary_ndims :],
  392. )
  393. query_rot, key_rot = apply_rotary_pos_emb(query_rot, key_rot, cos, sin)
  394. # [batch_size, seq_length, num_heads, head_dim]
  395. query_states = torch.cat((query_rot, query_pass), dim=-1)
  396. key_states = torch.cat((key_rot, key_pass), dim=-1)
  397. if past_key_values is not None:
  398. cache_kwargs = {
  399. "sin": sin,
  400. "cos": cos,
  401. "partial_rotation_size": self.rotary_ndims,
  402. "cache_position": cache_position,
  403. }
  404. key_states, value_states = past_key_values.update(key_states, value_states, self.layer_idx, cache_kwargs)
  405. # TODO: These transpose are quite inefficient but Flash Attention requires the layout [batch_size, sequence_length, num_heads, head_dim]. We would need to refactor the KV cache
  406. # to be able to avoid many of these transpose/reshape/view.
  407. query_states = query_states.transpose(1, 2)
  408. key_states = key_states.transpose(1, 2)
  409. value_states = value_states.transpose(1, 2)
  410. dropout_rate = self.attention_dropout.p if self.training else 0.0
  411. attn_output = _flash_attention_forward(
  412. query_states,
  413. key_states,
  414. value_states,
  415. attention_mask,
  416. q_len,
  417. position_ids=position_ids,
  418. dropout=dropout_rate,
  419. use_top_left_mask=self._flash_attn_uses_top_left_mask,
  420. is_causal=self.is_causal,
  421. )
  422. attn_output = attn_output.reshape(bsz, q_len, self.hidden_size).contiguous()
  423. attn_output = self.o_proj(attn_output)
  424. if not output_attentions:
  425. attn_weights = None
  426. return attn_output, attn_weights
  427. ATTENTION_CLASSES = {
  428. "eager": StableLmAttention,
  429. "sdpa": StableLmSdpaAttention,
  430. "flash_attention_2": StableLmFlashAttention2,
  431. }
  432. class StableLmDecoderLayer(GradientCheckpointingLayer):
  433. def __init__(self, config: StableLmConfig, layer_idx: int):
  434. super().__init__()
  435. self.use_parallel_residual = config.use_parallel_residual
  436. self.hidden_size = config.hidden_size
  437. self.self_attn = ATTENTION_CLASSES[config._attn_implementation](config, layer_idx=layer_idx)
  438. self.mlp = StableLmMLP(config)
  439. self.input_layernorm = nn.LayerNorm(config.hidden_size, eps=config.layer_norm_eps)
  440. self.post_attention_layernorm = None
  441. if not self.use_parallel_residual:
  442. self.post_attention_layernorm = nn.LayerNorm(config.hidden_size, eps=config.layer_norm_eps)
  443. self.dropout = nn.Dropout(config.hidden_dropout)
  444. def forward(
  445. self,
  446. hidden_states: torch.Tensor,
  447. attention_mask: Optional[torch.Tensor] = None,
  448. position_ids: Optional[torch.LongTensor] = None,
  449. past_key_values: Optional[Cache] = None,
  450. output_attentions: Optional[bool] = False,
  451. use_cache: Optional[bool] = False,
  452. cache_position: Optional[torch.LongTensor] = None,
  453. position_embeddings: Optional[tuple[torch.Tensor, torch.Tensor]] = None, # necessary, but kept here for BC
  454. ) -> tuple[torch.FloatTensor, Optional[tuple[torch.FloatTensor, torch.FloatTensor]]]:
  455. """
  456. Args:
  457. hidden_states (`torch.FloatTensor`): input to the layer of shape `(batch, seq_len, embed_dim)`
  458. attention_mask (`torch.FloatTensor`, *optional*): attention mask of size
  459. `(batch, 1, tgt_len, src_len)` where padding elements are indicated by very large negative values.
  460. position_ids (`torch.LongTensor` of shape `({0})`, *optional*):
  461. Indices of positions of each input sequence tokens in the position embeddings. Selected in the range
  462. `[0, config.n_positions - 1]`.
  463. [What are position IDs?](../glossary#position-ids)
  464. past_key_values (`Cache`, *optional*):
  465. cached past key and value projection states
  466. output_attentions (`bool`, *optional*):
  467. Whether or not to return the attentions tensors of all attention layers. See `attentions` under
  468. returned tensors for more detail.
  469. use_cache (`bool`, *optional*):
  470. If set to `True`, `past_key_values` key value states are returned and can be used to speed up decoding
  471. (see `past_key_values`).
  472. cache_position (`torch.LongTensor` of shape `(sequence_length)`, *optional*):
  473. Indices depicting the position of the input sequence tokens in the sequence
  474. position_embeddings (`tuple[torch.FloatTensor, torch.FloatTensor]`, *optional*):
  475. Tuple containing the cosine and sine positional embeddings of shape `(batch_size, seq_len, head_dim)`,
  476. with `head_dim` being the embedding dimension of each attention head.
  477. """
  478. residual = hidden_states
  479. hidden_states = self.input_layernorm(hidden_states)
  480. # Self Attention
  481. self_attn_output, self_attn_weights = self.self_attn(
  482. hidden_states=hidden_states,
  483. attention_mask=attention_mask,
  484. position_ids=position_ids,
  485. past_key_values=past_key_values,
  486. output_attentions=output_attentions,
  487. use_cache=use_cache,
  488. cache_position=cache_position,
  489. position_embeddings=position_embeddings,
  490. )
  491. # copied from transformers.models.gpt_neox.modeling_gpt_neox.GPTNeoXLayer.forward
  492. if self.use_parallel_residual:
  493. # x = x + attn(ln1(x)) + mlp(ln1(x))
  494. # Fully Connected
  495. mlp_output = self.mlp(hidden_states)
  496. mlp_output = self.dropout(mlp_output)
  497. hidden_states = residual + self_attn_output + mlp_output
  498. else:
  499. # x = x + attn(ln1(x))
  500. # x = x + mlp(ln2(x))
  501. residual = residual + self_attn_output
  502. # Fully Connected
  503. mlp_output = self.mlp(self.post_attention_layernorm(residual))
  504. mlp_output = self.dropout(mlp_output)
  505. hidden_states = residual + mlp_output
  506. outputs = (hidden_states,)
  507. if output_attentions:
  508. outputs += (self_attn_weights,)
  509. return outputs
  510. @auto_docstring
  511. class StableLmPreTrainedModel(PreTrainedModel):
  512. config: StableLmConfig
  513. base_model_prefix = "model"
  514. supports_gradient_checkpointing = True
  515. _no_split_modules = ["StableLmDecoderLayer"]
  516. _skip_keys_device_placement = "past_key_values"
  517. _supports_flash_attn = True
  518. _supports_sdpa = True
  519. _can_compile_fullgraph = True
  520. def _init_weights(self, module):
  521. std = self.config.initializer_range
  522. if isinstance(module, nn.Linear):
  523. module.weight.data.normal_(mean=0.0, std=std)
  524. if module.bias is not None:
  525. module.bias.data.zero_()
  526. elif isinstance(module, nn.Embedding):
  527. module.weight.data.normal_(mean=0.0, std=std)
  528. if module.padding_idx is not None:
  529. module.weight.data[module.padding_idx].zero_()
  530. elif isinstance(module, nn.LayerNorm):
  531. module.weight.data.fill_(1.0)
  532. module.bias.data.zero_()
  533. @auto_docstring
  534. class StableLmModel(StableLmPreTrainedModel):
  535. """
  536. Transformer decoder consisting of *config.num_hidden_layers* layers. Each layer is a [`StableLmDecoderLayer`]
  537. Args:
  538. config: StableLmConfig
  539. """
  540. def __init__(self, config: StableLmConfig):
  541. super().__init__(config)
  542. self.padding_idx = config.pad_token_id
  543. self.vocab_size = config.vocab_size
  544. self.embed_tokens = nn.Embedding(config.vocab_size, config.hidden_size, self.padding_idx)
  545. self.layers = nn.ModuleList(
  546. [StableLmDecoderLayer(config, layer_idx) for layer_idx in range(config.num_hidden_layers)]
  547. )
  548. self.norm = nn.LayerNorm(config.hidden_size, eps=config.layer_norm_eps)
  549. self.rotary_emb = StableLmRotaryEmbedding(config=config)
  550. self._attn_implementation = config._attn_implementation
  551. self.gradient_checkpointing = False
  552. # Initialize weights and apply final processing
  553. self.post_init()
  554. @can_return_tuple
  555. @auto_docstring
  556. def forward(
  557. self,
  558. input_ids: Optional[torch.LongTensor] = None,
  559. attention_mask: Optional[torch.Tensor] = None,
  560. position_ids: Optional[torch.LongTensor] = None,
  561. past_key_values: Optional[Cache] = None,
  562. inputs_embeds: Optional[torch.FloatTensor] = None,
  563. use_cache: Optional[bool] = None,
  564. output_attentions: Optional[bool] = None,
  565. output_hidden_states: Optional[bool] = None,
  566. cache_position: Optional[torch.LongTensor] = None,
  567. ) -> BaseModelOutputWithPast:
  568. output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
  569. output_hidden_states = (
  570. output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
  571. )
  572. use_cache = use_cache if use_cache is not None else self.config.use_cache
  573. if (input_ids is None) ^ (inputs_embeds is not None):
  574. raise ValueError("You must specify exactly one of input_ids or inputs_embeds")
  575. if self.gradient_checkpointing and self.training:
  576. if use_cache:
  577. logger.warning_once(
  578. "`use_cache=True` is incompatible with gradient checkpointing. Setting `use_cache=False`..."
  579. )
  580. use_cache = False
  581. if use_cache and past_key_values is None:
  582. past_key_values = DynamicCache(config=self.config)
  583. if inputs_embeds is None:
  584. inputs_embeds = self.embed_tokens(input_ids)
  585. if cache_position is None:
  586. past_seen_tokens = past_key_values.get_seq_length() if past_key_values is not None else 0
  587. cache_position = torch.arange(
  588. past_seen_tokens, past_seen_tokens + inputs_embeds.shape[1], device=inputs_embeds.device
  589. )
  590. if position_ids is None:
  591. position_ids = cache_position.unsqueeze(0)
  592. causal_mask = self._update_causal_mask(
  593. attention_mask, inputs_embeds, cache_position, past_key_values, output_attentions
  594. )
  595. hidden_states = inputs_embeds
  596. # create position embeddings to be shared across the decoder layers
  597. position_embeddings = self.rotary_emb(hidden_states, position_ids)
  598. # decoder layers
  599. all_hidden_states = () if output_hidden_states else None
  600. all_self_attns = () if output_attentions else None
  601. for decoder_layer in self.layers:
  602. if output_hidden_states:
  603. all_hidden_states += (hidden_states,)
  604. layer_outputs = decoder_layer(
  605. hidden_states,
  606. attention_mask=causal_mask,
  607. position_ids=position_ids,
  608. past_key_values=past_key_values,
  609. output_attentions=output_attentions,
  610. use_cache=use_cache,
  611. cache_position=cache_position,
  612. position_embeddings=position_embeddings,
  613. )
  614. hidden_states = layer_outputs[0]
  615. if output_attentions:
  616. all_self_attns += (layer_outputs[1],)
  617. hidden_states = self.norm(hidden_states)
  618. # add hidden states from the last decoder layer
  619. if output_hidden_states:
  620. all_hidden_states += (hidden_states,)
  621. return BaseModelOutputWithPast(
  622. last_hidden_state=hidden_states,
  623. past_key_values=past_key_values,
  624. hidden_states=all_hidden_states,
  625. attentions=all_self_attns,
  626. )
  627. # Copied from transformers.models.gptj.modeling_gptj.GPTJModel._update_causal_mask
  628. def _update_causal_mask(
  629. self,
  630. attention_mask: Union[torch.Tensor, "BlockMask"],
  631. input_tensor: torch.Tensor,
  632. cache_position: torch.Tensor,
  633. past_key_values: Cache,
  634. output_attentions: bool = False,
  635. ):
  636. if self.config._attn_implementation == "flash_attention_2":
  637. if attention_mask is not None and (attention_mask == 0.0).any():
  638. return attention_mask
  639. return None
  640. if self.config._attn_implementation == "flex_attention":
  641. if isinstance(attention_mask, torch.Tensor):
  642. attention_mask = make_flex_block_causal_mask(attention_mask)
  643. return attention_mask
  644. # For SDPA, when possible, we will rely on its `is_causal` argument instead of its `attn_mask` argument, in
  645. # order to dispatch on Flash Attention 2. This feature is not compatible with static cache, as SDPA will fail
  646. # to infer the attention mask.
  647. past_seen_tokens = past_key_values.get_seq_length() if past_key_values is not None else 0
  648. using_compilable_cache = past_key_values.is_compileable if past_key_values is not None else False
  649. # When output attentions is True, sdpa implementation's forward method calls the eager implementation's forward
  650. if self.config._attn_implementation == "sdpa" and not using_compilable_cache and not output_attentions:
  651. if AttentionMaskConverter._ignore_causal_mask_sdpa(
  652. attention_mask,
  653. inputs_embeds=input_tensor,
  654. past_key_values_length=past_seen_tokens,
  655. is_training=self.training,
  656. ):
  657. return None
  658. dtype = input_tensor.dtype
  659. sequence_length = input_tensor.shape[1]
  660. if using_compilable_cache:
  661. target_length = past_key_values.get_max_cache_shape()
  662. else:
  663. target_length = (
  664. attention_mask.shape[-1]
  665. if isinstance(attention_mask, torch.Tensor)
  666. else past_seen_tokens + sequence_length + 1
  667. )
  668. # In case the provided `attention` mask is 2D, we generate a causal mask here (4D).
  669. causal_mask = self._prepare_4d_causal_attention_mask_with_cache_position(
  670. attention_mask,
  671. sequence_length=sequence_length,
  672. target_length=target_length,
  673. dtype=dtype,
  674. cache_position=cache_position,
  675. batch_size=input_tensor.shape[0],
  676. )
  677. if (
  678. self.config._attn_implementation == "sdpa"
  679. and attention_mask is not None
  680. and attention_mask.device.type in ["cuda", "xpu", "npu"]
  681. and not output_attentions
  682. ):
  683. # Attend to all tokens in fully masked rows in the causal_mask, for example the relevant first rows when
  684. # using left padding. This is required by F.scaled_dot_product_attention memory-efficient attention path.
  685. # Details: https://github.com/pytorch/pytorch/issues/110213
  686. min_dtype = torch.finfo(dtype).min
  687. causal_mask = AttentionMaskConverter._unmask_unattended(causal_mask, min_dtype)
  688. return causal_mask
  689. @staticmethod
  690. # Copied from transformers.models.gptj.modeling_gptj.GPTJModel._prepare_4d_causal_attention_mask_with_cache_position
  691. def _prepare_4d_causal_attention_mask_with_cache_position(
  692. attention_mask: torch.Tensor,
  693. sequence_length: int,
  694. target_length: int,
  695. dtype: torch.dtype,
  696. cache_position: torch.Tensor,
  697. batch_size: int,
  698. **kwargs,
  699. ):
  700. """
  701. Creates a causal 4D mask of shape `(batch_size, 1, query_length, key_value_length)` from a 2D mask of shape
  702. `(batch_size, key_value_length)`, or if the input `attention_mask` is already 4D, do nothing.
  703. Args:
  704. attention_mask (`torch.Tensor`):
  705. A 2D attention mask of shape `(batch_size, key_value_length)` or a 4D attention mask of shape
  706. `(batch_size, 1, query_length, key_value_length)`.
  707. sequence_length (`int`):
  708. The sequence length being processed.
  709. target_length (`int`):
  710. The target length: when generating with static cache, the mask should be as long as the static cache,
  711. to account for the 0 padding, the part of the cache that is not filled yet.
  712. dtype (`torch.dtype`):
  713. The dtype to use for the 4D attention mask.
  714. cache_position (`torch.Tensor`):
  715. Indices depicting the position of the input sequence tokens in the sequence.
  716. batch_size (`torch.Tensor`):
  717. Batch size.
  718. """
  719. if attention_mask is not None and attention_mask.dim() == 4:
  720. # In this case we assume that the mask comes already in inverted form and requires no inversion or slicing.
  721. causal_mask = attention_mask
  722. else:
  723. min_dtype = torch.finfo(dtype).min
  724. causal_mask = torch.full(
  725. (sequence_length, target_length), fill_value=min_dtype, dtype=dtype, device=cache_position.device
  726. )
  727. if sequence_length != 1:
  728. causal_mask = torch.triu(causal_mask, diagonal=1)
  729. causal_mask *= torch.arange(target_length, device=cache_position.device) > cache_position.reshape(-1, 1)
  730. causal_mask = causal_mask[None, None, :, :].expand(batch_size, 1, -1, -1)
  731. if attention_mask is not None:
  732. causal_mask = causal_mask.clone() # copy to contiguous memory for in-place edit
  733. mask_length = attention_mask.shape[-1]
  734. padding_mask = causal_mask[:, :, :, :mask_length] + attention_mask[:, None, None, :].to(
  735. causal_mask.device
  736. )
  737. padding_mask = padding_mask == 0
  738. causal_mask[:, :, :, :mask_length] = causal_mask[:, :, :, :mask_length].masked_fill(
  739. padding_mask, min_dtype
  740. )
  741. return causal_mask
  742. # Copied from transformers.models.persimmon.modeling_persimmon.PersimmonForCausalLM with PERSIMMON->STABLELM,Persimmon->StableLm
  743. class StableLmForCausalLM(StableLmPreTrainedModel, GenerationMixin):
  744. _tied_weights_keys = ["lm_head.weight"]
  745. # Copied from transformers.models.llama.modeling_llama.LlamaForCausalLM.__init__ with LLAMA->STABLELM,Llama->StableLm
  746. def __init__(self, config):
  747. super().__init__(config)
  748. self.model = StableLmModel(config)
  749. self.vocab_size = config.vocab_size
  750. self.lm_head = nn.Linear(config.hidden_size, config.vocab_size, bias=False)
  751. # Initialize weights and apply final processing
  752. self.post_init()
  753. @can_return_tuple
  754. @auto_docstring
  755. # Ignore copy
  756. def forward(
  757. self,
  758. input_ids: Optional[torch.LongTensor] = None,
  759. attention_mask: Optional[torch.Tensor] = None,
  760. position_ids: Optional[torch.LongTensor] = None,
  761. past_key_values: Optional[Cache] = None,
  762. inputs_embeds: Optional[torch.FloatTensor] = None,
  763. labels: Optional[torch.LongTensor] = None,
  764. use_cache: Optional[bool] = None,
  765. output_attentions: Optional[bool] = None,
  766. output_hidden_states: Optional[bool] = None,
  767. cache_position: Optional[torch.LongTensor] = None,
  768. logits_to_keep: Union[int, torch.Tensor] = 0,
  769. **kwargs,
  770. ) -> CausalLMOutputWithPast:
  771. r"""
  772. labels (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*):
  773. Labels for computing the masked language modeling loss. Indices should either be in `[0, ...,
  774. config.vocab_size]` or -100 (see `input_ids` docstring). Tokens with indices set to `-100` are ignored
  775. (masked), the loss is only computed for the tokens with labels in `[0, ..., config.vocab_size]`.
  776. Example:
  777. ```python
  778. >>> from transformers import AutoTokenizer, StableLmForCausalLM
  779. >>> model = StableLmForCausalLM.from_pretrained("adept/persimmon-8b-base")
  780. >>> tokenizer = AutoTokenizer.from_pretrained("adept/persimmon-8b-base")
  781. >>> prompt = "human: Hey, what should I eat for dinner?"
  782. >>> inputs = tokenizer(prompt, return_tensors="pt")
  783. >>> # Generate
  784. >>> generate_ids = model.generate(inputs.input_ids, max_length=30)
  785. >>> tokenizer.batch_decode(generate_ids, skip_special_tokens=True, clean_up_tokenization_spaces=False)[0]
  786. 'human: Hey, what should I eat for dinner?\n\ncat: 🐱\n\nhuman: 😐\n\n'
  787. ```"""
  788. output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
  789. output_hidden_states = (
  790. output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
  791. )
  792. # decoder outputs consists of (dec_features, layer_state, dec_hidden, dec_attn)
  793. outputs: BaseModelOutputWithPast = self.model(
  794. input_ids=input_ids,
  795. attention_mask=attention_mask,
  796. position_ids=position_ids,
  797. past_key_values=past_key_values,
  798. inputs_embeds=inputs_embeds,
  799. use_cache=use_cache,
  800. output_attentions=output_attentions,
  801. output_hidden_states=output_hidden_states,
  802. cache_position=cache_position,
  803. )
  804. hidden_states = outputs.last_hidden_state
  805. # No upscaling to float was ever done for StableLm
  806. slice_indices = slice(-logits_to_keep, None) if isinstance(logits_to_keep, int) else logits_to_keep
  807. logits = self.lm_head(hidden_states[:, slice_indices, :])
  808. loss = None
  809. if labels is not None:
  810. loss = self.loss_function(
  811. logits,
  812. labels,
  813. vocab_size=self.config.vocab_size,
  814. **kwargs,
  815. )
  816. return CausalLMOutputWithPast(
  817. loss=loss,
  818. logits=logits,
  819. past_key_values=outputs.past_key_values,
  820. hidden_states=outputs.hidden_states,
  821. attentions=outputs.attentions,
  822. )
  823. class StableLmForSequenceClassification(GenericForSequenceClassification, StableLmPreTrainedModel): ...
  824. class StableLmForTokenClassification(GenericForTokenClassification, StableLmPreTrainedModel): ...
  825. __all__ = [
  826. "StableLmForCausalLM",
  827. "StableLmModel",
  828. "StableLmPreTrainedModel",
  829. "StableLmForSequenceClassification",
  830. "StableLmForTokenClassification",
  831. ]