modular_phi3.py 11 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272
  1. # coding=utf-8
  2. # Copyright 2024 Microsoft and the HuggingFace Inc. team. All rights reserved.
  3. #
  4. # Licensed under the Apache License, Version 2.0 (the "License");
  5. # you may not use this file except in compliance with the License.
  6. # You may obtain a copy of the License at
  7. #
  8. # http://www.apache.org/licenses/LICENSE-2.0
  9. #
  10. # Unless required by applicable law or agreed to in writing, software
  11. # distributed under the License is distributed on an "AS IS" BASIS,
  12. # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  13. # See the License for the specific language governing permissions and
  14. # limitations under the License.
  15. """PyTorch Phi-3 model."""
  16. from typing import Callable, Optional
  17. import torch
  18. from torch import nn
  19. from ...activations import ACT2FN
  20. from ...cache_utils import Cache
  21. from ...generation import GenerationMixin
  22. from ...modeling_flash_attention_utils import FlashAttentionKwargs
  23. from ...modeling_utils import ALL_ATTENTION_FUNCTIONS
  24. from ...processing_utils import Unpack
  25. from ...utils import logging
  26. from ...utils.deprecation import deprecate_kwarg
  27. from ..mistral.modeling_mistral import (
  28. MistralDecoderLayer,
  29. MistralForCausalLM,
  30. MistralForSequenceClassification,
  31. MistralForTokenClassification,
  32. MistralPreTrainedModel,
  33. eager_attention_forward,
  34. rotate_half,
  35. )
  36. from .configuration_phi3 import Phi3Config
  37. logger = logging.get_logger(__name__)
  38. _CHECKPOINT_FOR_DOC = "microsoft/Phi-3-mini-4k-instruct"
  39. _CONFIG_FOR_DOC = "Phi3Config"
  40. class Phi3MLP(nn.Module):
  41. def __init__(self, config):
  42. super().__init__()
  43. self.config = config
  44. self.gate_up_proj = nn.Linear(config.hidden_size, 2 * config.intermediate_size, bias=False)
  45. self.down_proj = nn.Linear(config.intermediate_size, config.hidden_size, bias=False)
  46. self.activation_fn = ACT2FN[config.hidden_act]
  47. def forward(self, hidden_states: torch.FloatTensor) -> torch.FloatTensor:
  48. up_states = self.gate_up_proj(hidden_states)
  49. gate, up_states = up_states.chunk(2, dim=-1)
  50. up_states = up_states * self.activation_fn(gate)
  51. return self.down_proj(up_states)
  52. def apply_rotary_pos_emb(q, k, cos, sin, position_ids=None, unsqueeze_dim=1):
  53. """Applies Rotary Position Embedding to the query and key tensors.
  54. Args:
  55. q (`torch.Tensor`): The query tensor.
  56. k (`torch.Tensor`): The key tensor.
  57. cos (`torch.Tensor`): The cosine part of the rotary embedding.
  58. sin (`torch.Tensor`): The sine part of the rotary embedding.
  59. position_ids (`torch.Tensor`, *optional*):
  60. Deprecated and unused.
  61. unsqueeze_dim (`int`, *optional*, defaults to 1):
  62. The 'unsqueeze_dim' argument specifies the dimension along which to unsqueeze cos[position_ids] and
  63. sin[position_ids] so that they can be properly broadcasted to the dimensions of q and k. For example, note
  64. that cos[position_ids] and sin[position_ids] have the shape [batch_size, seq_len, head_dim]. Then, if q and
  65. k have the shape [batch_size, heads, seq_len, head_dim], then setting unsqueeze_dim=1 makes
  66. cos[position_ids] and sin[position_ids] broadcastable to the shapes of q and k. Similarly, if q and k have
  67. the shape [batch_size, seq_len, heads, head_dim], then set unsqueeze_dim=2.
  68. Returns:
  69. `tuple(torch.Tensor)` comprising of the query and key tensors rotated using the Rotary Position Embedding.
  70. """
  71. cos = cos.unsqueeze(unsqueeze_dim)
  72. sin = sin.unsqueeze(unsqueeze_dim)
  73. rotary_dim = cos.shape[-1]
  74. q_rot, q_pass = q[..., :rotary_dim], q[..., rotary_dim:]
  75. k_rot, k_pass = k[..., :rotary_dim], k[..., rotary_dim:]
  76. q_embed = torch.cat([(q_rot * cos) + (rotate_half(q_rot) * sin), q_pass], dim=-1)
  77. k_embed = torch.cat([(k_rot * cos) + (rotate_half(k_rot) * sin), k_pass], dim=-1)
  78. return q_embed, k_embed
  79. class Phi3Attention(nn.Module):
  80. """Multi-headed attention from 'Attention Is All You Need' paper"""
  81. def __init__(self, config: Phi3Config, layer_idx: Optional[int] = None):
  82. super().__init__()
  83. self.config = config
  84. self.layer_idx = layer_idx
  85. self.head_dim = getattr(config, "head_dim", config.hidden_size // config.num_attention_heads)
  86. self.num_key_value_groups = config.num_attention_heads // config.num_key_value_heads
  87. self.num_key_value_heads = config.num_key_value_heads
  88. self.scaling = self.head_dim**-0.5
  89. self.attention_dropout = config.attention_dropout
  90. self.is_causal = True
  91. op_size = config.num_attention_heads * self.head_dim + 2 * (config.num_key_value_heads * self.head_dim)
  92. self.o_proj = nn.Linear(config.num_attention_heads * self.head_dim, config.hidden_size, bias=False)
  93. self.qkv_proj = nn.Linear(config.hidden_size, op_size, bias=False)
  94. @deprecate_kwarg("past_key_value", new_name="past_key_values", version="4.58")
  95. def forward(
  96. self,
  97. hidden_states: torch.Tensor,
  98. position_embeddings: tuple[torch.Tensor, torch.Tensor],
  99. attention_mask: Optional[torch.Tensor],
  100. past_key_values: Optional[Cache] = None,
  101. cache_position: Optional[torch.LongTensor] = None,
  102. **kwargs: Unpack[FlashAttentionKwargs],
  103. ) -> tuple[torch.Tensor, Optional[torch.Tensor], Optional[tuple[torch.Tensor]]]:
  104. input_shape = hidden_states.shape[:-1]
  105. hidden_shape = (*input_shape, -1, self.head_dim)
  106. qkv = self.qkv_proj(hidden_states)
  107. query_pos = self.config.num_attention_heads * self.head_dim
  108. query_states = qkv[..., :query_pos]
  109. key_states = qkv[..., query_pos : query_pos + self.num_key_value_heads * self.head_dim]
  110. value_states = qkv[..., query_pos + self.num_key_value_heads * self.head_dim :]
  111. query_states = query_states.view(hidden_shape).transpose(1, 2)
  112. key_states = key_states.view(hidden_shape).transpose(1, 2)
  113. value_states = value_states.view(hidden_shape).transpose(1, 2)
  114. cos, sin = position_embeddings
  115. query_states, key_states = apply_rotary_pos_emb(query_states, key_states, cos, sin)
  116. if past_key_values is not None:
  117. # sin and cos are specific to RoPE models; cache_position needed for the static cache
  118. cache_kwargs = {"sin": sin, "cos": cos, "cache_position": cache_position}
  119. key_states, value_states = past_key_values.update(key_states, value_states, self.layer_idx, cache_kwargs)
  120. attention_interface: Callable = eager_attention_forward
  121. if self.config._attn_implementation != "eager":
  122. attention_interface = ALL_ATTENTION_FUNCTIONS[self.config._attn_implementation]
  123. attn_output, attn_weights = attention_interface(
  124. self,
  125. query_states,
  126. key_states,
  127. value_states,
  128. attention_mask,
  129. dropout=0.0 if not self.training else self.attention_dropout,
  130. scaling=self.scaling,
  131. sliding_window=getattr(self.config, "sliding_window", None),
  132. **kwargs,
  133. )
  134. attn_output = attn_output.reshape(*input_shape, -1).contiguous()
  135. attn_output = self.o_proj(attn_output)
  136. return attn_output, attn_weights
  137. class Phi3DecoderLayer(MistralDecoderLayer):
  138. def __init__(self, config: Phi3Config, layer_idx: int):
  139. super().__init__(config, layer_idx)
  140. self.config = config
  141. self.self_attn = Phi3Attention(config=config, layer_idx=layer_idx)
  142. self.mlp = Phi3MLP(config)
  143. self.resid_attn_dropout = nn.Dropout(config.resid_pdrop)
  144. self.resid_mlp_dropout = nn.Dropout(config.resid_pdrop)
  145. @deprecate_kwarg("past_key_value", new_name="past_key_values", version="4.58")
  146. def forward(
  147. self,
  148. hidden_states: torch.Tensor,
  149. attention_mask: Optional[torch.Tensor] = None,
  150. position_ids: Optional[torch.LongTensor] = None,
  151. past_key_values: Optional[Cache] = None,
  152. use_cache: Optional[bool] = False,
  153. cache_position: Optional[torch.LongTensor] = None,
  154. position_embeddings: Optional[tuple[torch.Tensor, torch.Tensor]] = None, # necessary, but kept here for BC
  155. **kwargs: Unpack[FlashAttentionKwargs],
  156. ) -> tuple[torch.FloatTensor, Optional[tuple[torch.FloatTensor, torch.FloatTensor]]]:
  157. residual = hidden_states
  158. hidden_states = self.input_layernorm(hidden_states)
  159. hidden_states, self_attn_weights = self.self_attn(
  160. hidden_states=hidden_states,
  161. attention_mask=attention_mask,
  162. position_ids=position_ids,
  163. past_key_values=past_key_values,
  164. use_cache=use_cache,
  165. cache_position=cache_position,
  166. position_embeddings=position_embeddings,
  167. **kwargs,
  168. )
  169. hidden_states = residual + self.resid_attn_dropout(hidden_states) # main diff with Llama
  170. residual = hidden_states
  171. hidden_states = self.post_attention_layernorm(hidden_states)
  172. hidden_states = self.mlp(hidden_states)
  173. hidden_states = residual + self.resid_mlp_dropout(hidden_states) # main diff with Llama
  174. return hidden_states
  175. class Phi3PreTrainedModel(MistralPreTrainedModel):
  176. _version = "0.0.5"
  177. class Phi3ForCausalLM(MistralForCausalLM):
  178. def prepare_inputs_for_generation(
  179. self,
  180. input_ids,
  181. past_key_values=None,
  182. attention_mask=None,
  183. inputs_embeds=None,
  184. cache_position=None,
  185. position_ids=None,
  186. use_cache=True,
  187. logits_to_keep=None,
  188. **kwargs,
  189. ):
  190. # Overwritten -- this model may need to switch between short and long rope, invalidating the cache in the
  191. # process
  192. # When the first time input length reached long and short factor switching point, enforce re-compute cache
  193. # It will cause downside of slower at this single token position, however, better than current failure.
  194. if (
  195. past_key_values
  196. and self.config.rope_scaling
  197. and input_ids.shape[1] >= self.config.original_max_position_embeddings + 1
  198. ):
  199. past_length = cache_position[0]
  200. if past_length <= self.config.original_max_position_embeddings:
  201. past_key_values = None
  202. model_inputs = GenerationMixin.prepare_inputs_for_generation(
  203. self,
  204. input_ids=input_ids,
  205. past_key_values=past_key_values,
  206. attention_mask=attention_mask,
  207. inputs_embeds=inputs_embeds,
  208. cache_position=cache_position,
  209. position_ids=position_ids,
  210. use_cache=use_cache,
  211. logits_to_keep=logits_to_keep,
  212. **kwargs,
  213. )
  214. return model_inputs
  215. class Phi3ForSequenceClassification(MistralForSequenceClassification):
  216. pass
  217. class Phi3ForTokenClassification(MistralForTokenClassification):
  218. pass
  219. __all__ = [
  220. "Phi3PreTrainedModel",
  221. "Phi3Model", # noqa: F822
  222. "Phi3ForCausalLM",
  223. "Phi3ForSequenceClassification",
  224. "Phi3ForTokenClassification",
  225. ]