modeling_persimmon.py 33 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525526527528529530531532533534535536537538539540541542543544545546547548549550551552553554555556557558559560561562563564565566567568569570571572573574575576577578579580581582583584585586587588589590591592593594595596597598599600601602603604605606607608609610611612613614615616617618619620621622623624625626627628629630631632633634635636637638639640641642643644645646647648649650651652653654655656657658659660661662663664665666667668669670671672673674675676677678679680681682683684685686687688689690691692693694695696697698699700701702703704705706707708709710711712713714715716717718719720721722723724725726727728729730731732733734735736737738739740741742743744745746747748749750751752753754755756757758759760761762763764765766767768
  1. # coding=utf-8
  2. # Copyright 2023 EleutherAI and the HuggingFace Inc. team. All rights reserved.
  3. #
  4. # This code is based on EleutherAI's GPT-NeoX library and the GPT-NeoX
  5. # and OPT implementations in this library. It has been modified from its
  6. # original forms to accommodate minor architectural differences compared
  7. # to GPT-NeoX and OPT used by the Meta AI team that trained the model.
  8. #
  9. # Licensed under the Apache License, Version 2.0 (the "License");
  10. # you may not use this file except in compliance with the License.
  11. # You may obtain a copy of the License at
  12. #
  13. # http://www.apache.org/licenses/LICENSE-2.0
  14. #
  15. # Unless required by applicable law or agreed to in writing, software
  16. # distributed under the License is distributed on an "AS IS" BASIS,
  17. # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  18. # See the License for the specific language governing permissions and
  19. # limitations under the License.
  20. """PyTorch Persimmon model."""
  21. from typing import Callable, Optional, Union
  22. import torch
  23. from torch import nn
  24. from ...activations import ACT2FN
  25. from ...cache_utils import Cache, DynamicCache
  26. from ...generation import GenerationMixin
  27. from ...modeling_attn_mask_utils import AttentionMaskConverter
  28. from ...modeling_flash_attention_utils import FlashAttentionKwargs
  29. from ...modeling_layers import (
  30. GenericForSequenceClassification,
  31. GenericForTokenClassification,
  32. GradientCheckpointingLayer,
  33. )
  34. from ...modeling_outputs import (
  35. BaseModelOutputWithPast,
  36. CausalLMOutputWithPast,
  37. )
  38. from ...modeling_rope_utils import ROPE_INIT_FUNCTIONS, dynamic_rope_update
  39. from ...modeling_utils import ALL_ATTENTION_FUNCTIONS, PreTrainedModel
  40. from ...processing_utils import Unpack
  41. from ...utils import auto_docstring, can_return_tuple, is_torch_flex_attn_available, logging
  42. from ...utils.deprecation import deprecate_kwarg
  43. from .configuration_persimmon import PersimmonConfig
  44. if is_torch_flex_attn_available():
  45. from torch.nn.attention.flex_attention import BlockMask
  46. from ...integrations.flex_attention import make_flex_block_causal_mask
  47. logger = logging.get_logger(__name__)
  48. # Copied from transformers.models.llama.modeling_llama.LlamaRotaryEmbedding with Llama->Persimmon
  49. class PersimmonRotaryEmbedding(nn.Module):
  50. inv_freq: torch.Tensor # fix linting for `register_buffer`
  51. def __init__(self, config: PersimmonConfig, device=None):
  52. super().__init__()
  53. # BC: "rope_type" was originally "type"
  54. if hasattr(config, "rope_scaling") and isinstance(config.rope_scaling, dict):
  55. self.rope_type = config.rope_scaling.get("rope_type", config.rope_scaling.get("type"))
  56. else:
  57. self.rope_type = "default"
  58. self.max_seq_len_cached = config.max_position_embeddings
  59. self.original_max_seq_len = config.max_position_embeddings
  60. self.config = config
  61. self.rope_init_fn = ROPE_INIT_FUNCTIONS[self.rope_type]
  62. inv_freq, self.attention_scaling = self.rope_init_fn(self.config, device)
  63. self.register_buffer("inv_freq", inv_freq, persistent=False)
  64. self.original_inv_freq = self.inv_freq
  65. @torch.no_grad()
  66. @dynamic_rope_update # power user: used with advanced RoPE types (e.g. dynamic rope)
  67. def forward(self, x, position_ids):
  68. inv_freq_expanded = self.inv_freq[None, :, None].float().expand(position_ids.shape[0], -1, 1).to(x.device)
  69. position_ids_expanded = position_ids[:, None, :].float()
  70. device_type = x.device.type if isinstance(x.device.type, str) and x.device.type != "mps" else "cpu"
  71. with torch.autocast(device_type=device_type, enabled=False): # Force float32
  72. freqs = (inv_freq_expanded.float() @ position_ids_expanded.float()).transpose(1, 2)
  73. emb = torch.cat((freqs, freqs), dim=-1)
  74. cos = emb.cos() * self.attention_scaling
  75. sin = emb.sin() * self.attention_scaling
  76. return cos.to(dtype=x.dtype), sin.to(dtype=x.dtype)
  77. # Copied from transformers.models.llama.modeling_llama.rotate_half
  78. def rotate_half(x):
  79. """Rotates half the hidden dims of the input."""
  80. x1 = x[..., : x.shape[-1] // 2]
  81. x2 = x[..., x.shape[-1] // 2 :]
  82. return torch.cat((-x2, x1), dim=-1)
  83. # Copied from transformers.models.llama.modeling_llama.apply_rotary_pos_emb
  84. def apply_rotary_pos_emb(q, k, cos, sin, position_ids=None, unsqueeze_dim=1):
  85. """Applies Rotary Position Embedding to the query and key tensors.
  86. Args:
  87. q (`torch.Tensor`): The query tensor.
  88. k (`torch.Tensor`): The key tensor.
  89. cos (`torch.Tensor`): The cosine part of the rotary embedding.
  90. sin (`torch.Tensor`): The sine part of the rotary embedding.
  91. position_ids (`torch.Tensor`, *optional*):
  92. Deprecated and unused.
  93. unsqueeze_dim (`int`, *optional*, defaults to 1):
  94. The 'unsqueeze_dim' argument specifies the dimension along which to unsqueeze cos[position_ids] and
  95. sin[position_ids] so that they can be properly broadcasted to the dimensions of q and k. For example, note
  96. that cos[position_ids] and sin[position_ids] have the shape [batch_size, seq_len, head_dim]. Then, if q and
  97. k have the shape [batch_size, heads, seq_len, head_dim], then setting unsqueeze_dim=1 makes
  98. cos[position_ids] and sin[position_ids] broadcastable to the shapes of q and k. Similarly, if q and k have
  99. the shape [batch_size, seq_len, heads, head_dim], then set unsqueeze_dim=2.
  100. Returns:
  101. `tuple(torch.Tensor)` comprising of the query and key tensors rotated using the Rotary Position Embedding.
  102. """
  103. cos = cos.unsqueeze(unsqueeze_dim)
  104. sin = sin.unsqueeze(unsqueeze_dim)
  105. q_embed = (q * cos) + (rotate_half(q) * sin)
  106. k_embed = (k * cos) + (rotate_half(k) * sin)
  107. return q_embed, k_embed
  108. # Copied from transformers.models.gpt_neox.modeling_gpt_neox.GPTNeoXMLP with GPTNeoX->Persimmon
  109. class PersimmonMLP(nn.Module):
  110. def __init__(self, config):
  111. super().__init__()
  112. self.dense_h_to_4h = nn.Linear(config.hidden_size, config.intermediate_size)
  113. self.dense_4h_to_h = nn.Linear(config.intermediate_size, config.hidden_size)
  114. self.act = ACT2FN[config.hidden_act]
  115. def forward(self, hidden_states):
  116. hidden_states = self.dense_h_to_4h(hidden_states)
  117. hidden_states = self.act(hidden_states)
  118. hidden_states = self.dense_4h_to_h(hidden_states)
  119. return hidden_states
  120. def eager_attention_forward(
  121. module: nn.Module,
  122. query: torch.Tensor,
  123. key: torch.Tensor,
  124. value: torch.Tensor,
  125. attention_mask: Optional[torch.Tensor],
  126. scaling: float,
  127. dropout: float = 0.0,
  128. **kwargs,
  129. ):
  130. attn_weights = torch.matmul(query, key.transpose(2, 3)) * scaling
  131. if attention_mask is not None:
  132. causal_mask = attention_mask[:, :, :, : key.shape[-2]]
  133. attn_weights = attn_weights + causal_mask
  134. attn_weights = nn.functional.softmax(attn_weights, dim=-1, dtype=torch.float32).to(query.dtype)
  135. attn_weights = nn.functional.dropout(attn_weights, p=dropout, training=module.training)
  136. attn_output = torch.matmul(attn_weights, value)
  137. attn_output = attn_output.transpose(1, 2).contiguous()
  138. return attn_output, attn_weights
  139. class PersimmonAttention(nn.Module):
  140. """Multi-headed attention from 'Attention Is All You Need' paper"""
  141. def __init__(self, config: PersimmonConfig, layer_idx: Optional[int] = None):
  142. super().__init__()
  143. self.config = config
  144. self.layer_idx = layer_idx
  145. if layer_idx is None:
  146. logger.warning_once(
  147. f"Instantiating {self.__class__.__name__} without passing a `layer_idx` is not recommended and will "
  148. "lead to errors during the forward call if caching is used. Please make sure to provide a `layer_idx` "
  149. "when creating this class."
  150. )
  151. self.hidden_size = config.hidden_size
  152. self.num_heads = config.num_attention_heads
  153. self.head_dim = self.hidden_size // self.num_heads
  154. self.rope_theta = config.rope_theta
  155. self.rotary_ndims = int(self.head_dim * config.partial_rotary_factor)
  156. self.is_causal = True
  157. if (self.head_dim * self.num_heads) != self.hidden_size:
  158. raise ValueError(
  159. f"hidden_size must be divisible by num_heads (got `hidden_size`: {self.hidden_size}"
  160. f" and `num_heads`: {self.num_heads})."
  161. )
  162. self.query_key_value = nn.Linear(self.hidden_size, 3 * self.hidden_size, bias=True)
  163. self.dense = nn.Linear(self.num_heads * self.head_dim, self.hidden_size, bias=True)
  164. self.qk_layernorm = config.qk_layernorm
  165. self.scaling = self.head_dim**-0.5
  166. if self.qk_layernorm:
  167. self.q_layernorm = nn.LayerNorm(
  168. config.hidden_size // self.num_heads, eps=config.layer_norm_eps, elementwise_affine=True
  169. )
  170. self.k_layernorm = nn.LayerNorm(
  171. config.hidden_size // self.num_heads, eps=config.layer_norm_eps, elementwise_affine=True
  172. )
  173. self.attention_dropout = nn.Dropout(config.attention_dropout)
  174. self.rotary_emb = PersimmonRotaryEmbedding(config=self.config)
  175. def _split_heads(self, fused_qkv: torch.Tensor) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
  176. """
  177. Split the last dimension into (num_heads, head_dim) without making any copies, results share same memory
  178. storage as `fused_qkv`
  179. Args:
  180. fused_qkv (`torch.tensor`): [batch_size, seq_length, num_heads * 3 * head_dim]
  181. Returns:
  182. query: [batch_size, seq_length, num_heads, head_dim] key: [batch_size, seq_length, num_heads, head_dim]
  183. value: [batch_size, seq_length, num_heads, head_dim]
  184. """
  185. batch_size, seq_length, three_times_hidden_size = fused_qkv.shape
  186. fused_qkv = fused_qkv.view(batch_size, seq_length, self.num_heads, 3, self.head_dim)
  187. return fused_qkv[..., 0, :], fused_qkv[..., 1, :], fused_qkv[..., 2, :]
  188. @deprecate_kwarg("past_key_value", new_name="past_key_values", version="4.58")
  189. def forward(
  190. self,
  191. hidden_states: torch.Tensor,
  192. attention_mask: Optional[torch.Tensor] = None,
  193. position_ids: Optional[torch.LongTensor] = None,
  194. past_key_values: Optional[Cache] = None,
  195. output_attentions: bool = False,
  196. use_cache: bool = False,
  197. cache_position: Optional[torch.LongTensor] = None,
  198. position_embeddings: Optional[tuple[torch.Tensor, torch.Tensor]] = None, # necessary, but kept here for BC
  199. **kwargs: Unpack[FlashAttentionKwargs],
  200. ) -> tuple[torch.Tensor, Optional[torch.Tensor], Optional[tuple[torch.Tensor]]]:
  201. bsz, q_len, _ = hidden_states.size()
  202. # [batch_size, seq_length, 3 x hidden_size]
  203. fused_qkv = self.query_key_value(hidden_states)
  204. # 3 x [batch_size, seq_length, num_heads, head_dim]
  205. (query_states, key_states, value_states) = self._split_heads(fused_qkv)
  206. if self.qk_layernorm:
  207. query_states = self.q_layernorm(query_states)
  208. key_states = self.k_layernorm(key_states)
  209. # [batch_size, num_heads, seq_length, head_dim] -> [batch_size, seq_length, num_heads, head_dim]
  210. query_states = query_states.transpose(1, 2)
  211. value_states = value_states.transpose(1, 2)
  212. key_states = key_states.transpose(1, 2)
  213. cos, sin = position_embeddings
  214. # Partial rotary embedding
  215. query_rot, query_pass = (
  216. query_states[..., : self.rotary_ndims],
  217. query_states[..., self.rotary_ndims :],
  218. )
  219. key_rot, key_pass = (
  220. key_states[..., : self.rotary_ndims],
  221. key_states[..., self.rotary_ndims :],
  222. )
  223. # [batch_size, seq_length, num_heads, head_dim // config.partial_rotary_factor]
  224. query_rot, key_rot = apply_rotary_pos_emb(query_rot, key_rot, cos, sin)
  225. # [batch_size, seq_length, num_heads, head_dim]
  226. query_states = torch.cat((query_rot, query_pass), dim=-1)
  227. key_states = torch.cat((key_rot, key_pass), dim=-1)
  228. if past_key_values is not None:
  229. # Specific to RoPE models with partial rotation
  230. cache_kwargs = {
  231. "sin": sin,
  232. "cos": cos,
  233. "partial_rotation_size": self.rotary_ndims,
  234. "cache_position": cache_position,
  235. }
  236. key_states, value_states = past_key_values.update(key_states, value_states, self.layer_idx, cache_kwargs)
  237. attention_interface: Callable = eager_attention_forward
  238. if self.config._attn_implementation != "eager":
  239. attention_interface = ALL_ATTENTION_FUNCTIONS[self.config._attn_implementation]
  240. attn_output, attn_weights = attention_interface(
  241. self,
  242. query_states,
  243. key_states,
  244. value_states,
  245. attention_mask,
  246. dropout=0.0 if not self.training else self.config.attention_dropout,
  247. scaling=self.scaling,
  248. **kwargs,
  249. )
  250. attn_output = attn_output.reshape(bsz, q_len, -1)
  251. attn_output = self.dense(attn_output)
  252. if not output_attentions:
  253. attn_weights = None
  254. return attn_output, attn_weights
  255. class PersimmonDecoderLayer(GradientCheckpointingLayer):
  256. def __init__(self, config: PersimmonConfig, layer_idx: int):
  257. super().__init__()
  258. self.hidden_size = config.hidden_size
  259. self.self_attn = PersimmonAttention(config=config, layer_idx=layer_idx)
  260. self.mlp = PersimmonMLP(config)
  261. self.input_layernorm = nn.LayerNorm(config.hidden_size, eps=config.layer_norm_eps)
  262. self.post_attention_layernorm = nn.LayerNorm(config.hidden_size, eps=config.layer_norm_eps)
  263. self.dropout = nn.Dropout(config.hidden_dropout)
  264. @deprecate_kwarg("past_key_value", new_name="past_key_values", version="4.58")
  265. def forward(
  266. self,
  267. hidden_states: torch.Tensor,
  268. attention_mask: Optional[torch.Tensor] = None,
  269. position_ids: Optional[torch.LongTensor] = None,
  270. past_key_values: Optional[Cache] = None,
  271. output_attentions: Optional[bool] = False,
  272. use_cache: Optional[bool] = False,
  273. cache_position: Optional[torch.LongTensor] = None,
  274. position_embeddings: Optional[tuple[torch.Tensor, torch.Tensor]] = None, # necessary, but kept here for BC
  275. **kwargs: Unpack[FlashAttentionKwargs],
  276. ) -> tuple[torch.FloatTensor, Optional[tuple[torch.FloatTensor, torch.FloatTensor]]]:
  277. """
  278. Args:
  279. hidden_states (`torch.FloatTensor`): input to the layer of shape `(batch, seq_len, embed_dim)`
  280. attention_mask (`torch.FloatTensor`, *optional*): attention mask of size
  281. `(batch, 1, tgt_len, src_len)` where padding elements are indicated by very large negative values.
  282. position_ids (`torch.LongTensor` of shape `({0})`, *optional*):
  283. Indices of positions of each input sequence tokens in the position embeddings. Selected in the range
  284. `[0, config.n_positions - 1]`.
  285. [What are position IDs?](../glossary#position-ids)
  286. past_key_values (`Cache`, *optional*):
  287. cached past key and value projection states
  288. output_attentions (`bool`, *optional*):
  289. Whether or not to return the attentions tensors of all attention layers. See `attentions` under
  290. returned tensors for more detail.
  291. use_cache (`bool`, *optional*):
  292. If set to `True`, `past_key_values` key value states are returned and can be used to speed up decoding
  293. (see `past_key_values`).
  294. cache_position (`torch.LongTensor` of shape `(sequence_length)`, *optional*):
  295. Indices depicting the position of the input sequence tokens in the sequence
  296. position_embeddings (`tuple[torch.FloatTensor, torch.FloatTensor]`, *optional*):
  297. Tuple containing the cosine and sine positional embeddings of shape `(batch_size, seq_len, head_dim)`,
  298. with `head_dim` being the embedding dimension of each attention head.
  299. """
  300. residual = hidden_states
  301. hidden_states = self.input_layernorm(hidden_states)
  302. # Self Attention
  303. hidden_states, self_attn_weights = self.self_attn(
  304. hidden_states=hidden_states,
  305. attention_mask=attention_mask,
  306. position_ids=position_ids,
  307. past_key_values=past_key_values,
  308. output_attentions=output_attentions,
  309. use_cache=use_cache,
  310. cache_position=cache_position,
  311. position_embeddings=position_embeddings,
  312. **kwargs,
  313. )
  314. hidden_states = residual + hidden_states
  315. # Fully Connected
  316. residual = hidden_states
  317. hidden_states = self.post_attention_layernorm(hidden_states)
  318. hidden_states = self.mlp(hidden_states)
  319. hidden_states = self.dropout(hidden_states)
  320. hidden_states = hidden_states + residual
  321. outputs = (hidden_states,)
  322. if output_attentions:
  323. outputs += (self_attn_weights,)
  324. return outputs
  325. @auto_docstring
  326. class PersimmonPreTrainedModel(PreTrainedModel):
  327. config: PersimmonConfig
  328. base_model_prefix = "model"
  329. supports_gradient_checkpointing = True
  330. _no_split_modules = ["PersimmonDecoderLayer"]
  331. _skip_keys_device_placement = "past_key_values"
  332. _can_compile_fullgraph = True
  333. _supports_sdpa = True
  334. _supports_flash_attn = True
  335. _supports_attention_backend = True
  336. def _init_weights(self, module):
  337. std = self.config.initializer_range
  338. if isinstance(module, nn.Linear):
  339. module.weight.data.normal_(mean=0.0, std=std)
  340. if module.bias is not None:
  341. module.bias.data.zero_()
  342. elif isinstance(module, nn.Embedding):
  343. module.weight.data.normal_(mean=0.0, std=std)
  344. if module.padding_idx is not None:
  345. module.weight.data[module.padding_idx].zero_()
  346. elif isinstance(module, nn.LayerNorm):
  347. module.weight.data.fill_(1.0)
  348. module.bias.data.zero_()
  349. @auto_docstring
  350. class PersimmonModel(PersimmonPreTrainedModel):
  351. """
  352. Transformer decoder consisting of *config.num_hidden_layers* layers. Each layer is a [`PersimmonDecoderLayer`]
  353. Args:
  354. config: PersimmonConfig
  355. """
  356. def __init__(self, config: PersimmonConfig):
  357. super().__init__(config)
  358. self.padding_idx = config.pad_token_id
  359. self.vocab_size = config.vocab_size
  360. self.embed_tokens = nn.Embedding(config.vocab_size, config.hidden_size, self.padding_idx)
  361. self.layers = nn.ModuleList(
  362. [PersimmonDecoderLayer(config, layer_idx) for layer_idx in range(config.num_hidden_layers)]
  363. )
  364. self.final_layernorm = nn.LayerNorm(config.hidden_size, eps=config.layer_norm_eps)
  365. self.rotary_emb = PersimmonRotaryEmbedding(config=config)
  366. self.gradient_checkpointing = False
  367. # Initialize weights and apply final processing
  368. self.post_init()
  369. @can_return_tuple
  370. @auto_docstring
  371. def forward(
  372. self,
  373. input_ids: Optional[torch.LongTensor] = None,
  374. attention_mask: Optional[torch.Tensor] = None,
  375. position_ids: Optional[torch.LongTensor] = None,
  376. past_key_values: Optional[Cache] = None,
  377. inputs_embeds: Optional[torch.FloatTensor] = None,
  378. use_cache: Optional[bool] = None,
  379. output_attentions: Optional[bool] = None,
  380. output_hidden_states: Optional[bool] = None,
  381. cache_position: Optional[torch.LongTensor] = None,
  382. **kwargs: Unpack[FlashAttentionKwargs],
  383. ) -> BaseModelOutputWithPast:
  384. output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
  385. output_hidden_states = (
  386. output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
  387. )
  388. use_cache = use_cache if use_cache is not None else self.config.use_cache
  389. if (input_ids is None) ^ (inputs_embeds is not None):
  390. raise ValueError("You must specify exactly one of input_ids or inputs_embeds")
  391. if self.gradient_checkpointing and self.training:
  392. if use_cache:
  393. logger.warning_once(
  394. "`use_cache=True` is incompatible with gradient checkpointing. Setting `use_cache=False`..."
  395. )
  396. use_cache = False
  397. if use_cache and past_key_values is None:
  398. past_key_values = DynamicCache(config=self.config)
  399. if inputs_embeds is None:
  400. inputs_embeds = self.embed_tokens(input_ids)
  401. if cache_position is None:
  402. past_seen_tokens = past_key_values.get_seq_length() if past_key_values is not None else 0
  403. cache_position = torch.arange(
  404. past_seen_tokens, past_seen_tokens + inputs_embeds.shape[1], device=inputs_embeds.device
  405. )
  406. if position_ids is None:
  407. position_ids = cache_position.unsqueeze(0)
  408. causal_mask = self._update_causal_mask(
  409. attention_mask, inputs_embeds, cache_position, past_key_values, output_attentions
  410. )
  411. hidden_states = inputs_embeds
  412. # create position embeddings to be shared across the decoder layers
  413. position_embeddings = self.rotary_emb(hidden_states, position_ids)
  414. # decoder layers
  415. all_hidden_states = () if output_hidden_states else None
  416. all_self_attns = () if output_attentions else None
  417. for decoder_layer in self.layers:
  418. if output_hidden_states:
  419. all_hidden_states += (hidden_states,)
  420. layer_outputs = decoder_layer(
  421. hidden_states,
  422. attention_mask=causal_mask,
  423. position_ids=position_ids,
  424. past_key_values=past_key_values,
  425. output_attentions=output_attentions,
  426. use_cache=use_cache,
  427. cache_position=cache_position,
  428. position_embeddings=position_embeddings,
  429. **kwargs,
  430. )
  431. hidden_states = layer_outputs[0]
  432. if output_attentions:
  433. all_self_attns += (layer_outputs[1],)
  434. hidden_states = self.final_layernorm(hidden_states)
  435. # add hidden states from the last decoder layer
  436. if output_hidden_states:
  437. all_hidden_states += (hidden_states,)
  438. return BaseModelOutputWithPast(
  439. last_hidden_state=hidden_states,
  440. past_key_values=past_key_values,
  441. hidden_states=all_hidden_states,
  442. attentions=all_self_attns,
  443. )
  444. # Copied from transformers.models.gptj.modeling_gptj.GPTJModel._update_causal_mask
  445. def _update_causal_mask(
  446. self,
  447. attention_mask: Union[torch.Tensor, "BlockMask"],
  448. input_tensor: torch.Tensor,
  449. cache_position: torch.Tensor,
  450. past_key_values: Cache,
  451. output_attentions: bool = False,
  452. ):
  453. if self.config._attn_implementation == "flash_attention_2":
  454. if attention_mask is not None and (attention_mask == 0.0).any():
  455. return attention_mask
  456. return None
  457. if self.config._attn_implementation == "flex_attention":
  458. if isinstance(attention_mask, torch.Tensor):
  459. attention_mask = make_flex_block_causal_mask(attention_mask)
  460. return attention_mask
  461. # For SDPA, when possible, we will rely on its `is_causal` argument instead of its `attn_mask` argument, in
  462. # order to dispatch on Flash Attention 2. This feature is not compatible with static cache, as SDPA will fail
  463. # to infer the attention mask.
  464. past_seen_tokens = past_key_values.get_seq_length() if past_key_values is not None else 0
  465. using_compilable_cache = past_key_values.is_compileable if past_key_values is not None else False
  466. # When output attentions is True, sdpa implementation's forward method calls the eager implementation's forward
  467. if self.config._attn_implementation == "sdpa" and not using_compilable_cache and not output_attentions:
  468. if AttentionMaskConverter._ignore_causal_mask_sdpa(
  469. attention_mask,
  470. inputs_embeds=input_tensor,
  471. past_key_values_length=past_seen_tokens,
  472. is_training=self.training,
  473. ):
  474. return None
  475. dtype = input_tensor.dtype
  476. sequence_length = input_tensor.shape[1]
  477. if using_compilable_cache:
  478. target_length = past_key_values.get_max_cache_shape()
  479. else:
  480. target_length = (
  481. attention_mask.shape[-1]
  482. if isinstance(attention_mask, torch.Tensor)
  483. else past_seen_tokens + sequence_length + 1
  484. )
  485. # In case the provided `attention` mask is 2D, we generate a causal mask here (4D).
  486. causal_mask = self._prepare_4d_causal_attention_mask_with_cache_position(
  487. attention_mask,
  488. sequence_length=sequence_length,
  489. target_length=target_length,
  490. dtype=dtype,
  491. cache_position=cache_position,
  492. batch_size=input_tensor.shape[0],
  493. )
  494. if (
  495. self.config._attn_implementation == "sdpa"
  496. and attention_mask is not None
  497. and attention_mask.device.type in ["cuda", "xpu", "npu"]
  498. and not output_attentions
  499. ):
  500. # Attend to all tokens in fully masked rows in the causal_mask, for example the relevant first rows when
  501. # using left padding. This is required by F.scaled_dot_product_attention memory-efficient attention path.
  502. # Details: https://github.com/pytorch/pytorch/issues/110213
  503. min_dtype = torch.finfo(dtype).min
  504. causal_mask = AttentionMaskConverter._unmask_unattended(causal_mask, min_dtype)
  505. return causal_mask
  506. @staticmethod
  507. # Copied from transformers.models.gptj.modeling_gptj.GPTJModel._prepare_4d_causal_attention_mask_with_cache_position
  508. def _prepare_4d_causal_attention_mask_with_cache_position(
  509. attention_mask: torch.Tensor,
  510. sequence_length: int,
  511. target_length: int,
  512. dtype: torch.dtype,
  513. cache_position: torch.Tensor,
  514. batch_size: int,
  515. **kwargs,
  516. ):
  517. """
  518. Creates a causal 4D mask of shape `(batch_size, 1, query_length, key_value_length)` from a 2D mask of shape
  519. `(batch_size, key_value_length)`, or if the input `attention_mask` is already 4D, do nothing.
  520. Args:
  521. attention_mask (`torch.Tensor`):
  522. A 2D attention mask of shape `(batch_size, key_value_length)` or a 4D attention mask of shape
  523. `(batch_size, 1, query_length, key_value_length)`.
  524. sequence_length (`int`):
  525. The sequence length being processed.
  526. target_length (`int`):
  527. The target length: when generating with static cache, the mask should be as long as the static cache,
  528. to account for the 0 padding, the part of the cache that is not filled yet.
  529. dtype (`torch.dtype`):
  530. The dtype to use for the 4D attention mask.
  531. cache_position (`torch.Tensor`):
  532. Indices depicting the position of the input sequence tokens in the sequence.
  533. batch_size (`torch.Tensor`):
  534. Batch size.
  535. """
  536. if attention_mask is not None and attention_mask.dim() == 4:
  537. # In this case we assume that the mask comes already in inverted form and requires no inversion or slicing.
  538. causal_mask = attention_mask
  539. else:
  540. min_dtype = torch.finfo(dtype).min
  541. causal_mask = torch.full(
  542. (sequence_length, target_length), fill_value=min_dtype, dtype=dtype, device=cache_position.device
  543. )
  544. if sequence_length != 1:
  545. causal_mask = torch.triu(causal_mask, diagonal=1)
  546. causal_mask *= torch.arange(target_length, device=cache_position.device) > cache_position.reshape(-1, 1)
  547. causal_mask = causal_mask[None, None, :, :].expand(batch_size, 1, -1, -1)
  548. if attention_mask is not None:
  549. causal_mask = causal_mask.clone() # copy to contiguous memory for in-place edit
  550. mask_length = attention_mask.shape[-1]
  551. padding_mask = causal_mask[:, :, :, :mask_length] + attention_mask[:, None, None, :].to(
  552. causal_mask.device
  553. )
  554. padding_mask = padding_mask == 0
  555. causal_mask[:, :, :, :mask_length] = causal_mask[:, :, :, :mask_length].masked_fill(
  556. padding_mask, min_dtype
  557. )
  558. return causal_mask
  559. class PersimmonForCausalLM(PersimmonPreTrainedModel, GenerationMixin):
  560. _tied_weights_keys = ["lm_head.weight"]
  561. # Copied from transformers.models.llama.modeling_llama.LlamaForCausalLM.__init__ with LLAMA->PERSIMMON,Llama->Persimmon
  562. def __init__(self, config):
  563. super().__init__(config)
  564. self.model = PersimmonModel(config)
  565. self.vocab_size = config.vocab_size
  566. self.lm_head = nn.Linear(config.hidden_size, config.vocab_size, bias=False)
  567. # Initialize weights and apply final processing
  568. self.post_init()
  569. @can_return_tuple
  570. @auto_docstring
  571. def forward(
  572. self,
  573. input_ids: Optional[torch.LongTensor] = None,
  574. attention_mask: Optional[torch.Tensor] = None,
  575. position_ids: Optional[torch.LongTensor] = None,
  576. past_key_values: Optional[Cache] = None,
  577. inputs_embeds: Optional[torch.FloatTensor] = None,
  578. labels: Optional[torch.LongTensor] = None,
  579. use_cache: Optional[bool] = None,
  580. output_attentions: Optional[bool] = None,
  581. output_hidden_states: Optional[bool] = None,
  582. cache_position: Optional[torch.LongTensor] = None,
  583. logits_to_keep: Union[int, torch.Tensor] = 0,
  584. **kwargs,
  585. ) -> CausalLMOutputWithPast:
  586. r"""
  587. labels (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*):
  588. Labels for computing the masked language modeling loss. Indices should either be in `[0, ...,
  589. config.vocab_size]` or -100 (see `input_ids` docstring). Tokens with indices set to `-100` are ignored
  590. (masked), the loss is only computed for the tokens with labels in `[0, ..., config.vocab_size]`.
  591. Example:
  592. ```python
  593. >>> from transformers import AutoTokenizer, PersimmonForCausalLM
  594. >>> model = PersimmonForCausalLM.from_pretrained("adept/persimmon-8b-base")
  595. >>> tokenizer = AutoTokenizer.from_pretrained("adept/persimmon-8b-base")
  596. >>> prompt = "human: Hey, what should I eat for dinner?"
  597. >>> inputs = tokenizer(prompt, return_tensors="pt")
  598. >>> # Generate
  599. >>> generate_ids = model.generate(inputs.input_ids, max_length=30)
  600. >>> tokenizer.batch_decode(generate_ids, skip_special_tokens=True, clean_up_tokenization_spaces=False)[0]
  601. 'human: Hey, what should I eat for dinner?\n\ncat: 🐱\n\nhuman: 😐\n\n'
  602. ```"""
  603. output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
  604. output_hidden_states = (
  605. output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
  606. )
  607. # decoder outputs consists of (dec_features, layer_state, dec_hidden, dec_attn)
  608. outputs: BaseModelOutputWithPast = self.model(
  609. input_ids=input_ids,
  610. attention_mask=attention_mask,
  611. position_ids=position_ids,
  612. past_key_values=past_key_values,
  613. inputs_embeds=inputs_embeds,
  614. use_cache=use_cache,
  615. output_attentions=output_attentions,
  616. output_hidden_states=output_hidden_states,
  617. cache_position=cache_position,
  618. **kwargs,
  619. )
  620. hidden_states = outputs.last_hidden_state
  621. # No upscaling to float was ever done for Persimmon
  622. slice_indices = slice(-logits_to_keep, None) if isinstance(logits_to_keep, int) else logits_to_keep
  623. logits = self.lm_head(hidden_states[:, slice_indices, :])
  624. loss = None
  625. if labels is not None:
  626. loss = self.loss_function(
  627. logits,
  628. labels,
  629. vocab_size=self.config.vocab_size,
  630. **kwargs,
  631. )
  632. return CausalLMOutputWithPast(
  633. loss=loss,
  634. logits=logits,
  635. past_key_values=outputs.past_key_values,
  636. hidden_states=outputs.hidden_states,
  637. attentions=outputs.attentions,
  638. )
  639. class PersimmonForSequenceClassification(GenericForSequenceClassification, PersimmonPreTrainedModel): ...
  640. class PersimmonForTokenClassification(GenericForTokenClassification, PersimmonPreTrainedModel): ...
  641. __all__ = [
  642. "PersimmonForCausalLM",
  643. "PersimmonModel",
  644. "PersimmonPreTrainedModel",
  645. "PersimmonForSequenceClassification",
  646. "PersimmonForTokenClassification",
  647. ]