modeling_phi3.py 23 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525526527528529530531532533534535536537538539540541542543544545546547
  1. # 🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨
  2. # This file was automatically generated from src/transformers/models/phi3/modular_phi3.py.
  3. # Do NOT edit this file manually as any edits will be overwritten by the generation of
  4. # the file from the modular. If any change should be done, please apply the change to the
  5. # modular_phi3.py file directly. One of our CI enforces this.
  6. # 🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨
  7. # coding=utf-8
  8. # Copyright 2024 Microsoft and the HuggingFace Inc. team. All rights reserved.
  9. #
  10. # Licensed under the Apache License, Version 2.0 (the "License");
  11. # you may not use this file except in compliance with the License.
  12. # You may obtain a copy of the License at
  13. #
  14. # http://www.apache.org/licenses/LICENSE-2.0
  15. #
  16. # Unless required by applicable law or agreed to in writing, software
  17. # distributed under the License is distributed on an "AS IS" BASIS,
  18. # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  19. # See the License for the specific language governing permissions and
  20. # limitations under the License.
  21. from typing import Callable, Optional, Union
  22. import torch
  23. from torch import nn
  24. from transformers.utils.generic import check_model_inputs
  25. from ...activations import ACT2FN
  26. from ...cache_utils import Cache, DynamicCache
  27. from ...generation import GenerationMixin
  28. from ...integrations import use_kernel_forward_from_hub
  29. from ...masking_utils import create_causal_mask, create_sliding_window_causal_mask
  30. from ...modeling_flash_attention_utils import FlashAttentionKwargs
  31. from ...modeling_layers import (
  32. GenericForSequenceClassification,
  33. GenericForTokenClassification,
  34. GradientCheckpointingLayer,
  35. )
  36. from ...modeling_outputs import BaseModelOutputWithPast, CausalLMOutputWithPast
  37. from ...modeling_rope_utils import ROPE_INIT_FUNCTIONS, dynamic_rope_update
  38. from ...modeling_utils import ALL_ATTENTION_FUNCTIONS, PreTrainedModel
  39. from ...processing_utils import Unpack
  40. from ...utils import TransformersKwargs, auto_docstring, can_return_tuple
  41. from ...utils.deprecation import deprecate_kwarg
  42. from .configuration_phi3 import Phi3Config
  43. class Phi3MLP(nn.Module):
  44. def __init__(self, config):
  45. super().__init__()
  46. self.config = config
  47. self.gate_up_proj = nn.Linear(config.hidden_size, 2 * config.intermediate_size, bias=False)
  48. self.down_proj = nn.Linear(config.intermediate_size, config.hidden_size, bias=False)
  49. self.activation_fn = ACT2FN[config.hidden_act]
  50. def forward(self, hidden_states: torch.FloatTensor) -> torch.FloatTensor:
  51. up_states = self.gate_up_proj(hidden_states)
  52. gate, up_states = up_states.chunk(2, dim=-1)
  53. up_states = up_states * self.activation_fn(gate)
  54. return self.down_proj(up_states)
  55. def rotate_half(x):
  56. """Rotates half the hidden dims of the input."""
  57. x1 = x[..., : x.shape[-1] // 2]
  58. x2 = x[..., x.shape[-1] // 2 :]
  59. return torch.cat((-x2, x1), dim=-1)
  60. def repeat_kv(hidden_states: torch.Tensor, n_rep: int) -> torch.Tensor:
  61. """
  62. This is the equivalent of torch.repeat_interleave(x, dim=1, repeats=n_rep). The hidden states go from (batch,
  63. num_key_value_heads, seqlen, head_dim) to (batch, num_attention_heads, seqlen, head_dim)
  64. """
  65. batch, num_key_value_heads, slen, head_dim = hidden_states.shape
  66. if n_rep == 1:
  67. return hidden_states
  68. hidden_states = hidden_states[:, :, None, :, :].expand(batch, num_key_value_heads, n_rep, slen, head_dim)
  69. return hidden_states.reshape(batch, num_key_value_heads * n_rep, slen, head_dim)
  70. def eager_attention_forward(
  71. module: nn.Module,
  72. query: torch.Tensor,
  73. key: torch.Tensor,
  74. value: torch.Tensor,
  75. attention_mask: Optional[torch.Tensor],
  76. scaling: float,
  77. dropout: float = 0.0,
  78. **kwargs: Unpack[TransformersKwargs],
  79. ):
  80. key_states = repeat_kv(key, module.num_key_value_groups)
  81. value_states = repeat_kv(value, module.num_key_value_groups)
  82. attn_weights = torch.matmul(query, key_states.transpose(2, 3)) * scaling
  83. if attention_mask is not None:
  84. causal_mask = attention_mask[:, :, :, : key_states.shape[-2]]
  85. attn_weights = attn_weights + causal_mask
  86. attn_weights = nn.functional.softmax(attn_weights, dim=-1, dtype=torch.float32).to(query.dtype)
  87. attn_weights = nn.functional.dropout(attn_weights, p=dropout, training=module.training)
  88. attn_output = torch.matmul(attn_weights, value_states)
  89. attn_output = attn_output.transpose(1, 2).contiguous()
  90. return attn_output, attn_weights
  91. def apply_rotary_pos_emb(q, k, cos, sin, position_ids=None, unsqueeze_dim=1):
  92. """Applies Rotary Position Embedding to the query and key tensors.
  93. Args:
  94. q (`torch.Tensor`): The query tensor.
  95. k (`torch.Tensor`): The key tensor.
  96. cos (`torch.Tensor`): The cosine part of the rotary embedding.
  97. sin (`torch.Tensor`): The sine part of the rotary embedding.
  98. position_ids (`torch.Tensor`, *optional*):
  99. Deprecated and unused.
  100. unsqueeze_dim (`int`, *optional*, defaults to 1):
  101. The 'unsqueeze_dim' argument specifies the dimension along which to unsqueeze cos[position_ids] and
  102. sin[position_ids] so that they can be properly broadcasted to the dimensions of q and k. For example, note
  103. that cos[position_ids] and sin[position_ids] have the shape [batch_size, seq_len, head_dim]. Then, if q and
  104. k have the shape [batch_size, heads, seq_len, head_dim], then setting unsqueeze_dim=1 makes
  105. cos[position_ids] and sin[position_ids] broadcastable to the shapes of q and k. Similarly, if q and k have
  106. the shape [batch_size, seq_len, heads, head_dim], then set unsqueeze_dim=2.
  107. Returns:
  108. `tuple(torch.Tensor)` comprising of the query and key tensors rotated using the Rotary Position Embedding.
  109. """
  110. cos = cos.unsqueeze(unsqueeze_dim)
  111. sin = sin.unsqueeze(unsqueeze_dim)
  112. rotary_dim = cos.shape[-1]
  113. q_rot, q_pass = q[..., :rotary_dim], q[..., rotary_dim:]
  114. k_rot, k_pass = k[..., :rotary_dim], k[..., rotary_dim:]
  115. q_embed = torch.cat([(q_rot * cos) + (rotate_half(q_rot) * sin), q_pass], dim=-1)
  116. k_embed = torch.cat([(k_rot * cos) + (rotate_half(k_rot) * sin), k_pass], dim=-1)
  117. return q_embed, k_embed
  118. class Phi3Attention(nn.Module):
  119. """Multi-headed attention from 'Attention Is All You Need' paper"""
  120. def __init__(self, config: Phi3Config, layer_idx: Optional[int] = None):
  121. super().__init__()
  122. self.config = config
  123. self.layer_idx = layer_idx
  124. self.head_dim = getattr(config, "head_dim", config.hidden_size // config.num_attention_heads)
  125. self.num_key_value_groups = config.num_attention_heads // config.num_key_value_heads
  126. self.num_key_value_heads = config.num_key_value_heads
  127. self.scaling = self.head_dim**-0.5
  128. self.attention_dropout = config.attention_dropout
  129. self.is_causal = True
  130. op_size = config.num_attention_heads * self.head_dim + 2 * (config.num_key_value_heads * self.head_dim)
  131. self.o_proj = nn.Linear(config.num_attention_heads * self.head_dim, config.hidden_size, bias=False)
  132. self.qkv_proj = nn.Linear(config.hidden_size, op_size, bias=False)
  133. @deprecate_kwarg("past_key_value", new_name="past_key_values", version="4.58")
  134. def forward(
  135. self,
  136. hidden_states: torch.Tensor,
  137. position_embeddings: tuple[torch.Tensor, torch.Tensor],
  138. attention_mask: Optional[torch.Tensor],
  139. past_key_values: Optional[Cache] = None,
  140. cache_position: Optional[torch.LongTensor] = None,
  141. **kwargs: Unpack[FlashAttentionKwargs],
  142. ) -> tuple[torch.Tensor, Optional[torch.Tensor], Optional[tuple[torch.Tensor]]]:
  143. input_shape = hidden_states.shape[:-1]
  144. hidden_shape = (*input_shape, -1, self.head_dim)
  145. qkv = self.qkv_proj(hidden_states)
  146. query_pos = self.config.num_attention_heads * self.head_dim
  147. query_states = qkv[..., :query_pos]
  148. key_states = qkv[..., query_pos : query_pos + self.num_key_value_heads * self.head_dim]
  149. value_states = qkv[..., query_pos + self.num_key_value_heads * self.head_dim :]
  150. query_states = query_states.view(hidden_shape).transpose(1, 2)
  151. key_states = key_states.view(hidden_shape).transpose(1, 2)
  152. value_states = value_states.view(hidden_shape).transpose(1, 2)
  153. cos, sin = position_embeddings
  154. query_states, key_states = apply_rotary_pos_emb(query_states, key_states, cos, sin)
  155. if past_key_values is not None:
  156. # sin and cos are specific to RoPE models; cache_position needed for the static cache
  157. cache_kwargs = {"sin": sin, "cos": cos, "cache_position": cache_position}
  158. key_states, value_states = past_key_values.update(key_states, value_states, self.layer_idx, cache_kwargs)
  159. attention_interface: Callable = eager_attention_forward
  160. if self.config._attn_implementation != "eager":
  161. attention_interface = ALL_ATTENTION_FUNCTIONS[self.config._attn_implementation]
  162. attn_output, attn_weights = attention_interface(
  163. self,
  164. query_states,
  165. key_states,
  166. value_states,
  167. attention_mask,
  168. dropout=0.0 if not self.training else self.attention_dropout,
  169. scaling=self.scaling,
  170. sliding_window=getattr(self.config, "sliding_window", None),
  171. **kwargs,
  172. )
  173. attn_output = attn_output.reshape(*input_shape, -1).contiguous()
  174. attn_output = self.o_proj(attn_output)
  175. return attn_output, attn_weights
  176. @use_kernel_forward_from_hub("RMSNorm")
  177. class Phi3RMSNorm(nn.Module):
  178. def __init__(self, hidden_size, eps=1e-6):
  179. """
  180. Phi3RMSNorm is equivalent to T5LayerNorm
  181. """
  182. super().__init__()
  183. self.weight = nn.Parameter(torch.ones(hidden_size))
  184. self.variance_epsilon = eps
  185. def forward(self, hidden_states):
  186. input_dtype = hidden_states.dtype
  187. hidden_states = hidden_states.to(torch.float32)
  188. variance = hidden_states.pow(2).mean(-1, keepdim=True)
  189. hidden_states = hidden_states * torch.rsqrt(variance + self.variance_epsilon)
  190. return self.weight * hidden_states.to(input_dtype)
  191. def extra_repr(self):
  192. return f"{tuple(self.weight.shape)}, eps={self.variance_epsilon}"
  193. class Phi3DecoderLayer(GradientCheckpointingLayer):
  194. def __init__(self, config: Phi3Config, layer_idx: int):
  195. super().__init__()
  196. self.hidden_size = config.hidden_size
  197. self.self_attn = Phi3Attention(config=config, layer_idx=layer_idx)
  198. self.mlp = Phi3MLP(config)
  199. self.input_layernorm = Phi3RMSNorm(config.hidden_size, eps=config.rms_norm_eps)
  200. self.post_attention_layernorm = Phi3RMSNorm(config.hidden_size, eps=config.rms_norm_eps)
  201. self.config = config
  202. self.resid_attn_dropout = nn.Dropout(config.resid_pdrop)
  203. self.resid_mlp_dropout = nn.Dropout(config.resid_pdrop)
  204. @deprecate_kwarg("past_key_value", new_name="past_key_values", version="4.58")
  205. def forward(
  206. self,
  207. hidden_states: torch.Tensor,
  208. attention_mask: Optional[torch.Tensor] = None,
  209. position_ids: Optional[torch.LongTensor] = None,
  210. past_key_values: Optional[Cache] = None,
  211. use_cache: Optional[bool] = False,
  212. cache_position: Optional[torch.LongTensor] = None,
  213. position_embeddings: Optional[tuple[torch.Tensor, torch.Tensor]] = None, # necessary, but kept here for BC
  214. **kwargs: Unpack[FlashAttentionKwargs],
  215. ) -> tuple[torch.FloatTensor, Optional[tuple[torch.FloatTensor, torch.FloatTensor]]]:
  216. residual = hidden_states
  217. hidden_states = self.input_layernorm(hidden_states)
  218. hidden_states, self_attn_weights = self.self_attn(
  219. hidden_states=hidden_states,
  220. attention_mask=attention_mask,
  221. position_ids=position_ids,
  222. past_key_values=past_key_values,
  223. use_cache=use_cache,
  224. cache_position=cache_position,
  225. position_embeddings=position_embeddings,
  226. **kwargs,
  227. )
  228. hidden_states = residual + self.resid_attn_dropout(hidden_states) # main diff with Llama
  229. residual = hidden_states
  230. hidden_states = self.post_attention_layernorm(hidden_states)
  231. hidden_states = self.mlp(hidden_states)
  232. hidden_states = residual + self.resid_mlp_dropout(hidden_states) # main diff with Llama
  233. return hidden_states
  234. @auto_docstring
  235. class Phi3PreTrainedModel(PreTrainedModel):
  236. config: Phi3Config
  237. base_model_prefix = "model"
  238. supports_gradient_checkpointing = True
  239. _no_split_modules = ["Phi3DecoderLayer"]
  240. _skip_keys_device_placement = ["past_key_values"]
  241. _supports_flash_attn = True
  242. _supports_sdpa = True
  243. _supports_flex_attn = True
  244. _can_compile_fullgraph = True
  245. _supports_attention_backend = True
  246. _can_record_outputs = {
  247. "hidden_states": Phi3DecoderLayer,
  248. "attentions": Phi3Attention,
  249. }
  250. _version = "0.0.5"
  251. class Phi3RotaryEmbedding(nn.Module):
  252. inv_freq: torch.Tensor # fix linting for `register_buffer`
  253. def __init__(self, config: Phi3Config, device=None):
  254. super().__init__()
  255. # BC: "rope_type" was originally "type"
  256. if hasattr(config, "rope_scaling") and isinstance(config.rope_scaling, dict):
  257. self.rope_type = config.rope_scaling.get("rope_type", config.rope_scaling.get("type"))
  258. else:
  259. self.rope_type = "default"
  260. self.max_seq_len_cached = config.max_position_embeddings
  261. self.original_max_seq_len = config.max_position_embeddings
  262. self.config = config
  263. self.rope_init_fn = ROPE_INIT_FUNCTIONS[self.rope_type]
  264. inv_freq, self.attention_scaling = self.rope_init_fn(self.config, device)
  265. self.register_buffer("inv_freq", inv_freq, persistent=False)
  266. self.original_inv_freq = self.inv_freq
  267. @torch.no_grad()
  268. @dynamic_rope_update # power user: used with advanced RoPE types (e.g. dynamic rope)
  269. def forward(self, x, position_ids):
  270. inv_freq_expanded = self.inv_freq[None, :, None].float().expand(position_ids.shape[0], -1, 1).to(x.device)
  271. position_ids_expanded = position_ids[:, None, :].float()
  272. device_type = x.device.type if isinstance(x.device.type, str) and x.device.type != "mps" else "cpu"
  273. with torch.autocast(device_type=device_type, enabled=False): # Force float32
  274. freqs = (inv_freq_expanded.float() @ position_ids_expanded.float()).transpose(1, 2)
  275. emb = torch.cat((freqs, freqs), dim=-1)
  276. cos = emb.cos() * self.attention_scaling
  277. sin = emb.sin() * self.attention_scaling
  278. return cos.to(dtype=x.dtype), sin.to(dtype=x.dtype)
  279. @auto_docstring
  280. class Phi3Model(Phi3PreTrainedModel):
  281. def __init__(self, config: Phi3Config):
  282. super().__init__(config)
  283. self.padding_idx = config.pad_token_id
  284. self.vocab_size = config.vocab_size
  285. self.embed_tokens = nn.Embedding(config.vocab_size, config.hidden_size, self.padding_idx)
  286. self.layers = nn.ModuleList(
  287. [Phi3DecoderLayer(config, layer_idx) for layer_idx in range(config.num_hidden_layers)]
  288. )
  289. self.norm = Phi3RMSNorm(config.hidden_size, eps=config.rms_norm_eps)
  290. self.rotary_emb = Phi3RotaryEmbedding(config=config)
  291. self.gradient_checkpointing = False
  292. # Initialize weights and apply final processing
  293. self.post_init()
  294. @check_model_inputs()
  295. @auto_docstring
  296. def forward(
  297. self,
  298. input_ids: Optional[torch.LongTensor] = None,
  299. attention_mask: Optional[torch.Tensor] = None,
  300. position_ids: Optional[torch.LongTensor] = None,
  301. past_key_values: Optional[Cache] = None,
  302. inputs_embeds: Optional[torch.FloatTensor] = None,
  303. use_cache: Optional[bool] = None,
  304. cache_position: Optional[torch.LongTensor] = None,
  305. **kwargs: Unpack[TransformersKwargs],
  306. ) -> BaseModelOutputWithPast:
  307. if (input_ids is None) ^ (inputs_embeds is not None):
  308. raise ValueError("You must specify exactly one of input_ids or inputs_embeds")
  309. if inputs_embeds is None:
  310. inputs_embeds = self.embed_tokens(input_ids)
  311. if use_cache and past_key_values is None:
  312. past_key_values = DynamicCache(config=self.config)
  313. if cache_position is None:
  314. past_seen_tokens = past_key_values.get_seq_length() if past_key_values is not None else 0
  315. cache_position = torch.arange(
  316. past_seen_tokens, past_seen_tokens + inputs_embeds.shape[1], device=inputs_embeds.device
  317. )
  318. if position_ids is None:
  319. position_ids = cache_position.unsqueeze(0)
  320. mask_function = create_causal_mask if self.config.sliding_window is None else create_sliding_window_causal_mask
  321. causal_mask = mask_function(
  322. config=self.config,
  323. input_embeds=inputs_embeds,
  324. attention_mask=attention_mask,
  325. cache_position=cache_position,
  326. past_key_values=past_key_values,
  327. position_ids=position_ids,
  328. )
  329. hidden_states = inputs_embeds
  330. position_embeddings = self.rotary_emb(hidden_states, position_ids)
  331. for decoder_layer in self.layers[: self.config.num_hidden_layers]:
  332. hidden_states = decoder_layer(
  333. hidden_states,
  334. attention_mask=causal_mask,
  335. position_ids=position_ids,
  336. past_key_values=past_key_values,
  337. use_cache=use_cache,
  338. cache_position=cache_position,
  339. position_embeddings=position_embeddings,
  340. **kwargs,
  341. )
  342. hidden_states = self.norm(hidden_states)
  343. return BaseModelOutputWithPast(
  344. last_hidden_state=hidden_states,
  345. past_key_values=past_key_values if use_cache else None,
  346. )
  347. @auto_docstring
  348. class Phi3ForCausalLM(Phi3PreTrainedModel, GenerationMixin):
  349. _tied_weights_keys = ["lm_head.weight"]
  350. _tp_plan = {"lm_head": "colwise_rep"}
  351. _pp_plan = {"lm_head": (["hidden_states"], ["logits"])}
  352. def __init__(self, config):
  353. super().__init__(config)
  354. self.model = Phi3Model(config)
  355. self.vocab_size = config.vocab_size
  356. self.lm_head = nn.Linear(config.hidden_size, config.vocab_size, bias=False)
  357. # Initialize weights and apply final processing
  358. self.post_init()
  359. @can_return_tuple
  360. @auto_docstring
  361. def forward(
  362. self,
  363. input_ids: Optional[torch.LongTensor] = None,
  364. attention_mask: Optional[torch.Tensor] = None,
  365. position_ids: Optional[torch.LongTensor] = None,
  366. past_key_values: Optional[Cache] = None,
  367. inputs_embeds: Optional[torch.FloatTensor] = None,
  368. labels: Optional[torch.LongTensor] = None,
  369. use_cache: Optional[bool] = None,
  370. cache_position: Optional[torch.LongTensor] = None,
  371. logits_to_keep: Union[int, torch.Tensor] = 0,
  372. **kwargs: Unpack[TransformersKwargs],
  373. ) -> CausalLMOutputWithPast:
  374. r"""
  375. Example:
  376. ```python
  377. >>> from transformers import AutoTokenizer, Phi3ForCausalLM
  378. >>> model = Phi3ForCausalLM.from_pretrained("meta-phi3/Phi3-2-7b-hf")
  379. >>> tokenizer = AutoTokenizer.from_pretrained("meta-phi3/Phi3-2-7b-hf")
  380. >>> prompt = "Hey, are you conscious? Can you talk to me?"
  381. >>> inputs = tokenizer(prompt, return_tensors="pt")
  382. >>> # Generate
  383. >>> generate_ids = model.generate(inputs.input_ids, max_length=30)
  384. >>> tokenizer.batch_decode(generate_ids, skip_special_tokens=True, clean_up_tokenization_spaces=False)[0]
  385. "Hey, are you conscious? Can you talk to me?\nI'm not conscious, but I can talk to you."
  386. ```"""
  387. outputs: BaseModelOutputWithPast = self.model(
  388. input_ids=input_ids,
  389. attention_mask=attention_mask,
  390. position_ids=position_ids,
  391. past_key_values=past_key_values,
  392. inputs_embeds=inputs_embeds,
  393. use_cache=use_cache,
  394. cache_position=cache_position,
  395. **kwargs,
  396. )
  397. hidden_states = outputs.last_hidden_state
  398. # Only compute necessary logits, and do not upcast them to float if we are not computing the loss
  399. slice_indices = slice(-logits_to_keep, None) if isinstance(logits_to_keep, int) else logits_to_keep
  400. logits = self.lm_head(hidden_states[:, slice_indices, :])
  401. loss = None
  402. if labels is not None:
  403. loss = self.loss_function(logits=logits, labels=labels, vocab_size=self.config.vocab_size, **kwargs)
  404. return CausalLMOutputWithPast(
  405. loss=loss,
  406. logits=logits,
  407. past_key_values=outputs.past_key_values,
  408. hidden_states=outputs.hidden_states,
  409. attentions=outputs.attentions,
  410. )
  411. def prepare_inputs_for_generation(
  412. self,
  413. input_ids,
  414. past_key_values=None,
  415. attention_mask=None,
  416. inputs_embeds=None,
  417. cache_position=None,
  418. position_ids=None,
  419. use_cache=True,
  420. logits_to_keep=None,
  421. **kwargs,
  422. ):
  423. # Overwritten -- this model may need to switch between short and long rope, invalidating the cache in the
  424. # process
  425. # When the first time input length reached long and short factor switching point, enforce re-compute cache
  426. # It will cause downside of slower at this single token position, however, better than current failure.
  427. if (
  428. past_key_values
  429. and self.config.rope_scaling
  430. and input_ids.shape[1] >= self.config.original_max_position_embeddings + 1
  431. ):
  432. past_length = cache_position[0]
  433. if past_length <= self.config.original_max_position_embeddings:
  434. past_key_values = None
  435. model_inputs = super().prepare_inputs_for_generation(
  436. input_ids=input_ids,
  437. past_key_values=past_key_values,
  438. attention_mask=attention_mask,
  439. inputs_embeds=inputs_embeds,
  440. cache_position=cache_position,
  441. position_ids=position_ids,
  442. use_cache=use_cache,
  443. logits_to_keep=logits_to_keep,
  444. **kwargs,
  445. )
  446. return model_inputs
  447. class Phi3ForSequenceClassification(GenericForSequenceClassification, Phi3PreTrainedModel):
  448. pass
  449. class Phi3ForTokenClassification(GenericForTokenClassification, Phi3PreTrainedModel):
  450. pass
  451. __all__ = [
  452. "Phi3PreTrainedModel",
  453. "Phi3Model",
  454. "Phi3ForCausalLM",
  455. "Phi3ForSequenceClassification",
  456. "Phi3ForTokenClassification",
  457. ]