modular_phi.py 11 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301
  1. from typing import Callable, Optional
  2. import torch
  3. import torch.nn as nn
  4. from ...cache_utils import Cache, DynamicCache
  5. from ...masking_utils import create_causal_mask
  6. from ...modeling_layers import GradientCheckpointingLayer
  7. from ...modeling_outputs import (
  8. BaseModelOutputWithPast,
  9. )
  10. from ...modeling_utils import ALL_ATTENTION_FUNCTIONS
  11. from ...processing_utils import Unpack
  12. from ...utils import TransformersKwargs, logging
  13. from ...utils.deprecation import deprecate_kwarg
  14. from ..clip.modeling_clip import CLIPMLP
  15. from ..llama.modeling_llama import (
  16. LlamaAttention,
  17. LlamaForCausalLM,
  18. LlamaForSequenceClassification,
  19. LlamaForTokenClassification,
  20. LlamaModel,
  21. LlamaRotaryEmbedding,
  22. apply_rotary_pos_emb,
  23. eager_attention_forward, # copied from Llama
  24. )
  25. from .configuration_phi import PhiConfig
  26. logger = logging.get_logger(__name__)
  27. _CHECKPOINT_FOR_DOC = "microsoft/phi-1"
  28. _CONFIG_FOR_DOC = "PhiConfig"
  29. class PhiAttention(LlamaAttention):
  30. def __init__(self, config: PhiConfig, layer_idx: int):
  31. super().__init__(config, layer_idx)
  32. self.q_proj = nn.Linear(config.hidden_size, config.num_attention_heads * self.head_dim, bias=True)
  33. self.k_proj = nn.Linear(config.hidden_size, config.num_key_value_heads * self.head_dim, bias=True)
  34. self.v_proj = nn.Linear(config.hidden_size, config.num_key_value_heads * self.head_dim, bias=True)
  35. self.dense = nn.Linear(config.num_attention_heads * self.head_dim, config.hidden_size, bias=True)
  36. del self.o_proj
  37. self.rotary_ndims = int(self.head_dim * config.partial_rotary_factor)
  38. self.qk_layernorm = config.qk_layernorm
  39. if self.qk_layernorm:
  40. self.q_layernorm = nn.LayerNorm(
  41. config.hidden_size // config.num_attention_heads, eps=config.layer_norm_eps, elementwise_affine=True
  42. )
  43. self.k_layernorm = nn.LayerNorm(
  44. config.hidden_size // config.num_attention_heads, eps=config.layer_norm_eps, elementwise_affine=True
  45. )
  46. @deprecate_kwarg("past_key_value", new_name="past_key_values", version="4.58")
  47. def forward(
  48. self,
  49. hidden_states: torch.Tensor,
  50. position_embeddings: tuple[torch.Tensor, torch.Tensor],
  51. attention_mask: Optional[torch.Tensor],
  52. past_key_values: Optional[Cache] = None,
  53. cache_position: Optional[torch.LongTensor] = None,
  54. **kwargs,
  55. ) -> tuple[torch.Tensor, Optional[torch.Tensor]]:
  56. input_shape = hidden_states.shape[:-1]
  57. hidden_shape = (*input_shape, -1, self.head_dim)
  58. query_states = self.q_proj(hidden_states).view(hidden_shape).transpose(1, 2)
  59. key_states = self.k_proj(hidden_states).view(hidden_shape).transpose(1, 2)
  60. value_states = self.v_proj(hidden_states).view(hidden_shape).transpose(1, 2)
  61. if self.qk_layernorm:
  62. query_states = self.q_layernorm(query_states)
  63. key_states = self.k_layernorm(key_states)
  64. cos, sin = position_embeddings
  65. # Partial rotary embedding
  66. query_rot, query_pass = (
  67. query_states[..., : self.rotary_ndims],
  68. query_states[..., self.rotary_ndims :],
  69. )
  70. key_rot, key_pass = (
  71. key_states[..., : self.rotary_ndims],
  72. key_states[..., self.rotary_ndims :],
  73. )
  74. # [batch_size, seq_length, num_heads, head_dim // config.partial_rotary_factor]
  75. query_rot, key_rot = apply_rotary_pos_emb(query_rot, key_rot, cos, sin)
  76. # [batch_size, seq_length, num_heads, head_dim]
  77. query_states = torch.cat((query_rot, query_pass), dim=-1)
  78. key_states = torch.cat((key_rot, key_pass), dim=-1)
  79. if past_key_values is not None:
  80. # sin and cos are specific to RoPE models; cache_position needed for the static cache
  81. cache_kwargs = {"sin": sin, "cos": cos, "cache_position": cache_position}
  82. key_states, value_states = past_key_values.update(key_states, value_states, self.layer_idx, cache_kwargs)
  83. attention_interface: Callable = eager_attention_forward
  84. if self.config._attn_implementation != "eager":
  85. attention_interface = ALL_ATTENTION_FUNCTIONS[self.config._attn_implementation]
  86. attn_output, attn_weights = attention_interface(
  87. self,
  88. query_states,
  89. key_states,
  90. value_states,
  91. attention_mask,
  92. dropout=0.0 if not self.training else self.attention_dropout,
  93. scaling=self.scaling,
  94. **kwargs,
  95. )
  96. attn_output = attn_output.reshape(*input_shape, -1).contiguous()
  97. attn_output = self.dense(attn_output)
  98. return attn_output, attn_weights
  99. class PhiMLP(CLIPMLP):
  100. pass
  101. class PhiDecoderLayer(GradientCheckpointingLayer):
  102. def __init__(self, config: PhiConfig, layer_idx: int):
  103. super().__init__()
  104. self.self_attn = PhiAttention(config, layer_idx=layer_idx)
  105. self.mlp = PhiMLP(config)
  106. self.input_layernorm = nn.LayerNorm(config.hidden_size, eps=config.layer_norm_eps)
  107. self.resid_dropout = nn.Dropout(config.resid_pdrop)
  108. @deprecate_kwarg("past_key_value", new_name="past_key_values", version="4.58")
  109. def forward(
  110. self,
  111. hidden_states: torch.Tensor,
  112. attention_mask: Optional[torch.Tensor] = None,
  113. position_ids: Optional[torch.LongTensor] = None,
  114. past_key_values: Optional[Cache] = None,
  115. output_attentions: Optional[bool] = False,
  116. use_cache: Optional[bool] = False,
  117. cache_position: Optional[torch.LongTensor] = None,
  118. position_embeddings: Optional[tuple[torch.Tensor, torch.Tensor]] = None, # necessary, but kept here for BC
  119. **kwargs,
  120. ) -> tuple[torch.FloatTensor, Optional[tuple[torch.FloatTensor, torch.FloatTensor]]]:
  121. residual = hidden_states
  122. hidden_states = self.input_layernorm(hidden_states)
  123. # Self Attention
  124. attn_outputs, self_attn_weights = self.self_attn(
  125. hidden_states=hidden_states,
  126. attention_mask=attention_mask,
  127. position_ids=position_ids,
  128. past_key_values=past_key_values,
  129. output_attentions=output_attentions,
  130. use_cache=use_cache,
  131. cache_position=cache_position,
  132. position_embeddings=position_embeddings,
  133. **kwargs,
  134. )
  135. attn_outputs = self.resid_dropout(attn_outputs)
  136. feed_forward_hidden_states = self.resid_dropout(self.mlp(hidden_states))
  137. hidden_states = attn_outputs + feed_forward_hidden_states + residual
  138. outputs = (hidden_states,)
  139. if output_attentions:
  140. outputs += (self_attn_weights,)
  141. return outputs
  142. class PhiRotaryEmbedding(LlamaRotaryEmbedding):
  143. pass
  144. class PhiModel(LlamaModel):
  145. def __init__(self, config: PhiConfig):
  146. super().__init__(config)
  147. self.layers = nn.ModuleList(
  148. [PhiDecoderLayer(config, layer_idx) for layer_idx in range(config.num_hidden_layers)]
  149. )
  150. self.embed_dropout = nn.Dropout(config.embd_pdrop)
  151. self.final_layernorm = nn.LayerNorm(config.hidden_size, eps=config.layer_norm_eps)
  152. del self.norm
  153. def forward(
  154. self,
  155. input_ids: Optional[torch.LongTensor] = None,
  156. attention_mask: Optional[torch.Tensor] = None,
  157. position_ids: Optional[torch.LongTensor] = None,
  158. past_key_values: Optional[Cache] = None,
  159. inputs_embeds: Optional[torch.FloatTensor] = None,
  160. use_cache: Optional[bool] = None,
  161. output_attentions: Optional[bool] = None,
  162. output_hidden_states: Optional[bool] = None,
  163. cache_position: Optional[torch.LongTensor] = None,
  164. **kwargs: Unpack[TransformersKwargs],
  165. ) -> BaseModelOutputWithPast:
  166. output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
  167. output_hidden_states = (
  168. output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
  169. )
  170. use_cache = use_cache if use_cache is not None else self.config.use_cache
  171. if (input_ids is None) ^ (inputs_embeds is not None):
  172. raise ValueError("You must specify exactly one of input_ids or inputs_embeds")
  173. if self.gradient_checkpointing and self.training and use_cache:
  174. logger.warning_once(
  175. "`use_cache=True` is incompatible with gradient checkpointing. Setting `use_cache=False`."
  176. )
  177. use_cache = False
  178. if inputs_embeds is None:
  179. inputs_embeds = self.embed_tokens(input_ids)
  180. if use_cache and past_key_values is None:
  181. past_key_values = DynamicCache(config=self.config)
  182. if cache_position is None:
  183. past_seen_tokens = past_key_values.get_seq_length() if past_key_values is not None else 0
  184. cache_position = torch.arange(
  185. past_seen_tokens, past_seen_tokens + inputs_embeds.shape[1], device=inputs_embeds.device
  186. )
  187. if position_ids is None:
  188. position_ids = cache_position.unsqueeze(0)
  189. causal_mask = create_causal_mask(
  190. config=self.config,
  191. input_embeds=inputs_embeds,
  192. attention_mask=attention_mask,
  193. cache_position=cache_position,
  194. past_key_values=past_key_values,
  195. position_ids=position_ids,
  196. )
  197. inputs_embeds = self.embed_dropout(inputs_embeds) # diff with Llama
  198. hidden_states = inputs_embeds
  199. # create position embeddings to be shared across the decoder layers
  200. position_embeddings = self.rotary_emb(hidden_states, position_ids)
  201. # decoder layers
  202. all_hidden_states = () if output_hidden_states else None
  203. all_self_attns = () if output_attentions else None
  204. for decoder_layer in self.layers[: self.config.num_hidden_layers]:
  205. if output_hidden_states:
  206. all_hidden_states += (hidden_states,)
  207. layer_outputs = decoder_layer(
  208. hidden_states,
  209. attention_mask=causal_mask,
  210. position_ids=position_ids,
  211. past_key_values=past_key_values,
  212. output_attentions=output_attentions,
  213. use_cache=use_cache,
  214. cache_position=cache_position,
  215. position_embeddings=position_embeddings,
  216. **kwargs,
  217. )
  218. hidden_states = layer_outputs[0]
  219. if output_attentions:
  220. all_self_attns += (layer_outputs[1],)
  221. hidden_states = self.final_layernorm(hidden_states) # diff with Llama
  222. # add hidden states from the last decoder layer
  223. if output_hidden_states:
  224. all_hidden_states += (hidden_states,)
  225. return BaseModelOutputWithPast(
  226. last_hidden_state=hidden_states,
  227. past_key_values=past_key_values if use_cache else None,
  228. hidden_states=all_hidden_states,
  229. attentions=all_self_attns,
  230. )
  231. class PhiForCausalLM(LlamaForCausalLM):
  232. def __init__(self, config):
  233. super().__init__(config)
  234. self.lm_head = nn.Linear(config.hidden_size, config.vocab_size, bias=True)
  235. class PhiForSequenceClassification(LlamaForSequenceClassification):
  236. pass
  237. class PhiForTokenClassification(LlamaForTokenClassification):
  238. pass
  239. __all__ = [
  240. "PhiPreTrainedModel", # noqa: F822
  241. "PhiModel",
  242. "PhiForCausalLM",
  243. "PhiForSequenceClassification",
  244. "PhiForTokenClassification",
  245. ]