modular_starcoder2.py 9.6 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238
  1. # coding=utf-8
  2. # Copyright 2024 BigCode and the HuggingFace Inc. team. All rights reserved.
  3. #
  4. # This code is based on EleutherAI's GPT-NeoX library and the GPT-NeoX
  5. # and OPT implementations in this library. It has been modified from its
  6. # original forms to accommodate minor architectural differences compared
  7. # to GPT-NeoX and OPT used by the Meta AI team that trained the model.
  8. #
  9. # Licensed under the Apache License, Version 2.0 (the "License");
  10. # you may not use this file except in compliance with the License.
  11. # You may obtain a copy of the License at
  12. #
  13. # http://www.apache.org/licenses/LICENSE-2.0
  14. #
  15. # Unless required by applicable law or agreed to in writing, software
  16. # distributed under the License is distributed on an "AS IS" BASIS,
  17. # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  18. # See the License for the specific language governing permissions and
  19. # limitations under the License.
  20. """PyTorch Starcoder2 model."""
  21. from typing import Callable, Optional, Union
  22. import torch
  23. from torch import nn
  24. from transformers.utils.generic import check_model_inputs
  25. from ...activations import ACT2FN
  26. from ...cache_utils import Cache, DynamicCache
  27. from ...masking_utils import create_causal_mask, create_sliding_window_causal_mask
  28. from ...modeling_flash_attention_utils import FlashAttentionKwargs
  29. from ...modeling_outputs import BaseModelOutputWithPast
  30. from ...modeling_utils import ALL_ATTENTION_FUNCTIONS
  31. from ...processing_utils import Unpack
  32. from ...utils import TransformersKwargs, logging
  33. from ...utils.deprecation import deprecate_kwarg
  34. from ..mistral.modeling_mistral import (
  35. MistralAttention,
  36. MistralDecoderLayer,
  37. MistralForCausalLM,
  38. MistralForSequenceClassification,
  39. MistralForTokenClassification,
  40. MistralModel,
  41. MistralRotaryEmbedding,
  42. apply_rotary_pos_emb,
  43. eager_attention_forward,
  44. )
  45. from .configuration_starcoder2 import Starcoder2Config
  46. logger = logging.get_logger(__name__)
  47. class Starcoder2MLP(nn.Module):
  48. def __init__(self, config: Starcoder2Config):
  49. super().__init__()
  50. embed_dim = config.hidden_size
  51. self.c_fc = nn.Linear(embed_dim, config.intermediate_size, bias=config.use_bias)
  52. self.c_proj = nn.Linear(config.intermediate_size, embed_dim, bias=config.use_bias)
  53. self.act = ACT2FN[config.hidden_act]
  54. self.residual_dropout = config.residual_dropout
  55. def forward(self, hidden_states: Optional[tuple[torch.FloatTensor]]) -> torch.FloatTensor:
  56. hidden_states = self.c_fc(hidden_states)
  57. hidden_states = self.act(hidden_states)
  58. hidden_states = self.c_proj(hidden_states)
  59. hidden_states = nn.functional.dropout(hidden_states, p=self.residual_dropout, training=self.training)
  60. return hidden_states
  61. class Starcoder2Attention(MistralAttention):
  62. def __init__(self, config: Starcoder2Config, layer_idx: Optional[int] = None):
  63. super().__init__(config=config, layer_idx=layer_idx)
  64. self.residual_dropout = config.residual_dropout
  65. self.q_proj = nn.Linear(config.hidden_size, config.num_attention_heads * self.head_dim, bias=config.use_bias)
  66. self.k_proj = nn.Linear(config.hidden_size, config.num_key_value_heads * self.head_dim, bias=config.use_bias)
  67. self.v_proj = nn.Linear(config.hidden_size, config.num_key_value_heads * self.head_dim, bias=config.use_bias)
  68. self.o_proj = nn.Linear(config.num_attention_heads * self.head_dim, config.hidden_size, bias=config.use_bias)
  69. @deprecate_kwarg("past_key_value", new_name="past_key_values", version="4.58")
  70. def forward(
  71. self,
  72. hidden_states: torch.Tensor,
  73. position_embeddings: tuple[torch.Tensor, torch.Tensor],
  74. attention_mask: Optional[torch.Tensor],
  75. past_key_values: Optional[Cache] = None,
  76. cache_position: Optional[torch.LongTensor] = None,
  77. **kwargs: Unpack[FlashAttentionKwargs],
  78. ) -> tuple[torch.Tensor, Optional[torch.Tensor], Optional[tuple[torch.Tensor]]]:
  79. input_shape = hidden_states.shape[:-1]
  80. hidden_shape = (*input_shape, -1, self.head_dim)
  81. query_states = self.q_proj(hidden_states).view(hidden_shape).transpose(1, 2)
  82. key_states = self.k_proj(hidden_states).view(hidden_shape).transpose(1, 2)
  83. value_states = self.v_proj(hidden_states).view(hidden_shape).transpose(1, 2)
  84. cos, sin = position_embeddings
  85. query_states, key_states = apply_rotary_pos_emb(query_states, key_states, cos, sin)
  86. if past_key_values is not None:
  87. # sin and cos are specific to RoPE models; cache_position needed for the static cache
  88. cache_kwargs = {"sin": sin, "cos": cos, "cache_position": cache_position}
  89. key_states, value_states = past_key_values.update(key_states, value_states, self.layer_idx, cache_kwargs)
  90. attention_interface: Callable = eager_attention_forward
  91. if self.config._attn_implementation != "eager":
  92. attention_interface = ALL_ATTENTION_FUNCTIONS[self.config._attn_implementation]
  93. attn_output, attn_weights = attention_interface(
  94. self,
  95. query_states,
  96. key_states,
  97. value_states,
  98. attention_mask,
  99. dropout=0.0 if not self.training else self.attention_dropout,
  100. scaling=self.scaling,
  101. sliding_window=getattr(self.config, "sliding_window", None), # diff with Llama
  102. **kwargs,
  103. )
  104. attn_output = attn_output.reshape(*input_shape, -1).contiguous()
  105. attn_output = self.o_proj(attn_output)
  106. attn_output = nn.functional.dropout(
  107. attn_output, p=self.residual_dropout, training=self.training
  108. ) # diff with Llama
  109. return attn_output, attn_weights
  110. class Starcoder2DecoderLayer(MistralDecoderLayer):
  111. def __init__(self, config: Starcoder2Config, layer_idx: int):
  112. super().__init__(config, layer_idx)
  113. self.self_attn = Starcoder2Attention(config=config, layer_idx=layer_idx)
  114. self.mlp = Starcoder2MLP(config)
  115. self.input_layernorm = nn.LayerNorm(config.hidden_size, eps=config.norm_epsilon)
  116. self.post_attention_layernorm = nn.LayerNorm(config.hidden_size, eps=config.norm_epsilon)
  117. class Starcoder2RotaryEmbedding(MistralRotaryEmbedding):
  118. pass
  119. class Starcoder2Model(MistralModel):
  120. def __init__(self, config: Starcoder2Config):
  121. super().__init__(config)
  122. self.layers = nn.ModuleList(
  123. [Starcoder2DecoderLayer(config, layer_idx) for layer_idx in range(config.num_hidden_layers)]
  124. )
  125. self.norm = nn.LayerNorm(config.hidden_size, eps=config.norm_epsilon)
  126. self.embedding_dropout = config.embedding_dropout
  127. @check_model_inputs()
  128. def forward(
  129. self,
  130. input_ids: Optional[torch.LongTensor] = None,
  131. attention_mask: Optional[torch.Tensor] = None,
  132. position_ids: Optional[torch.LongTensor] = None,
  133. past_key_values: Optional[Union[Cache, list[torch.FloatTensor]]] = None,
  134. inputs_embeds: Optional[torch.FloatTensor] = None,
  135. use_cache: Optional[bool] = None,
  136. cache_position: Optional[torch.LongTensor] = None,
  137. **kwargs: Unpack[TransformersKwargs],
  138. ) -> BaseModelOutputWithPast:
  139. if (input_ids is None) ^ (inputs_embeds is not None):
  140. raise ValueError("You must specify exactly one of input_ids or inputs_embeds")
  141. if inputs_embeds is None:
  142. inputs_embeds = self.embed_tokens(input_ids)
  143. if use_cache and past_key_values is None:
  144. past_key_values = DynamicCache(config=self.config)
  145. if cache_position is None:
  146. past_seen_tokens = past_key_values.get_seq_length() if past_key_values is not None else 0
  147. cache_position = torch.arange(
  148. past_seen_tokens, past_seen_tokens + inputs_embeds.shape[1], device=inputs_embeds.device
  149. )
  150. if position_ids is None:
  151. position_ids = cache_position.unsqueeze(0)
  152. mask_function = create_causal_mask if self.config.sliding_window is None else create_sliding_window_causal_mask
  153. causal_mask = mask_function(
  154. config=self.config,
  155. input_embeds=inputs_embeds,
  156. attention_mask=attention_mask,
  157. cache_position=cache_position,
  158. past_key_values=past_key_values,
  159. position_ids=position_ids,
  160. )
  161. hidden_states = inputs_embeds
  162. hidden_states = nn.functional.dropout(
  163. hidden_states, p=self.embedding_dropout, training=self.training
  164. ) # main diff with Llama
  165. # create position embeddings to be shared across the decoder layers
  166. position_embeddings = self.rotary_emb(hidden_states, position_ids)
  167. for decoder_layer in self.layers[: self.config.num_hidden_layers]:
  168. hidden_states = decoder_layer(
  169. hidden_states,
  170. attention_mask=causal_mask,
  171. position_ids=position_ids,
  172. past_key_values=past_key_values,
  173. use_cache=use_cache,
  174. cache_position=cache_position,
  175. position_embeddings=position_embeddings,
  176. **kwargs,
  177. )
  178. hidden_states = self.norm(hidden_states)
  179. return BaseModelOutputWithPast(
  180. last_hidden_state=hidden_states,
  181. past_key_values=past_key_values if use_cache else None,
  182. )
  183. class Starcoder2ForCausalLM(MistralForCausalLM):
  184. pass
  185. class Starcoder2ForSequenceClassification(MistralForSequenceClassification):
  186. pass
  187. class Starcoder2ForTokenClassification(MistralForTokenClassification):
  188. pass
  189. __all__ = [
  190. "Starcoder2ForCausalLM",
  191. "Starcoder2Model",
  192. "Starcoder2PreTrainedModel", # noqa: F822
  193. "Starcoder2ForSequenceClassification",
  194. "Starcoder2ForTokenClassification",
  195. ]