modeling_phi.py 22 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522
  1. # 🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨
  2. # This file was automatically generated from src/transformers/models/phi/modular_phi.py.
  3. # Do NOT edit this file manually as any edits will be overwritten by the generation of
  4. # the file from the modular. If any change should be done, please apply the change to the
  5. # modular_phi.py file directly. One of our CI enforces this.
  6. # 🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨
  7. from typing import Callable, Optional, Union
  8. import torch
  9. import torch.nn as nn
  10. from ...activations import ACT2FN
  11. from ...cache_utils import Cache, DynamicCache
  12. from ...generation import GenerationMixin
  13. from ...masking_utils import create_causal_mask
  14. from ...modeling_layers import (
  15. GenericForSequenceClassification,
  16. GenericForTokenClassification,
  17. GradientCheckpointingLayer,
  18. )
  19. from ...modeling_outputs import BaseModelOutputWithPast, CausalLMOutputWithPast
  20. from ...modeling_rope_utils import ROPE_INIT_FUNCTIONS, dynamic_rope_update
  21. from ...modeling_utils import ALL_ATTENTION_FUNCTIONS, PreTrainedModel
  22. from ...processing_utils import Unpack
  23. from ...utils import TransformersKwargs, auto_docstring, can_return_tuple, logging
  24. from ...utils.deprecation import deprecate_kwarg
  25. from ...utils.generic import check_model_inputs
  26. from .configuration_phi import PhiConfig
  27. logger = logging.get_logger(__name__)
  28. def rotate_half(x):
  29. """Rotates half the hidden dims of the input."""
  30. x1 = x[..., : x.shape[-1] // 2]
  31. x2 = x[..., x.shape[-1] // 2 :]
  32. return torch.cat((-x2, x1), dim=-1)
  33. def apply_rotary_pos_emb(q, k, cos, sin, position_ids=None, unsqueeze_dim=1):
  34. """Applies Rotary Position Embedding to the query and key tensors.
  35. Args:
  36. q (`torch.Tensor`): The query tensor.
  37. k (`torch.Tensor`): The key tensor.
  38. cos (`torch.Tensor`): The cosine part of the rotary embedding.
  39. sin (`torch.Tensor`): The sine part of the rotary embedding.
  40. position_ids (`torch.Tensor`, *optional*):
  41. Deprecated and unused.
  42. unsqueeze_dim (`int`, *optional*, defaults to 1):
  43. The 'unsqueeze_dim' argument specifies the dimension along which to unsqueeze cos[position_ids] and
  44. sin[position_ids] so that they can be properly broadcasted to the dimensions of q and k. For example, note
  45. that cos[position_ids] and sin[position_ids] have the shape [batch_size, seq_len, head_dim]. Then, if q and
  46. k have the shape [batch_size, heads, seq_len, head_dim], then setting unsqueeze_dim=1 makes
  47. cos[position_ids] and sin[position_ids] broadcastable to the shapes of q and k. Similarly, if q and k have
  48. the shape [batch_size, seq_len, heads, head_dim], then set unsqueeze_dim=2.
  49. Returns:
  50. `tuple(torch.Tensor)` comprising of the query and key tensors rotated using the Rotary Position Embedding.
  51. """
  52. cos = cos.unsqueeze(unsqueeze_dim)
  53. sin = sin.unsqueeze(unsqueeze_dim)
  54. q_embed = (q * cos) + (rotate_half(q) * sin)
  55. k_embed = (k * cos) + (rotate_half(k) * sin)
  56. return q_embed, k_embed
  57. def repeat_kv(hidden_states: torch.Tensor, n_rep: int) -> torch.Tensor:
  58. """
  59. This is the equivalent of torch.repeat_interleave(x, dim=1, repeats=n_rep). The hidden states go from (batch,
  60. num_key_value_heads, seqlen, head_dim) to (batch, num_attention_heads, seqlen, head_dim)
  61. """
  62. batch, num_key_value_heads, slen, head_dim = hidden_states.shape
  63. if n_rep == 1:
  64. return hidden_states
  65. hidden_states = hidden_states[:, :, None, :, :].expand(batch, num_key_value_heads, n_rep, slen, head_dim)
  66. return hidden_states.reshape(batch, num_key_value_heads * n_rep, slen, head_dim)
  67. def eager_attention_forward(
  68. module: nn.Module,
  69. query: torch.Tensor,
  70. key: torch.Tensor,
  71. value: torch.Tensor,
  72. attention_mask: Optional[torch.Tensor],
  73. scaling: float,
  74. dropout: float = 0.0,
  75. **kwargs: Unpack[TransformersKwargs],
  76. ):
  77. key_states = repeat_kv(key, module.num_key_value_groups)
  78. value_states = repeat_kv(value, module.num_key_value_groups)
  79. attn_weights = torch.matmul(query, key_states.transpose(2, 3)) * scaling
  80. if attention_mask is not None:
  81. causal_mask = attention_mask[:, :, :, : key_states.shape[-2]]
  82. attn_weights = attn_weights + causal_mask
  83. attn_weights = nn.functional.softmax(attn_weights, dim=-1, dtype=torch.float32).to(query.dtype)
  84. attn_weights = nn.functional.dropout(attn_weights, p=dropout, training=module.training)
  85. attn_output = torch.matmul(attn_weights, value_states)
  86. attn_output = attn_output.transpose(1, 2).contiguous()
  87. return attn_output, attn_weights
  88. class PhiAttention(nn.Module):
  89. """Multi-headed attention from 'Attention Is All You Need' paper"""
  90. def __init__(self, config: PhiConfig, layer_idx: int):
  91. super().__init__()
  92. self.config = config
  93. self.layer_idx = layer_idx
  94. self.head_dim = getattr(config, "head_dim", config.hidden_size // config.num_attention_heads)
  95. self.num_key_value_groups = config.num_attention_heads // config.num_key_value_heads
  96. self.scaling = self.head_dim**-0.5
  97. self.attention_dropout = config.attention_dropout
  98. self.is_causal = True
  99. self.q_proj = nn.Linear(config.hidden_size, config.num_attention_heads * self.head_dim, bias=True)
  100. self.k_proj = nn.Linear(config.hidden_size, config.num_key_value_heads * self.head_dim, bias=True)
  101. self.v_proj = nn.Linear(config.hidden_size, config.num_key_value_heads * self.head_dim, bias=True)
  102. self.dense = nn.Linear(config.num_attention_heads * self.head_dim, config.hidden_size, bias=True)
  103. self.rotary_ndims = int(self.head_dim * config.partial_rotary_factor)
  104. self.qk_layernorm = config.qk_layernorm
  105. if self.qk_layernorm:
  106. self.q_layernorm = nn.LayerNorm(
  107. config.hidden_size // config.num_attention_heads, eps=config.layer_norm_eps, elementwise_affine=True
  108. )
  109. self.k_layernorm = nn.LayerNorm(
  110. config.hidden_size // config.num_attention_heads, eps=config.layer_norm_eps, elementwise_affine=True
  111. )
  112. @deprecate_kwarg("past_key_value", new_name="past_key_values", version="4.58")
  113. def forward(
  114. self,
  115. hidden_states: torch.Tensor,
  116. position_embeddings: tuple[torch.Tensor, torch.Tensor],
  117. attention_mask: Optional[torch.Tensor],
  118. past_key_values: Optional[Cache] = None,
  119. cache_position: Optional[torch.LongTensor] = None,
  120. **kwargs,
  121. ) -> tuple[torch.Tensor, Optional[torch.Tensor]]:
  122. input_shape = hidden_states.shape[:-1]
  123. hidden_shape = (*input_shape, -1, self.head_dim)
  124. query_states = self.q_proj(hidden_states).view(hidden_shape).transpose(1, 2)
  125. key_states = self.k_proj(hidden_states).view(hidden_shape).transpose(1, 2)
  126. value_states = self.v_proj(hidden_states).view(hidden_shape).transpose(1, 2)
  127. if self.qk_layernorm:
  128. query_states = self.q_layernorm(query_states)
  129. key_states = self.k_layernorm(key_states)
  130. cos, sin = position_embeddings
  131. # Partial rotary embedding
  132. query_rot, query_pass = (
  133. query_states[..., : self.rotary_ndims],
  134. query_states[..., self.rotary_ndims :],
  135. )
  136. key_rot, key_pass = (
  137. key_states[..., : self.rotary_ndims],
  138. key_states[..., self.rotary_ndims :],
  139. )
  140. # [batch_size, seq_length, num_heads, head_dim // config.partial_rotary_factor]
  141. query_rot, key_rot = apply_rotary_pos_emb(query_rot, key_rot, cos, sin)
  142. # [batch_size, seq_length, num_heads, head_dim]
  143. query_states = torch.cat((query_rot, query_pass), dim=-1)
  144. key_states = torch.cat((key_rot, key_pass), dim=-1)
  145. if past_key_values is not None:
  146. # sin and cos are specific to RoPE models; cache_position needed for the static cache
  147. cache_kwargs = {"sin": sin, "cos": cos, "cache_position": cache_position}
  148. key_states, value_states = past_key_values.update(key_states, value_states, self.layer_idx, cache_kwargs)
  149. attention_interface: Callable = eager_attention_forward
  150. if self.config._attn_implementation != "eager":
  151. attention_interface = ALL_ATTENTION_FUNCTIONS[self.config._attn_implementation]
  152. attn_output, attn_weights = attention_interface(
  153. self,
  154. query_states,
  155. key_states,
  156. value_states,
  157. attention_mask,
  158. dropout=0.0 if not self.training else self.attention_dropout,
  159. scaling=self.scaling,
  160. **kwargs,
  161. )
  162. attn_output = attn_output.reshape(*input_shape, -1).contiguous()
  163. attn_output = self.dense(attn_output)
  164. return attn_output, attn_weights
  165. class PhiMLP(nn.Module):
  166. def __init__(self, config):
  167. super().__init__()
  168. self.config = config
  169. self.activation_fn = ACT2FN[config.hidden_act]
  170. self.fc1 = nn.Linear(config.hidden_size, config.intermediate_size)
  171. self.fc2 = nn.Linear(config.intermediate_size, config.hidden_size)
  172. def forward(self, hidden_states: torch.Tensor) -> torch.Tensor:
  173. hidden_states = self.fc1(hidden_states)
  174. hidden_states = self.activation_fn(hidden_states)
  175. hidden_states = self.fc2(hidden_states)
  176. return hidden_states
  177. class PhiDecoderLayer(GradientCheckpointingLayer):
  178. def __init__(self, config: PhiConfig, layer_idx: int):
  179. super().__init__()
  180. self.self_attn = PhiAttention(config, layer_idx=layer_idx)
  181. self.mlp = PhiMLP(config)
  182. self.input_layernorm = nn.LayerNorm(config.hidden_size, eps=config.layer_norm_eps)
  183. self.resid_dropout = nn.Dropout(config.resid_pdrop)
  184. @deprecate_kwarg("past_key_value", new_name="past_key_values", version="4.58")
  185. def forward(
  186. self,
  187. hidden_states: torch.Tensor,
  188. attention_mask: Optional[torch.Tensor] = None,
  189. position_ids: Optional[torch.LongTensor] = None,
  190. past_key_values: Optional[Cache] = None,
  191. output_attentions: Optional[bool] = False,
  192. use_cache: Optional[bool] = False,
  193. cache_position: Optional[torch.LongTensor] = None,
  194. position_embeddings: Optional[tuple[torch.Tensor, torch.Tensor]] = None, # necessary, but kept here for BC
  195. **kwargs,
  196. ) -> tuple[torch.FloatTensor, Optional[tuple[torch.FloatTensor, torch.FloatTensor]]]:
  197. residual = hidden_states
  198. hidden_states = self.input_layernorm(hidden_states)
  199. # Self Attention
  200. attn_outputs, self_attn_weights = self.self_attn(
  201. hidden_states=hidden_states,
  202. attention_mask=attention_mask,
  203. position_ids=position_ids,
  204. past_key_values=past_key_values,
  205. output_attentions=output_attentions,
  206. use_cache=use_cache,
  207. cache_position=cache_position,
  208. position_embeddings=position_embeddings,
  209. **kwargs,
  210. )
  211. attn_outputs = self.resid_dropout(attn_outputs)
  212. feed_forward_hidden_states = self.resid_dropout(self.mlp(hidden_states))
  213. hidden_states = attn_outputs + feed_forward_hidden_states + residual
  214. outputs = (hidden_states,)
  215. if output_attentions:
  216. outputs += (self_attn_weights,)
  217. return outputs
  218. class PhiRotaryEmbedding(nn.Module):
  219. inv_freq: torch.Tensor # fix linting for `register_buffer`
  220. def __init__(self, config: PhiConfig, device=None):
  221. super().__init__()
  222. # BC: "rope_type" was originally "type"
  223. if hasattr(config, "rope_scaling") and isinstance(config.rope_scaling, dict):
  224. self.rope_type = config.rope_scaling.get("rope_type", config.rope_scaling.get("type"))
  225. else:
  226. self.rope_type = "default"
  227. self.max_seq_len_cached = config.max_position_embeddings
  228. self.original_max_seq_len = config.max_position_embeddings
  229. self.config = config
  230. self.rope_init_fn = ROPE_INIT_FUNCTIONS[self.rope_type]
  231. inv_freq, self.attention_scaling = self.rope_init_fn(self.config, device)
  232. self.register_buffer("inv_freq", inv_freq, persistent=False)
  233. self.original_inv_freq = self.inv_freq
  234. @torch.no_grad()
  235. @dynamic_rope_update # power user: used with advanced RoPE types (e.g. dynamic rope)
  236. def forward(self, x, position_ids):
  237. inv_freq_expanded = self.inv_freq[None, :, None].float().expand(position_ids.shape[0], -1, 1).to(x.device)
  238. position_ids_expanded = position_ids[:, None, :].float()
  239. device_type = x.device.type if isinstance(x.device.type, str) and x.device.type != "mps" else "cpu"
  240. with torch.autocast(device_type=device_type, enabled=False): # Force float32
  241. freqs = (inv_freq_expanded.float() @ position_ids_expanded.float()).transpose(1, 2)
  242. emb = torch.cat((freqs, freqs), dim=-1)
  243. cos = emb.cos() * self.attention_scaling
  244. sin = emb.sin() * self.attention_scaling
  245. return cos.to(dtype=x.dtype), sin.to(dtype=x.dtype)
  246. @auto_docstring
  247. class PhiPreTrainedModel(PreTrainedModel):
  248. config: PhiConfig
  249. base_model_prefix = "model"
  250. supports_gradient_checkpointing = True
  251. _no_split_modules = ["PhiDecoderLayer"]
  252. _skip_keys_device_placement = ["past_key_values"]
  253. _supports_flash_attn = True
  254. _supports_sdpa = True
  255. _supports_flex_attn = True
  256. _can_compile_fullgraph = True
  257. _supports_attention_backend = True
  258. _can_record_outputs = {
  259. "hidden_states": PhiDecoderLayer,
  260. "attentions": PhiAttention,
  261. }
  262. @auto_docstring
  263. class PhiModel(PhiPreTrainedModel):
  264. def __init__(self, config: PhiConfig):
  265. super().__init__(config)
  266. self.padding_idx = config.pad_token_id
  267. self.vocab_size = config.vocab_size
  268. self.embed_tokens = nn.Embedding(config.vocab_size, config.hidden_size, self.padding_idx)
  269. self.layers = nn.ModuleList(
  270. [PhiDecoderLayer(config, layer_idx) for layer_idx in range(config.num_hidden_layers)]
  271. )
  272. self.rotary_emb = PhiRotaryEmbedding(config=config)
  273. self.gradient_checkpointing = False
  274. self.embed_dropout = nn.Dropout(config.embd_pdrop)
  275. self.final_layernorm = nn.LayerNorm(config.hidden_size, eps=config.layer_norm_eps)
  276. # Initialize weights and apply final processing
  277. self.post_init()
  278. @check_model_inputs()
  279. @auto_docstring
  280. def forward(
  281. self,
  282. input_ids: Optional[torch.LongTensor] = None,
  283. attention_mask: Optional[torch.Tensor] = None,
  284. position_ids: Optional[torch.LongTensor] = None,
  285. past_key_values: Optional[Cache] = None,
  286. inputs_embeds: Optional[torch.FloatTensor] = None,
  287. use_cache: Optional[bool] = None,
  288. output_attentions: Optional[bool] = None,
  289. output_hidden_states: Optional[bool] = None,
  290. cache_position: Optional[torch.LongTensor] = None,
  291. **kwargs: Unpack[TransformersKwargs],
  292. ) -> BaseModelOutputWithPast:
  293. output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
  294. output_hidden_states = (
  295. output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
  296. )
  297. use_cache = use_cache if use_cache is not None else self.config.use_cache
  298. if (input_ids is None) ^ (inputs_embeds is not None):
  299. raise ValueError("You must specify exactly one of input_ids or inputs_embeds")
  300. if self.gradient_checkpointing and self.training and use_cache:
  301. logger.warning_once(
  302. "`use_cache=True` is incompatible with gradient checkpointing. Setting `use_cache=False`."
  303. )
  304. use_cache = False
  305. if inputs_embeds is None:
  306. inputs_embeds = self.embed_tokens(input_ids)
  307. if use_cache and past_key_values is None:
  308. past_key_values = DynamicCache(config=self.config)
  309. if cache_position is None:
  310. past_seen_tokens = past_key_values.get_seq_length() if past_key_values is not None else 0
  311. cache_position = torch.arange(
  312. past_seen_tokens, past_seen_tokens + inputs_embeds.shape[1], device=inputs_embeds.device
  313. )
  314. if position_ids is None:
  315. position_ids = cache_position.unsqueeze(0)
  316. causal_mask = create_causal_mask(
  317. config=self.config,
  318. input_embeds=inputs_embeds,
  319. attention_mask=attention_mask,
  320. cache_position=cache_position,
  321. past_key_values=past_key_values,
  322. position_ids=position_ids,
  323. )
  324. inputs_embeds = self.embed_dropout(inputs_embeds) # diff with Llama
  325. hidden_states = inputs_embeds
  326. # create position embeddings to be shared across the decoder layers
  327. position_embeddings = self.rotary_emb(hidden_states, position_ids)
  328. # decoder layers
  329. all_hidden_states = () if output_hidden_states else None
  330. all_self_attns = () if output_attentions else None
  331. for decoder_layer in self.layers[: self.config.num_hidden_layers]:
  332. if output_hidden_states:
  333. all_hidden_states += (hidden_states,)
  334. layer_outputs = decoder_layer(
  335. hidden_states,
  336. attention_mask=causal_mask,
  337. position_ids=position_ids,
  338. past_key_values=past_key_values,
  339. output_attentions=output_attentions,
  340. use_cache=use_cache,
  341. cache_position=cache_position,
  342. position_embeddings=position_embeddings,
  343. **kwargs,
  344. )
  345. hidden_states = layer_outputs[0]
  346. if output_attentions:
  347. all_self_attns += (layer_outputs[1],)
  348. hidden_states = self.final_layernorm(hidden_states) # diff with Llama
  349. # add hidden states from the last decoder layer
  350. if output_hidden_states:
  351. all_hidden_states += (hidden_states,)
  352. return BaseModelOutputWithPast(
  353. last_hidden_state=hidden_states,
  354. past_key_values=past_key_values if use_cache else None,
  355. hidden_states=all_hidden_states,
  356. attentions=all_self_attns,
  357. )
  358. @auto_docstring
  359. class PhiForCausalLM(PhiPreTrainedModel, GenerationMixin):
  360. _tied_weights_keys = ["lm_head.weight"]
  361. _tp_plan = {"lm_head": "colwise_rep"}
  362. _pp_plan = {"lm_head": (["hidden_states"], ["logits"])}
  363. def __init__(self, config):
  364. super().__init__(config)
  365. self.model = PhiModel(config)
  366. self.vocab_size = config.vocab_size
  367. self.lm_head = nn.Linear(config.hidden_size, config.vocab_size, bias=True)
  368. # Initialize weights and apply final processing
  369. self.post_init()
  370. @can_return_tuple
  371. @auto_docstring
  372. def forward(
  373. self,
  374. input_ids: Optional[torch.LongTensor] = None,
  375. attention_mask: Optional[torch.Tensor] = None,
  376. position_ids: Optional[torch.LongTensor] = None,
  377. past_key_values: Optional[Cache] = None,
  378. inputs_embeds: Optional[torch.FloatTensor] = None,
  379. labels: Optional[torch.LongTensor] = None,
  380. use_cache: Optional[bool] = None,
  381. cache_position: Optional[torch.LongTensor] = None,
  382. logits_to_keep: Union[int, torch.Tensor] = 0,
  383. **kwargs: Unpack[TransformersKwargs],
  384. ) -> CausalLMOutputWithPast:
  385. r"""
  386. Example:
  387. ```python
  388. >>> from transformers import AutoTokenizer, PhiForCausalLM
  389. >>> model = PhiForCausalLM.from_pretrained("meta-phi/Phi-2-7b-hf")
  390. >>> tokenizer = AutoTokenizer.from_pretrained("meta-phi/Phi-2-7b-hf")
  391. >>> prompt = "Hey, are you conscious? Can you talk to me?"
  392. >>> inputs = tokenizer(prompt, return_tensors="pt")
  393. >>> # Generate
  394. >>> generate_ids = model.generate(inputs.input_ids, max_length=30)
  395. >>> tokenizer.batch_decode(generate_ids, skip_special_tokens=True, clean_up_tokenization_spaces=False)[0]
  396. "Hey, are you conscious? Can you talk to me?\nI'm not conscious, but I can talk to you."
  397. ```"""
  398. outputs: BaseModelOutputWithPast = self.model(
  399. input_ids=input_ids,
  400. attention_mask=attention_mask,
  401. position_ids=position_ids,
  402. past_key_values=past_key_values,
  403. inputs_embeds=inputs_embeds,
  404. use_cache=use_cache,
  405. cache_position=cache_position,
  406. **kwargs,
  407. )
  408. hidden_states = outputs.last_hidden_state
  409. # Only compute necessary logits, and do not upcast them to float if we are not computing the loss
  410. slice_indices = slice(-logits_to_keep, None) if isinstance(logits_to_keep, int) else logits_to_keep
  411. logits = self.lm_head(hidden_states[:, slice_indices, :])
  412. loss = None
  413. if labels is not None:
  414. loss = self.loss_function(logits=logits, labels=labels, vocab_size=self.config.vocab_size, **kwargs)
  415. return CausalLMOutputWithPast(
  416. loss=loss,
  417. logits=logits,
  418. past_key_values=outputs.past_key_values,
  419. hidden_states=outputs.hidden_states,
  420. attentions=outputs.attentions,
  421. )
  422. class PhiForSequenceClassification(GenericForSequenceClassification, PhiPreTrainedModel):
  423. pass
  424. class PhiForTokenClassification(GenericForTokenClassification, PhiPreTrainedModel):
  425. pass
  426. __all__ = [
  427. "PhiPreTrainedModel",
  428. "PhiModel",
  429. "PhiForCausalLM",
  430. "PhiForSequenceClassification",
  431. "PhiForTokenClassification",
  432. ]