modeling_granite.py 24 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525526527528529530531532533534535536537538539540541542543544545546547548549550551552553554555556557558559560561562563564565
  1. # 🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨
  2. # This file was automatically generated from src/transformers/models/granite/modular_granite.py.
  3. # Do NOT edit this file manually as any edits will be overwritten by the generation of
  4. # the file from the modular. If any change should be done, please apply the change to the
  5. # modular_granite.py file directly. One of our CI enforces this.
  6. # 🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨
  7. # coding=utf-8
  8. # Copyright 2024 IBM and the HuggingFace Inc. team. All rights reserved.
  9. #
  10. #
  11. # Licensed under the Apache License, Version 2.0 (the "License");
  12. # you may not use this file except in compliance with the License.
  13. # You may obtain a copy of the License at
  14. #
  15. # http://www.apache.org/licenses/LICENSE-2.0
  16. #
  17. # Unless required by applicable law or agreed to in writing, software
  18. # distributed under the License is distributed on an "AS IS" BASIS,
  19. # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  20. # See the License for the specific language governing permissions and
  21. # limitations under the License.
  22. from typing import Callable, Optional, Union
  23. import torch
  24. from torch import nn
  25. from ...activations import ACT2FN
  26. from ...cache_utils import Cache, DynamicCache
  27. from ...generation import GenerationMixin
  28. from ...integrations import use_kernel_forward_from_hub
  29. from ...masking_utils import create_causal_mask
  30. from ...modeling_layers import GradientCheckpointingLayer
  31. from ...modeling_outputs import BaseModelOutputWithPast, CausalLMOutputWithPast
  32. from ...modeling_rope_utils import ROPE_INIT_FUNCTIONS, dynamic_rope_update
  33. from ...modeling_utils import ALL_ATTENTION_FUNCTIONS, PreTrainedModel
  34. from ...processing_utils import Unpack
  35. from ...utils import TransformersKwargs, auto_docstring, can_return_tuple, logging
  36. from ...utils.deprecation import deprecate_kwarg
  37. from ...utils.generic import check_model_inputs
  38. from .configuration_granite import GraniteConfig
  39. logger = logging.get_logger(__name__)
  40. def rotate_half(x):
  41. """Rotates half the hidden dims of the input."""
  42. x1 = x[..., : x.shape[-1] // 2]
  43. x2 = x[..., x.shape[-1] // 2 :]
  44. return torch.cat((-x2, x1), dim=-1)
  45. def apply_rotary_pos_emb(q, k, cos, sin, position_ids=None, unsqueeze_dim=1):
  46. """Applies Rotary Position Embedding to the query and key tensors.
  47. Args:
  48. q (`torch.Tensor`): The query tensor.
  49. k (`torch.Tensor`): The key tensor.
  50. cos (`torch.Tensor`): The cosine part of the rotary embedding.
  51. sin (`torch.Tensor`): The sine part of the rotary embedding.
  52. position_ids (`torch.Tensor`, *optional*):
  53. Deprecated and unused.
  54. unsqueeze_dim (`int`, *optional*, defaults to 1):
  55. The 'unsqueeze_dim' argument specifies the dimension along which to unsqueeze cos[position_ids] and
  56. sin[position_ids] so that they can be properly broadcasted to the dimensions of q and k. For example, note
  57. that cos[position_ids] and sin[position_ids] have the shape [batch_size, seq_len, head_dim]. Then, if q and
  58. k have the shape [batch_size, heads, seq_len, head_dim], then setting unsqueeze_dim=1 makes
  59. cos[position_ids] and sin[position_ids] broadcastable to the shapes of q and k. Similarly, if q and k have
  60. the shape [batch_size, seq_len, heads, head_dim], then set unsqueeze_dim=2.
  61. Returns:
  62. `tuple(torch.Tensor)` comprising of the query and key tensors rotated using the Rotary Position Embedding.
  63. """
  64. cos = cos.unsqueeze(unsqueeze_dim)
  65. sin = sin.unsqueeze(unsqueeze_dim)
  66. q_embed = (q * cos) + (rotate_half(q) * sin)
  67. k_embed = (k * cos) + (rotate_half(k) * sin)
  68. return q_embed, k_embed
  69. def repeat_kv(hidden_states: torch.Tensor, n_rep: int) -> torch.Tensor:
  70. """
  71. This is the equivalent of torch.repeat_interleave(x, dim=1, repeats=n_rep). The hidden states go from (batch,
  72. num_key_value_heads, seqlen, head_dim) to (batch, num_attention_heads, seqlen, head_dim)
  73. """
  74. batch, num_key_value_heads, slen, head_dim = hidden_states.shape
  75. if n_rep == 1:
  76. return hidden_states
  77. hidden_states = hidden_states[:, :, None, :, :].expand(batch, num_key_value_heads, n_rep, slen, head_dim)
  78. return hidden_states.reshape(batch, num_key_value_heads * n_rep, slen, head_dim)
  79. def eager_attention_forward(
  80. module: nn.Module,
  81. query: torch.Tensor,
  82. key: torch.Tensor,
  83. value: torch.Tensor,
  84. attention_mask: Optional[torch.Tensor],
  85. scaling: float,
  86. dropout: float = 0.0,
  87. **kwargs: Unpack[TransformersKwargs],
  88. ):
  89. key_states = repeat_kv(key, module.num_key_value_groups)
  90. value_states = repeat_kv(value, module.num_key_value_groups)
  91. attn_weights = torch.matmul(query, key_states.transpose(2, 3)) * scaling
  92. if attention_mask is not None:
  93. causal_mask = attention_mask[:, :, :, : key_states.shape[-2]]
  94. attn_weights = attn_weights + causal_mask
  95. attn_weights = nn.functional.softmax(attn_weights, dim=-1, dtype=torch.float32).to(query.dtype)
  96. attn_weights = nn.functional.dropout(attn_weights, p=dropout, training=module.training)
  97. attn_output = torch.matmul(attn_weights, value_states)
  98. attn_output = attn_output.transpose(1, 2).contiguous()
  99. return attn_output, attn_weights
  100. class GraniteAttention(nn.Module):
  101. """Multi-headed attention from 'Attention Is All You Need' paper"""
  102. def __init__(self, config: GraniteConfig, layer_idx: Optional[int] = None):
  103. super().__init__()
  104. self.config = config
  105. self.layer_idx = layer_idx
  106. self.head_dim = getattr(config, "head_dim", config.hidden_size // config.num_attention_heads)
  107. self.num_key_value_groups = config.num_attention_heads // config.num_key_value_heads
  108. self.scaling = config.attention_multiplier
  109. self.attention_dropout = config.attention_dropout
  110. self.is_causal = True
  111. self.q_proj = nn.Linear(
  112. config.hidden_size, config.num_attention_heads * self.head_dim, bias=config.attention_bias
  113. )
  114. self.k_proj = nn.Linear(
  115. config.hidden_size, config.num_key_value_heads * self.head_dim, bias=config.attention_bias
  116. )
  117. self.v_proj = nn.Linear(
  118. config.hidden_size, config.num_key_value_heads * self.head_dim, bias=config.attention_bias
  119. )
  120. self.o_proj = nn.Linear(
  121. config.num_attention_heads * self.head_dim, config.hidden_size, bias=config.attention_bias
  122. )
  123. @deprecate_kwarg("past_key_value", new_name="past_key_values", version="4.58")
  124. def forward(
  125. self,
  126. hidden_states: torch.Tensor,
  127. position_embeddings: tuple[torch.Tensor, torch.Tensor],
  128. attention_mask: Optional[torch.Tensor],
  129. past_key_values: Optional[Cache] = None,
  130. cache_position: Optional[torch.LongTensor] = None,
  131. **kwargs: Unpack[TransformersKwargs],
  132. ) -> tuple[torch.Tensor, torch.Tensor]:
  133. input_shape = hidden_states.shape[:-1]
  134. hidden_shape = (*input_shape, -1, self.head_dim)
  135. query_states = self.q_proj(hidden_states).view(hidden_shape).transpose(1, 2)
  136. key_states = self.k_proj(hidden_states).view(hidden_shape).transpose(1, 2)
  137. value_states = self.v_proj(hidden_states).view(hidden_shape).transpose(1, 2)
  138. cos, sin = position_embeddings
  139. query_states, key_states = apply_rotary_pos_emb(query_states, key_states, cos, sin)
  140. if past_key_values is not None:
  141. # sin and cos are specific to RoPE models; cache_position needed for the static cache
  142. cache_kwargs = {"sin": sin, "cos": cos, "cache_position": cache_position}
  143. key_states, value_states = past_key_values.update(key_states, value_states, self.layer_idx, cache_kwargs)
  144. attention_interface: Callable = eager_attention_forward
  145. if self.config._attn_implementation != "eager":
  146. attention_interface = ALL_ATTENTION_FUNCTIONS[self.config._attn_implementation]
  147. attn_output, attn_weights = attention_interface(
  148. self,
  149. query_states,
  150. key_states,
  151. value_states,
  152. attention_mask,
  153. dropout=0.0 if not self.training else self.attention_dropout,
  154. scaling=self.scaling,
  155. **kwargs,
  156. )
  157. attn_output = attn_output.reshape(*input_shape, -1).contiguous()
  158. attn_output = self.o_proj(attn_output)
  159. return attn_output, attn_weights
  160. @use_kernel_forward_from_hub("RMSNorm")
  161. class GraniteRMSNorm(nn.Module):
  162. def __init__(self, hidden_size, eps=1e-6):
  163. """
  164. GraniteRMSNorm is equivalent to T5LayerNorm
  165. """
  166. super().__init__()
  167. self.weight = nn.Parameter(torch.ones(hidden_size))
  168. self.variance_epsilon = eps
  169. def forward(self, hidden_states):
  170. input_dtype = hidden_states.dtype
  171. hidden_states = hidden_states.to(torch.float32)
  172. variance = hidden_states.pow(2).mean(-1, keepdim=True)
  173. hidden_states = hidden_states * torch.rsqrt(variance + self.variance_epsilon)
  174. return self.weight * hidden_states.to(input_dtype)
  175. def extra_repr(self):
  176. return f"{tuple(self.weight.shape)}, eps={self.variance_epsilon}"
  177. class GraniteMLP(nn.Module):
  178. def __init__(self, config):
  179. super().__init__()
  180. self.config = config
  181. self.hidden_size = config.hidden_size
  182. self.intermediate_size = config.intermediate_size
  183. self.gate_proj = nn.Linear(self.hidden_size, self.intermediate_size, bias=config.mlp_bias)
  184. self.up_proj = nn.Linear(self.hidden_size, self.intermediate_size, bias=config.mlp_bias)
  185. self.down_proj = nn.Linear(self.intermediate_size, self.hidden_size, bias=config.mlp_bias)
  186. self.act_fn = ACT2FN[config.hidden_act]
  187. def forward(self, x):
  188. down_proj = self.down_proj(self.act_fn(self.gate_proj(x)) * self.up_proj(x))
  189. return down_proj
  190. class GraniteDecoderLayer(GradientCheckpointingLayer):
  191. def __init__(self, config: GraniteConfig, layer_idx: int):
  192. super().__init__()
  193. self.hidden_size = config.hidden_size
  194. self.self_attn = GraniteAttention(config=config, layer_idx=layer_idx)
  195. self.mlp = GraniteMLP(config)
  196. self.input_layernorm = GraniteRMSNorm(config.hidden_size, eps=config.rms_norm_eps)
  197. self.post_attention_layernorm = GraniteRMSNorm(config.hidden_size, eps=config.rms_norm_eps)
  198. self.residual_multiplier = config.residual_multiplier
  199. @deprecate_kwarg("past_key_value", new_name="past_key_values", version="4.58")
  200. def forward(
  201. self,
  202. hidden_states: torch.Tensor,
  203. attention_mask: Optional[torch.Tensor] = None,
  204. position_ids: Optional[torch.LongTensor] = None,
  205. past_key_values: Optional[Cache] = None,
  206. output_attentions: Optional[bool] = False,
  207. use_cache: Optional[bool] = False,
  208. cache_position: Optional[torch.LongTensor] = None,
  209. position_embeddings: Optional[tuple[torch.Tensor, torch.Tensor]] = None, # necessary, but kept here for BC
  210. **kwargs,
  211. ) -> tuple[torch.FloatTensor, Optional[tuple[torch.FloatTensor, torch.FloatTensor]]]:
  212. """
  213. Args:
  214. hidden_states (`torch.FloatTensor`): input to the layer of shape `(batch, seq_len, embed_dim)`
  215. attention_mask (`torch.FloatTensor`, *optional*):
  216. attention mask of size `(batch_size, sequence_length)` if flash attention is used or `(batch_size, 1,
  217. query_sequence_length, key_sequence_length)` if default attention is used.
  218. output_attentions (`bool`, *optional*):
  219. Whether or not to return the attentions tensors of all attention layers. See `attentions` under
  220. returned tensors for more detail.
  221. use_cache (`bool`, *optional*):
  222. If set to `True`, `past_key_values` key value states are returned and can be used to speed up decoding
  223. (see `past_key_values`).
  224. past_key_values (`Cache`, *optional*): cached past key and value projection states
  225. cache_position (`torch.LongTensor` of shape `(sequence_length)`, *optional*):
  226. Indices depicting the position of the input sequence tokens in the sequence
  227. position_embeddings (`tuple[torch.FloatTensor, torch.FloatTensor]`, *optional*):
  228. Tuple containing the cosine and sine positional embeddings of shape `(batch_size, seq_len, head_dim)`,
  229. with `head_dim` being the embedding dimension of each attention head.
  230. kwargs (`dict`, *optional*):
  231. Arbitrary kwargs to be ignored, used for FSDP and other methods that injects code
  232. into the model
  233. """
  234. residual = hidden_states
  235. hidden_states = self.input_layernorm(hidden_states)
  236. # Self Attention
  237. hidden_states, self_attn_weights = self.self_attn(
  238. hidden_states=hidden_states,
  239. attention_mask=attention_mask,
  240. position_ids=position_ids,
  241. past_key_values=past_key_values,
  242. output_attentions=output_attentions,
  243. use_cache=use_cache,
  244. cache_position=cache_position,
  245. position_embeddings=position_embeddings,
  246. **kwargs,
  247. )
  248. hidden_states = residual + hidden_states * self.residual_multiplier
  249. # Fully Connected
  250. residual = hidden_states
  251. hidden_states = self.post_attention_layernorm(hidden_states)
  252. hidden_states = self.mlp(hidden_states)
  253. hidden_states = residual + hidden_states * self.residual_multiplier # main diff with Llama
  254. outputs = (hidden_states,)
  255. if output_attentions:
  256. outputs += (self_attn_weights,)
  257. return outputs
  258. @auto_docstring
  259. class GranitePreTrainedModel(PreTrainedModel):
  260. config: GraniteConfig
  261. base_model_prefix = "model"
  262. supports_gradient_checkpointing = True
  263. _no_split_modules = ["GraniteDecoderLayer"]
  264. _skip_keys_device_placement = ["past_key_values"]
  265. _supports_flash_attn = True
  266. _supports_sdpa = True
  267. _supports_flex_attn = True
  268. _can_compile_fullgraph = True
  269. _supports_attention_backend = True
  270. _can_record_outputs = {
  271. "hidden_states": GraniteDecoderLayer,
  272. "attentions": GraniteAttention,
  273. }
  274. class GraniteRotaryEmbedding(nn.Module):
  275. inv_freq: torch.Tensor # fix linting for `register_buffer`
  276. def __init__(self, config: GraniteConfig, device=None):
  277. super().__init__()
  278. # BC: "rope_type" was originally "type"
  279. if hasattr(config, "rope_scaling") and isinstance(config.rope_scaling, dict):
  280. self.rope_type = config.rope_scaling.get("rope_type", config.rope_scaling.get("type"))
  281. else:
  282. self.rope_type = "default"
  283. self.max_seq_len_cached = config.max_position_embeddings
  284. self.original_max_seq_len = config.max_position_embeddings
  285. self.config = config
  286. self.rope_init_fn = ROPE_INIT_FUNCTIONS[self.rope_type]
  287. inv_freq, self.attention_scaling = self.rope_init_fn(self.config, device)
  288. self.register_buffer("inv_freq", inv_freq, persistent=False)
  289. self.original_inv_freq = self.inv_freq
  290. @torch.no_grad()
  291. @dynamic_rope_update # power user: used with advanced RoPE types (e.g. dynamic rope)
  292. def forward(self, x, position_ids):
  293. inv_freq_expanded = self.inv_freq[None, :, None].float().expand(position_ids.shape[0], -1, 1).to(x.device)
  294. position_ids_expanded = position_ids[:, None, :].float()
  295. device_type = x.device.type if isinstance(x.device.type, str) and x.device.type != "mps" else "cpu"
  296. with torch.autocast(device_type=device_type, enabled=False): # Force float32
  297. freqs = (inv_freq_expanded.float() @ position_ids_expanded.float()).transpose(1, 2)
  298. emb = torch.cat((freqs, freqs), dim=-1)
  299. cos = emb.cos() * self.attention_scaling
  300. sin = emb.sin() * self.attention_scaling
  301. return cos.to(dtype=x.dtype), sin.to(dtype=x.dtype)
  302. @auto_docstring
  303. class GraniteModel(GranitePreTrainedModel):
  304. def __init__(self, config: GraniteConfig):
  305. super().__init__(config)
  306. self.padding_idx = config.pad_token_id
  307. self.vocab_size = config.vocab_size
  308. self.embed_tokens = nn.Embedding(config.vocab_size, config.hidden_size, self.padding_idx)
  309. self.layers = nn.ModuleList(
  310. [GraniteDecoderLayer(config, layer_idx) for layer_idx in range(config.num_hidden_layers)]
  311. )
  312. self.norm = GraniteRMSNorm(config.hidden_size, eps=config.rms_norm_eps)
  313. self.rotary_emb = GraniteRotaryEmbedding(config=config)
  314. self.gradient_checkpointing = False
  315. self.embedding_multiplier = config.embedding_multiplier
  316. # Initialize weights and apply final processing
  317. self.post_init()
  318. @check_model_inputs()
  319. @auto_docstring
  320. def forward(
  321. self,
  322. input_ids: Optional[torch.LongTensor] = None,
  323. attention_mask: Optional[torch.Tensor] = None,
  324. position_ids: Optional[torch.LongTensor] = None,
  325. past_key_values: Optional[Cache] = None,
  326. inputs_embeds: Optional[torch.FloatTensor] = None,
  327. use_cache: Optional[bool] = None,
  328. output_attentions: Optional[bool] = None,
  329. output_hidden_states: Optional[bool] = None,
  330. cache_position: Optional[torch.LongTensor] = None,
  331. **kwargs: Unpack[TransformersKwargs],
  332. ) -> BaseModelOutputWithPast:
  333. output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
  334. output_hidden_states = (
  335. output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
  336. )
  337. use_cache = use_cache if use_cache is not None else self.config.use_cache
  338. if (input_ids is None) ^ (inputs_embeds is not None):
  339. raise ValueError("You must specify exactly one of input_ids or inputs_embeds")
  340. if self.gradient_checkpointing and self.training and use_cache:
  341. logger.warning_once(
  342. "`use_cache=True` is incompatible with gradient checkpointing. Setting `use_cache=False`."
  343. )
  344. use_cache = False
  345. if inputs_embeds is None:
  346. inputs_embeds = self.embed_tokens(input_ids)
  347. inputs_embeds = inputs_embeds * self.embedding_multiplier # main diff with Llama
  348. if use_cache and past_key_values is None:
  349. past_key_values = DynamicCache(config=self.config)
  350. if cache_position is None:
  351. past_seen_tokens = past_key_values.get_seq_length() if past_key_values is not None else 0
  352. cache_position = torch.arange(
  353. past_seen_tokens, past_seen_tokens + inputs_embeds.shape[1], device=inputs_embeds.device
  354. )
  355. if position_ids is None:
  356. position_ids = cache_position.unsqueeze(0)
  357. causal_mask = create_causal_mask(
  358. config=self.config,
  359. input_embeds=inputs_embeds,
  360. attention_mask=attention_mask,
  361. cache_position=cache_position,
  362. past_key_values=past_key_values,
  363. position_ids=position_ids,
  364. )
  365. hidden_states = inputs_embeds
  366. # create position embeddings to be shared across the decoder layers
  367. position_embeddings = self.rotary_emb(hidden_states, position_ids)
  368. # decoder layers
  369. all_hidden_states = () if output_hidden_states else None
  370. all_self_attns = () if output_attentions else None
  371. for decoder_layer in self.layers[: self.config.num_hidden_layers]:
  372. if output_hidden_states:
  373. all_hidden_states += (hidden_states,)
  374. layer_outputs = decoder_layer(
  375. hidden_states,
  376. attention_mask=causal_mask,
  377. position_ids=position_ids,
  378. past_key_values=past_key_values,
  379. output_attentions=output_attentions,
  380. use_cache=use_cache,
  381. cache_position=cache_position,
  382. position_embeddings=position_embeddings,
  383. **kwargs,
  384. )
  385. hidden_states = layer_outputs[0]
  386. if output_attentions:
  387. all_self_attns += (layer_outputs[1],)
  388. hidden_states = self.norm(hidden_states)
  389. # add hidden states from the last decoder layer
  390. if output_hidden_states:
  391. all_hidden_states += (hidden_states,)
  392. return BaseModelOutputWithPast(
  393. last_hidden_state=hidden_states,
  394. past_key_values=past_key_values if use_cache else None,
  395. hidden_states=all_hidden_states,
  396. attentions=all_self_attns,
  397. )
  398. @auto_docstring
  399. class GraniteForCausalLM(GranitePreTrainedModel, GenerationMixin):
  400. _tied_weights_keys = ["lm_head.weight"]
  401. _tp_plan = {"lm_head": "colwise_rep"}
  402. _pp_plan = {"lm_head": (["hidden_states"], ["logits"])}
  403. def __init__(self, config):
  404. super().__init__(config)
  405. self.model = GraniteModel(config)
  406. self.vocab_size = config.vocab_size
  407. self.lm_head = nn.Linear(config.hidden_size, config.vocab_size, bias=False)
  408. # Initialize weights and apply final processing
  409. self.post_init()
  410. @can_return_tuple
  411. @auto_docstring
  412. def forward(
  413. self,
  414. input_ids: Optional[torch.LongTensor] = None,
  415. attention_mask: Optional[torch.Tensor] = None,
  416. position_ids: Optional[torch.LongTensor] = None,
  417. past_key_values: Optional[Union[Cache, list[torch.FloatTensor]]] = None,
  418. inputs_embeds: Optional[torch.FloatTensor] = None,
  419. labels: Optional[torch.LongTensor] = None,
  420. use_cache: Optional[bool] = None,
  421. output_attentions: Optional[bool] = None,
  422. output_hidden_states: Optional[bool] = None,
  423. cache_position: Optional[torch.LongTensor] = None,
  424. logits_to_keep: Union[int, torch.Tensor] = 0,
  425. **kwargs: Unpack[TransformersKwargs],
  426. ) -> CausalLMOutputWithPast:
  427. r"""
  428. Example:
  429. ```python
  430. >>> from transformers import AutoTokenizer, GraniteForCausalLM
  431. >>> model = GraniteForCausalLM.from_pretrained("meta-granite/Granite-2-7b-hf")
  432. >>> tokenizer = AutoTokenizer.from_pretrained("meta-granite/Granite-2-7b-hf")
  433. >>> prompt = "Hey, are you conscious? Can you talk to me?"
  434. >>> inputs = tokenizer(prompt, return_tensors="pt")
  435. >>> # Generate
  436. >>> generate_ids = model.generate(inputs.input_ids, max_length=30)
  437. >>> tokenizer.batch_decode(generate_ids, skip_special_tokens=True, clean_up_tokenization_spaces=False)[0]
  438. "Hey, are you conscious? Can you talk to me?\nI'm not conscious, but I can talk to you."
  439. ```"""
  440. output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
  441. output_hidden_states = (
  442. output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
  443. )
  444. # decoder outputs consists of (dec_features, layer_state, dec_hidden, dec_attn)
  445. outputs: BaseModelOutputWithPast = self.model(
  446. input_ids=input_ids,
  447. attention_mask=attention_mask,
  448. position_ids=position_ids,
  449. past_key_values=past_key_values,
  450. inputs_embeds=inputs_embeds,
  451. use_cache=use_cache,
  452. output_attentions=output_attentions,
  453. output_hidden_states=output_hidden_states,
  454. cache_position=cache_position,
  455. **kwargs,
  456. )
  457. hidden_states = outputs.last_hidden_state
  458. # Only compute necessary logits, and do not upcast them to float if we are not computing the loss
  459. slice_indices = slice(-logits_to_keep, None) if isinstance(logits_to_keep, int) else logits_to_keep
  460. logits = self.lm_head(hidden_states[:, slice_indices, :])
  461. logits = logits / self.config.logits_scaling # main diff with Llama
  462. loss = None
  463. if labels is not None:
  464. loss = self.loss_function(logits=logits, labels=labels, vocab_size=self.config.vocab_size, **kwargs)
  465. return CausalLMOutputWithPast(
  466. loss=loss,
  467. logits=logits,
  468. past_key_values=outputs.past_key_values,
  469. hidden_states=outputs.hidden_states,
  470. attentions=outputs.attentions,
  471. )
  472. __all__ = ["GraniteForCausalLM", "GraniteModel", "GranitePreTrainedModel"]