modeling_starcoder2.py 21 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485
  1. # 🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨
  2. # This file was automatically generated from src/transformers/models/starcoder2/modular_starcoder2.py.
  3. # Do NOT edit this file manually as any edits will be overwritten by the generation of
  4. # the file from the modular. If any change should be done, please apply the change to the
  5. # modular_starcoder2.py file directly. One of our CI enforces this.
  6. # 🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨
  7. # coding=utf-8
  8. # Copyright 2024 BigCode and the HuggingFace Inc. team. All rights reserved.
  9. #
  10. # This code is based on EleutherAI's GPT-NeoX library and the GPT-NeoX
  11. # and OPT implementations in this library. It has been modified from its
  12. # original forms to accommodate minor architectural differences compared
  13. # to GPT-NeoX and OPT used by the Meta AI team that trained the model.
  14. #
  15. # Licensed under the Apache License, Version 2.0 (the "License");
  16. # you may not use this file except in compliance with the License.
  17. # You may obtain a copy of the License at
  18. #
  19. # http://www.apache.org/licenses/LICENSE-2.0
  20. #
  21. # Unless required by applicable law or agreed to in writing, software
  22. # distributed under the License is distributed on an "AS IS" BASIS,
  23. # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  24. # See the License for the specific language governing permissions and
  25. # limitations under the License.
  26. from typing import Callable, Optional, Union
  27. import torch
  28. from torch import nn
  29. from transformers.utils.generic import check_model_inputs
  30. from ...activations import ACT2FN
  31. from ...cache_utils import Cache, DynamicCache
  32. from ...generation import GenerationMixin
  33. from ...masking_utils import create_causal_mask, create_sliding_window_causal_mask
  34. from ...modeling_flash_attention_utils import FlashAttentionKwargs
  35. from ...modeling_layers import (
  36. GenericForSequenceClassification,
  37. GenericForTokenClassification,
  38. GradientCheckpointingLayer,
  39. )
  40. from ...modeling_outputs import BaseModelOutputWithPast, CausalLMOutputWithPast
  41. from ...modeling_rope_utils import ROPE_INIT_FUNCTIONS, dynamic_rope_update
  42. from ...modeling_utils import ALL_ATTENTION_FUNCTIONS, PreTrainedModel
  43. from ...processing_utils import Unpack
  44. from ...utils import TransformersKwargs, auto_docstring, can_return_tuple
  45. from ...utils.deprecation import deprecate_kwarg
  46. from .configuration_starcoder2 import Starcoder2Config
  47. class Starcoder2MLP(nn.Module):
  48. def __init__(self, config: Starcoder2Config):
  49. super().__init__()
  50. embed_dim = config.hidden_size
  51. self.c_fc = nn.Linear(embed_dim, config.intermediate_size, bias=config.use_bias)
  52. self.c_proj = nn.Linear(config.intermediate_size, embed_dim, bias=config.use_bias)
  53. self.act = ACT2FN[config.hidden_act]
  54. self.residual_dropout = config.residual_dropout
  55. def forward(self, hidden_states: Optional[tuple[torch.FloatTensor]]) -> torch.FloatTensor:
  56. hidden_states = self.c_fc(hidden_states)
  57. hidden_states = self.act(hidden_states)
  58. hidden_states = self.c_proj(hidden_states)
  59. hidden_states = nn.functional.dropout(hidden_states, p=self.residual_dropout, training=self.training)
  60. return hidden_states
  61. def rotate_half(x):
  62. """Rotates half the hidden dims of the input."""
  63. x1 = x[..., : x.shape[-1] // 2]
  64. x2 = x[..., x.shape[-1] // 2 :]
  65. return torch.cat((-x2, x1), dim=-1)
  66. def apply_rotary_pos_emb(q, k, cos, sin, position_ids=None, unsqueeze_dim=1):
  67. """Applies Rotary Position Embedding to the query and key tensors.
  68. Args:
  69. q (`torch.Tensor`): The query tensor.
  70. k (`torch.Tensor`): The key tensor.
  71. cos (`torch.Tensor`): The cosine part of the rotary embedding.
  72. sin (`torch.Tensor`): The sine part of the rotary embedding.
  73. position_ids (`torch.Tensor`, *optional*):
  74. Deprecated and unused.
  75. unsqueeze_dim (`int`, *optional*, defaults to 1):
  76. The 'unsqueeze_dim' argument specifies the dimension along which to unsqueeze cos[position_ids] and
  77. sin[position_ids] so that they can be properly broadcasted to the dimensions of q and k. For example, note
  78. that cos[position_ids] and sin[position_ids] have the shape [batch_size, seq_len, head_dim]. Then, if q and
  79. k have the shape [batch_size, heads, seq_len, head_dim], then setting unsqueeze_dim=1 makes
  80. cos[position_ids] and sin[position_ids] broadcastable to the shapes of q and k. Similarly, if q and k have
  81. the shape [batch_size, seq_len, heads, head_dim], then set unsqueeze_dim=2.
  82. Returns:
  83. `tuple(torch.Tensor)` comprising of the query and key tensors rotated using the Rotary Position Embedding.
  84. """
  85. cos = cos.unsqueeze(unsqueeze_dim)
  86. sin = sin.unsqueeze(unsqueeze_dim)
  87. q_embed = (q * cos) + (rotate_half(q) * sin)
  88. k_embed = (k * cos) + (rotate_half(k) * sin)
  89. return q_embed, k_embed
  90. def repeat_kv(hidden_states: torch.Tensor, n_rep: int) -> torch.Tensor:
  91. """
  92. This is the equivalent of torch.repeat_interleave(x, dim=1, repeats=n_rep). The hidden states go from (batch,
  93. num_key_value_heads, seqlen, head_dim) to (batch, num_attention_heads, seqlen, head_dim)
  94. """
  95. batch, num_key_value_heads, slen, head_dim = hidden_states.shape
  96. if n_rep == 1:
  97. return hidden_states
  98. hidden_states = hidden_states[:, :, None, :, :].expand(batch, num_key_value_heads, n_rep, slen, head_dim)
  99. return hidden_states.reshape(batch, num_key_value_heads * n_rep, slen, head_dim)
  100. def eager_attention_forward(
  101. module: nn.Module,
  102. query: torch.Tensor,
  103. key: torch.Tensor,
  104. value: torch.Tensor,
  105. attention_mask: Optional[torch.Tensor],
  106. scaling: float,
  107. dropout: float = 0.0,
  108. **kwargs: Unpack[TransformersKwargs],
  109. ):
  110. key_states = repeat_kv(key, module.num_key_value_groups)
  111. value_states = repeat_kv(value, module.num_key_value_groups)
  112. attn_weights = torch.matmul(query, key_states.transpose(2, 3)) * scaling
  113. if attention_mask is not None:
  114. causal_mask = attention_mask[:, :, :, : key_states.shape[-2]]
  115. attn_weights = attn_weights + causal_mask
  116. attn_weights = nn.functional.softmax(attn_weights, dim=-1, dtype=torch.float32).to(query.dtype)
  117. attn_weights = nn.functional.dropout(attn_weights, p=dropout, training=module.training)
  118. attn_output = torch.matmul(attn_weights, value_states)
  119. attn_output = attn_output.transpose(1, 2).contiguous()
  120. return attn_output, attn_weights
  121. class Starcoder2Attention(nn.Module):
  122. """Multi-headed attention from 'Attention Is All You Need' paper"""
  123. def __init__(self, config: Starcoder2Config, layer_idx: Optional[int] = None):
  124. super().__init__()
  125. self.config = config
  126. self.layer_idx = layer_idx
  127. self.head_dim = getattr(config, "head_dim", None) or config.hidden_size // config.num_attention_heads
  128. self.num_key_value_groups = config.num_attention_heads // config.num_key_value_heads
  129. self.scaling = self.head_dim**-0.5
  130. self.attention_dropout = config.attention_dropout
  131. self.is_causal = True
  132. self.q_proj = nn.Linear(config.hidden_size, config.num_attention_heads * self.head_dim, bias=config.use_bias)
  133. self.k_proj = nn.Linear(config.hidden_size, config.num_key_value_heads * self.head_dim, bias=config.use_bias)
  134. self.v_proj = nn.Linear(config.hidden_size, config.num_key_value_heads * self.head_dim, bias=config.use_bias)
  135. self.o_proj = nn.Linear(config.num_attention_heads * self.head_dim, config.hidden_size, bias=config.use_bias)
  136. self.residual_dropout = config.residual_dropout
  137. @deprecate_kwarg("past_key_value", new_name="past_key_values", version="4.58")
  138. def forward(
  139. self,
  140. hidden_states: torch.Tensor,
  141. position_embeddings: tuple[torch.Tensor, torch.Tensor],
  142. attention_mask: Optional[torch.Tensor],
  143. past_key_values: Optional[Cache] = None,
  144. cache_position: Optional[torch.LongTensor] = None,
  145. **kwargs: Unpack[FlashAttentionKwargs],
  146. ) -> tuple[torch.Tensor, Optional[torch.Tensor], Optional[tuple[torch.Tensor]]]:
  147. input_shape = hidden_states.shape[:-1]
  148. hidden_shape = (*input_shape, -1, self.head_dim)
  149. query_states = self.q_proj(hidden_states).view(hidden_shape).transpose(1, 2)
  150. key_states = self.k_proj(hidden_states).view(hidden_shape).transpose(1, 2)
  151. value_states = self.v_proj(hidden_states).view(hidden_shape).transpose(1, 2)
  152. cos, sin = position_embeddings
  153. query_states, key_states = apply_rotary_pos_emb(query_states, key_states, cos, sin)
  154. if past_key_values is not None:
  155. # sin and cos are specific to RoPE models; cache_position needed for the static cache
  156. cache_kwargs = {"sin": sin, "cos": cos, "cache_position": cache_position}
  157. key_states, value_states = past_key_values.update(key_states, value_states, self.layer_idx, cache_kwargs)
  158. attention_interface: Callable = eager_attention_forward
  159. if self.config._attn_implementation != "eager":
  160. attention_interface = ALL_ATTENTION_FUNCTIONS[self.config._attn_implementation]
  161. attn_output, attn_weights = attention_interface(
  162. self,
  163. query_states,
  164. key_states,
  165. value_states,
  166. attention_mask,
  167. dropout=0.0 if not self.training else self.attention_dropout,
  168. scaling=self.scaling,
  169. sliding_window=getattr(self.config, "sliding_window", None), # diff with Llama
  170. **kwargs,
  171. )
  172. attn_output = attn_output.reshape(*input_shape, -1).contiguous()
  173. attn_output = self.o_proj(attn_output)
  174. attn_output = nn.functional.dropout(
  175. attn_output, p=self.residual_dropout, training=self.training
  176. ) # diff with Llama
  177. return attn_output, attn_weights
  178. class Starcoder2DecoderLayer(GradientCheckpointingLayer):
  179. def __init__(self, config: Starcoder2Config, layer_idx: int):
  180. super().__init__()
  181. self.hidden_size = config.hidden_size
  182. self.self_attn = Starcoder2Attention(config=config, layer_idx=layer_idx)
  183. self.mlp = Starcoder2MLP(config)
  184. self.input_layernorm = nn.LayerNorm(config.hidden_size, eps=config.norm_epsilon)
  185. self.post_attention_layernorm = nn.LayerNorm(config.hidden_size, eps=config.norm_epsilon)
  186. @deprecate_kwarg("past_key_value", new_name="past_key_values", version="4.58")
  187. def forward(
  188. self,
  189. hidden_states: torch.Tensor,
  190. attention_mask: Optional[torch.Tensor] = None,
  191. position_ids: Optional[torch.LongTensor] = None,
  192. past_key_values: Optional[Cache] = None,
  193. use_cache: Optional[bool] = False,
  194. cache_position: Optional[torch.LongTensor] = None,
  195. position_embeddings: Optional[tuple[torch.Tensor, torch.Tensor]] = None, # necessary, but kept here for BC
  196. **kwargs: Unpack[TransformersKwargs],
  197. ) -> torch.Tensor:
  198. residual = hidden_states
  199. hidden_states = self.input_layernorm(hidden_states)
  200. # Self Attention
  201. hidden_states, _ = self.self_attn(
  202. hidden_states=hidden_states,
  203. attention_mask=attention_mask,
  204. position_ids=position_ids,
  205. past_key_values=past_key_values,
  206. use_cache=use_cache,
  207. cache_position=cache_position,
  208. position_embeddings=position_embeddings,
  209. **kwargs,
  210. )
  211. hidden_states = residual + hidden_states
  212. # Fully Connected
  213. residual = hidden_states
  214. hidden_states = self.post_attention_layernorm(hidden_states)
  215. hidden_states = self.mlp(hidden_states)
  216. hidden_states = residual + hidden_states
  217. return hidden_states
  218. class Starcoder2RotaryEmbedding(nn.Module):
  219. inv_freq: torch.Tensor # fix linting for `register_buffer`
  220. def __init__(self, config: Starcoder2Config, device=None):
  221. super().__init__()
  222. # BC: "rope_type" was originally "type"
  223. if hasattr(config, "rope_scaling") and isinstance(config.rope_scaling, dict):
  224. self.rope_type = config.rope_scaling.get("rope_type", config.rope_scaling.get("type"))
  225. else:
  226. self.rope_type = "default"
  227. self.max_seq_len_cached = config.max_position_embeddings
  228. self.original_max_seq_len = config.max_position_embeddings
  229. self.config = config
  230. self.rope_init_fn = ROPE_INIT_FUNCTIONS[self.rope_type]
  231. inv_freq, self.attention_scaling = self.rope_init_fn(self.config, device)
  232. self.register_buffer("inv_freq", inv_freq, persistent=False)
  233. self.original_inv_freq = self.inv_freq
  234. @torch.no_grad()
  235. @dynamic_rope_update # power user: used with advanced RoPE types (e.g. dynamic rope)
  236. def forward(self, x, position_ids):
  237. inv_freq_expanded = self.inv_freq[None, :, None].float().expand(position_ids.shape[0], -1, 1).to(x.device)
  238. position_ids_expanded = position_ids[:, None, :].float()
  239. device_type = x.device.type if isinstance(x.device.type, str) and x.device.type != "mps" else "cpu"
  240. with torch.autocast(device_type=device_type, enabled=False): # Force float32
  241. freqs = (inv_freq_expanded.float() @ position_ids_expanded.float()).transpose(1, 2)
  242. emb = torch.cat((freqs, freqs), dim=-1)
  243. cos = emb.cos() * self.attention_scaling
  244. sin = emb.sin() * self.attention_scaling
  245. return cos.to(dtype=x.dtype), sin.to(dtype=x.dtype)
  246. @auto_docstring
  247. class Starcoder2PreTrainedModel(PreTrainedModel):
  248. config: Starcoder2Config
  249. base_model_prefix = "model"
  250. supports_gradient_checkpointing = True
  251. _no_split_modules = ["Starcoder2DecoderLayer"]
  252. _skip_keys_device_placement = ["past_key_values"]
  253. _supports_flash_attn = True
  254. _supports_sdpa = True
  255. _supports_flex_attn = True
  256. _can_compile_fullgraph = True
  257. _supports_attention_backend = True
  258. _can_record_outputs = {
  259. "hidden_states": Starcoder2DecoderLayer,
  260. "attentions": Starcoder2Attention,
  261. }
  262. @auto_docstring
  263. class Starcoder2Model(Starcoder2PreTrainedModel):
  264. def __init__(self, config: Starcoder2Config):
  265. super().__init__(config)
  266. self.padding_idx = config.pad_token_id
  267. self.vocab_size = config.vocab_size
  268. self.embed_tokens = nn.Embedding(config.vocab_size, config.hidden_size, self.padding_idx)
  269. self.layers = nn.ModuleList(
  270. [Starcoder2DecoderLayer(config, layer_idx) for layer_idx in range(config.num_hidden_layers)]
  271. )
  272. self.norm = nn.LayerNorm(config.hidden_size, eps=config.norm_epsilon)
  273. self.rotary_emb = Starcoder2RotaryEmbedding(config=config)
  274. self.gradient_checkpointing = False
  275. self.embedding_dropout = config.embedding_dropout
  276. # Initialize weights and apply final processing
  277. self.post_init()
  278. @check_model_inputs()
  279. def forward(
  280. self,
  281. input_ids: Optional[torch.LongTensor] = None,
  282. attention_mask: Optional[torch.Tensor] = None,
  283. position_ids: Optional[torch.LongTensor] = None,
  284. past_key_values: Optional[Union[Cache, list[torch.FloatTensor]]] = None,
  285. inputs_embeds: Optional[torch.FloatTensor] = None,
  286. use_cache: Optional[bool] = None,
  287. cache_position: Optional[torch.LongTensor] = None,
  288. **kwargs: Unpack[TransformersKwargs],
  289. ) -> BaseModelOutputWithPast:
  290. if (input_ids is None) ^ (inputs_embeds is not None):
  291. raise ValueError("You must specify exactly one of input_ids or inputs_embeds")
  292. if inputs_embeds is None:
  293. inputs_embeds = self.embed_tokens(input_ids)
  294. if use_cache and past_key_values is None:
  295. past_key_values = DynamicCache(config=self.config)
  296. if cache_position is None:
  297. past_seen_tokens = past_key_values.get_seq_length() if past_key_values is not None else 0
  298. cache_position = torch.arange(
  299. past_seen_tokens, past_seen_tokens + inputs_embeds.shape[1], device=inputs_embeds.device
  300. )
  301. if position_ids is None:
  302. position_ids = cache_position.unsqueeze(0)
  303. mask_function = create_causal_mask if self.config.sliding_window is None else create_sliding_window_causal_mask
  304. causal_mask = mask_function(
  305. config=self.config,
  306. input_embeds=inputs_embeds,
  307. attention_mask=attention_mask,
  308. cache_position=cache_position,
  309. past_key_values=past_key_values,
  310. position_ids=position_ids,
  311. )
  312. hidden_states = inputs_embeds
  313. hidden_states = nn.functional.dropout(
  314. hidden_states, p=self.embedding_dropout, training=self.training
  315. ) # main diff with Llama
  316. # create position embeddings to be shared across the decoder layers
  317. position_embeddings = self.rotary_emb(hidden_states, position_ids)
  318. for decoder_layer in self.layers[: self.config.num_hidden_layers]:
  319. hidden_states = decoder_layer(
  320. hidden_states,
  321. attention_mask=causal_mask,
  322. position_ids=position_ids,
  323. past_key_values=past_key_values,
  324. use_cache=use_cache,
  325. cache_position=cache_position,
  326. position_embeddings=position_embeddings,
  327. **kwargs,
  328. )
  329. hidden_states = self.norm(hidden_states)
  330. return BaseModelOutputWithPast(
  331. last_hidden_state=hidden_states,
  332. past_key_values=past_key_values if use_cache else None,
  333. )
  334. @auto_docstring
  335. class Starcoder2ForCausalLM(Starcoder2PreTrainedModel, GenerationMixin):
  336. _tied_weights_keys = ["lm_head.weight"]
  337. _tp_plan = {"lm_head": "colwise_rep"}
  338. _pp_plan = {"lm_head": (["hidden_states"], ["logits"])}
  339. def __init__(self, config):
  340. super().__init__(config)
  341. self.model = Starcoder2Model(config)
  342. self.vocab_size = config.vocab_size
  343. self.lm_head = nn.Linear(config.hidden_size, config.vocab_size, bias=False)
  344. # Initialize weights and apply final processing
  345. self.post_init()
  346. @can_return_tuple
  347. @auto_docstring
  348. def forward(
  349. self,
  350. input_ids: Optional[torch.LongTensor] = None,
  351. attention_mask: Optional[torch.Tensor] = None,
  352. position_ids: Optional[torch.LongTensor] = None,
  353. past_key_values: Optional[Cache] = None,
  354. inputs_embeds: Optional[torch.FloatTensor] = None,
  355. labels: Optional[torch.LongTensor] = None,
  356. use_cache: Optional[bool] = None,
  357. cache_position: Optional[torch.LongTensor] = None,
  358. logits_to_keep: Union[int, torch.Tensor] = 0,
  359. **kwargs: Unpack[TransformersKwargs],
  360. ) -> CausalLMOutputWithPast:
  361. r"""
  362. Example:
  363. ```python
  364. >>> from transformers import AutoTokenizer, Starcoder2ForCausalLM
  365. >>> model = Starcoder2ForCausalLM.from_pretrained("meta-starcoder2/Starcoder2-2-7b-hf")
  366. >>> tokenizer = AutoTokenizer.from_pretrained("meta-starcoder2/Starcoder2-2-7b-hf")
  367. >>> prompt = "Hey, are you conscious? Can you talk to me?"
  368. >>> inputs = tokenizer(prompt, return_tensors="pt")
  369. >>> # Generate
  370. >>> generate_ids = model.generate(inputs.input_ids, max_length=30)
  371. >>> tokenizer.batch_decode(generate_ids, skip_special_tokens=True, clean_up_tokenization_spaces=False)[0]
  372. "Hey, are you conscious? Can you talk to me?\nI'm not conscious, but I can talk to you."
  373. ```"""
  374. outputs: BaseModelOutputWithPast = self.model(
  375. input_ids=input_ids,
  376. attention_mask=attention_mask,
  377. position_ids=position_ids,
  378. past_key_values=past_key_values,
  379. inputs_embeds=inputs_embeds,
  380. use_cache=use_cache,
  381. cache_position=cache_position,
  382. **kwargs,
  383. )
  384. hidden_states = outputs.last_hidden_state
  385. # Only compute necessary logits, and do not upcast them to float if we are not computing the loss
  386. slice_indices = slice(-logits_to_keep, None) if isinstance(logits_to_keep, int) else logits_to_keep
  387. logits = self.lm_head(hidden_states[:, slice_indices, :])
  388. loss = None
  389. if labels is not None:
  390. loss = self.loss_function(logits=logits, labels=labels, vocab_size=self.config.vocab_size, **kwargs)
  391. return CausalLMOutputWithPast(
  392. loss=loss,
  393. logits=logits,
  394. past_key_values=outputs.past_key_values,
  395. hidden_states=outputs.hidden_states,
  396. attentions=outputs.attentions,
  397. )
  398. class Starcoder2ForSequenceClassification(GenericForSequenceClassification, Starcoder2PreTrainedModel):
  399. pass
  400. class Starcoder2ForTokenClassification(GenericForTokenClassification, Starcoder2PreTrainedModel):
  401. pass
  402. __all__ = [
  403. "Starcoder2ForCausalLM",
  404. "Starcoder2Model",
  405. "Starcoder2PreTrainedModel",
  406. "Starcoder2ForSequenceClassification",
  407. "Starcoder2ForTokenClassification",
  408. ]