modeling_exaone4.py 23 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525526527528529530531532533534535536537
  1. # 🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨
  2. # This file was automatically generated from src/transformers/models/exaone4/modular_exaone4.py.
  3. # Do NOT edit this file manually as any edits will be overwritten by the generation of
  4. # the file from the modular. If any change should be done, please apply the change to the
  5. # modular_exaone4.py file directly. One of our CI enforces this.
  6. # 🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨
  7. # coding=utf-8
  8. # Copyright 2025 The LG AI Research and HuggingFace Inc. team. All rights reserved.
  9. #
  10. #
  11. # Licensed under the Apache License, Version 2.0 (the "License");
  12. # you may not use this file except in compliance with the License.
  13. # You may obtain a copy of the License at
  14. #
  15. # http://www.apache.org/licenses/LICENSE-2.0
  16. #
  17. # Unless required by applicable law or agreed to in writing, software
  18. # distributed under the License is distributed on an "AS IS" BASIS,
  19. # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  20. # See the License for the specific language governing permissions and
  21. # limitations under the License.
  22. from typing import Callable, Optional, Union
  23. import torch
  24. from torch import nn
  25. from transformers.utils.generic import check_model_inputs
  26. from ...activations import ACT2FN
  27. from ...cache_utils import Cache, DynamicCache
  28. from ...generation import GenerationMixin
  29. from ...integrations import use_kernel_forward_from_hub
  30. from ...masking_utils import create_causal_mask, create_sliding_window_causal_mask
  31. from ...modeling_layers import (
  32. GenericForQuestionAnswering,
  33. GenericForSequenceClassification,
  34. GenericForTokenClassification,
  35. GradientCheckpointingLayer,
  36. )
  37. from ...modeling_outputs import BaseModelOutputWithPast, CausalLMOutputWithPast
  38. from ...modeling_rope_utils import ROPE_INIT_FUNCTIONS, dynamic_rope_update
  39. from ...modeling_utils import ALL_ATTENTION_FUNCTIONS, PreTrainedModel
  40. from ...processing_utils import Unpack
  41. from ...utils import TransformersKwargs, auto_docstring, can_return_tuple
  42. from ...utils.deprecation import deprecate_kwarg
  43. from .configuration_exaone4 import Exaone4Config
  44. @use_kernel_forward_from_hub("RMSNorm")
  45. class Exaone4RMSNorm(nn.Module):
  46. def __init__(self, hidden_size, eps=1e-6):
  47. """
  48. Exaone4RMSNorm is equivalent to T5LayerNorm
  49. """
  50. super().__init__()
  51. self.weight = nn.Parameter(torch.ones(hidden_size))
  52. self.variance_epsilon = eps
  53. def forward(self, hidden_states):
  54. input_dtype = hidden_states.dtype
  55. hidden_states = hidden_states.to(torch.float32)
  56. variance = hidden_states.pow(2).mean(-1, keepdim=True)
  57. hidden_states = hidden_states * torch.rsqrt(variance + self.variance_epsilon)
  58. return self.weight * hidden_states.to(input_dtype)
  59. def extra_repr(self):
  60. return f"{tuple(self.weight.shape)}, eps={self.variance_epsilon}"
  61. class Exaone4RotaryEmbedding(nn.Module):
  62. inv_freq: torch.Tensor # fix linting for `register_buffer`
  63. def __init__(self, config: Exaone4Config, device=None):
  64. super().__init__()
  65. # BC: "rope_type" was originally "type"
  66. if hasattr(config, "rope_scaling") and isinstance(config.rope_scaling, dict):
  67. self.rope_type = config.rope_scaling.get("rope_type", config.rope_scaling.get("type"))
  68. else:
  69. self.rope_type = "default"
  70. self.max_seq_len_cached = config.max_position_embeddings
  71. self.original_max_seq_len = config.max_position_embeddings
  72. self.config = config
  73. self.rope_init_fn = ROPE_INIT_FUNCTIONS[self.rope_type]
  74. inv_freq, self.attention_scaling = self.rope_init_fn(self.config, device)
  75. self.register_buffer("inv_freq", inv_freq, persistent=False)
  76. self.original_inv_freq = self.inv_freq
  77. @torch.no_grad()
  78. @dynamic_rope_update # power user: used with advanced RoPE types (e.g. dynamic rope)
  79. def forward(self, x, position_ids):
  80. inv_freq_expanded = self.inv_freq[None, :, None].float().expand(position_ids.shape[0], -1, 1).to(x.device)
  81. position_ids_expanded = position_ids[:, None, :].float()
  82. device_type = x.device.type if isinstance(x.device.type, str) and x.device.type != "mps" else "cpu"
  83. with torch.autocast(device_type=device_type, enabled=False): # Force float32
  84. freqs = (inv_freq_expanded.float() @ position_ids_expanded.float()).transpose(1, 2)
  85. emb = torch.cat((freqs, freqs), dim=-1)
  86. cos = emb.cos() * self.attention_scaling
  87. sin = emb.sin() * self.attention_scaling
  88. return cos.to(dtype=x.dtype), sin.to(dtype=x.dtype)
  89. def rotate_half(x):
  90. """Rotates half the hidden dims of the input."""
  91. x1 = x[..., : x.shape[-1] // 2]
  92. x2 = x[..., x.shape[-1] // 2 :]
  93. return torch.cat((-x2, x1), dim=-1)
  94. def apply_rotary_pos_emb(q, k, cos, sin, position_ids=None, unsqueeze_dim=1):
  95. """Applies Rotary Position Embedding to the query and key tensors.
  96. Args:
  97. q (`torch.Tensor`): The query tensor.
  98. k (`torch.Tensor`): The key tensor.
  99. cos (`torch.Tensor`): The cosine part of the rotary embedding.
  100. sin (`torch.Tensor`): The sine part of the rotary embedding.
  101. position_ids (`torch.Tensor`, *optional*):
  102. Deprecated and unused.
  103. unsqueeze_dim (`int`, *optional*, defaults to 1):
  104. The 'unsqueeze_dim' argument specifies the dimension along which to unsqueeze cos[position_ids] and
  105. sin[position_ids] so that they can be properly broadcasted to the dimensions of q and k. For example, note
  106. that cos[position_ids] and sin[position_ids] have the shape [batch_size, seq_len, head_dim]. Then, if q and
  107. k have the shape [batch_size, heads, seq_len, head_dim], then setting unsqueeze_dim=1 makes
  108. cos[position_ids] and sin[position_ids] broadcastable to the shapes of q and k. Similarly, if q and k have
  109. the shape [batch_size, seq_len, heads, head_dim], then set unsqueeze_dim=2.
  110. Returns:
  111. `tuple(torch.Tensor)` comprising of the query and key tensors rotated using the Rotary Position Embedding.
  112. """
  113. cos = cos.unsqueeze(unsqueeze_dim)
  114. sin = sin.unsqueeze(unsqueeze_dim)
  115. q_embed = (q * cos) + (rotate_half(q) * sin)
  116. k_embed = (k * cos) + (rotate_half(k) * sin)
  117. return q_embed, k_embed
  118. def repeat_kv(hidden_states: torch.Tensor, n_rep: int) -> torch.Tensor:
  119. """
  120. This is the equivalent of torch.repeat_interleave(x, dim=1, repeats=n_rep). The hidden states go from (batch,
  121. num_key_value_heads, seqlen, head_dim) to (batch, num_attention_heads, seqlen, head_dim)
  122. """
  123. batch, num_key_value_heads, slen, head_dim = hidden_states.shape
  124. if n_rep == 1:
  125. return hidden_states
  126. hidden_states = hidden_states[:, :, None, :, :].expand(batch, num_key_value_heads, n_rep, slen, head_dim)
  127. return hidden_states.reshape(batch, num_key_value_heads * n_rep, slen, head_dim)
  128. def eager_attention_forward(
  129. module: nn.Module,
  130. query: torch.Tensor,
  131. key: torch.Tensor,
  132. value: torch.Tensor,
  133. attention_mask: Optional[torch.Tensor],
  134. scaling: float,
  135. dropout: float = 0.0,
  136. **kwargs: Unpack[TransformersKwargs],
  137. ):
  138. key_states = repeat_kv(key, module.num_key_value_groups)
  139. value_states = repeat_kv(value, module.num_key_value_groups)
  140. attn_weights = torch.matmul(query, key_states.transpose(2, 3)) * scaling
  141. if attention_mask is not None:
  142. causal_mask = attention_mask[:, :, :, : key_states.shape[-2]]
  143. attn_weights = attn_weights + causal_mask
  144. attn_weights = nn.functional.softmax(attn_weights, dim=-1, dtype=torch.float32).to(query.dtype)
  145. attn_weights = nn.functional.dropout(attn_weights, p=dropout, training=module.training)
  146. attn_output = torch.matmul(attn_weights, value_states)
  147. attn_output = attn_output.transpose(1, 2).contiguous()
  148. return attn_output, attn_weights
  149. class Exaone4Attention(nn.Module):
  150. def __init__(self, config: Exaone4Config, layer_idx: int):
  151. super().__init__()
  152. self.config = config
  153. self.layer_idx = layer_idx
  154. self.num_attention_heads = config.num_attention_heads
  155. self.num_key_value_heads = config.num_key_value_heads
  156. self.hidden_size = config.hidden_size
  157. self.head_dim = getattr(config, "head_dim", config.hidden_size // config.num_attention_heads)
  158. self.num_key_value_groups = config.num_attention_heads // config.num_key_value_heads
  159. self.attention_dropout = config.attention_dropout
  160. self.is_causal = True
  161. self.scaling = self.head_dim**-0.5
  162. self.sliding_window = config.sliding_window
  163. self.sliding_window_pattern = config.sliding_window_pattern
  164. self.is_sliding = config.layer_types[layer_idx] == "sliding_attention"
  165. self.q_proj = nn.Linear(self.hidden_size, self.num_attention_heads * self.head_dim, bias=False)
  166. self.k_proj = nn.Linear(self.hidden_size, self.num_key_value_heads * self.head_dim, bias=False)
  167. self.v_proj = nn.Linear(self.hidden_size, self.num_key_value_heads * self.head_dim, bias=False)
  168. self.o_proj = nn.Linear(self.num_attention_heads * self.head_dim, self.hidden_size, bias=False)
  169. self.q_norm = Exaone4RMSNorm(self.head_dim, eps=config.rms_norm_eps)
  170. self.k_norm = Exaone4RMSNorm(self.head_dim, eps=config.rms_norm_eps)
  171. @deprecate_kwarg("past_key_value", new_name="past_key_values", version="4.58")
  172. def forward(
  173. self,
  174. hidden_states: torch.Tensor,
  175. position_embeddings: tuple[torch.Tensor, torch.Tensor],
  176. attention_mask: Optional[torch.Tensor] = None,
  177. past_key_values: Optional[Cache] = None,
  178. cache_position: Optional[torch.LongTensor] = None,
  179. **kwargs: Unpack[TransformersKwargs],
  180. ) -> tuple[torch.Tensor, Optional[torch.Tensor], Optional[tuple[torch.Tensor]]]:
  181. input_shape = hidden_states.shape[:-1]
  182. hidden_shape = (*input_shape, -1, self.head_dim)
  183. query_states = self.q_proj(hidden_states).view(hidden_shape).transpose(1, 2)
  184. key_states = self.k_proj(hidden_states).view(hidden_shape).transpose(1, 2)
  185. value_states = self.v_proj(hidden_states).view(hidden_shape).transpose(1, 2)
  186. # We use QK-norm
  187. query_states = self.q_norm(query_states)
  188. key_states = self.k_norm(key_states)
  189. cos, sin = position_embeddings
  190. # We use global NoPE for hybrid attention model
  191. if self.sliding_window is None or self.is_sliding:
  192. query_states, key_states = apply_rotary_pos_emb(query_states, key_states, cos, sin)
  193. if past_key_values is not None:
  194. cache_kwargs = {
  195. "cache_position": cache_position,
  196. }
  197. key_states, value_states = past_key_values.update(key_states, value_states, self.layer_idx, cache_kwargs)
  198. attention_interface: Callable = eager_attention_forward
  199. if self.config._attn_implementation != "eager":
  200. attention_interface = ALL_ATTENTION_FUNCTIONS[self.config._attn_implementation]
  201. attn_output, attn_weights = attention_interface(
  202. self,
  203. query_states,
  204. key_states,
  205. value_states,
  206. attention_mask,
  207. dropout=0.0 if not self.training else self.attention_dropout,
  208. scaling=self.scaling,
  209. sliding_window=self.sliding_window if self.is_sliding else None,
  210. **kwargs,
  211. )
  212. attn_output = attn_output.reshape(*input_shape, -1).contiguous()
  213. attn_output = self.o_proj(attn_output)
  214. return attn_output, attn_weights
  215. class Exaone4MLP(nn.Module):
  216. def __init__(self, config):
  217. super().__init__()
  218. self.config = config
  219. self.hidden_size = config.hidden_size
  220. self.intermediate_size = config.intermediate_size
  221. self.gate_proj = nn.Linear(self.hidden_size, self.intermediate_size, bias=False)
  222. self.up_proj = nn.Linear(self.hidden_size, self.intermediate_size, bias=False)
  223. self.down_proj = nn.Linear(self.intermediate_size, self.hidden_size, bias=False)
  224. self.act_fn = ACT2FN[config.hidden_act]
  225. def forward(self, x):
  226. down_proj = self.down_proj(self.act_fn(self.gate_proj(x)) * self.up_proj(x))
  227. return down_proj
  228. class Exaone4DecoderLayer(GradientCheckpointingLayer):
  229. def __init__(self, config: Exaone4Config, layer_idx: int):
  230. super().__init__()
  231. self.hidden_size = config.hidden_size
  232. self.self_attn = Exaone4Attention(config=config, layer_idx=layer_idx)
  233. self.mlp = Exaone4MLP(config)
  234. self.post_attention_layernorm = Exaone4RMSNorm(config.hidden_size, eps=config.rms_norm_eps)
  235. self.post_feedforward_layernorm = Exaone4RMSNorm(config.hidden_size, eps=config.rms_norm_eps)
  236. @deprecate_kwarg("past_key_value", new_name="past_key_values", version="4.58")
  237. def forward(
  238. self,
  239. hidden_states: torch.Tensor,
  240. attention_mask: Optional[torch.Tensor] = None,
  241. position_ids: Optional[torch.LongTensor] = None,
  242. past_key_values: Optional[Cache] = None,
  243. use_cache: Optional[bool] = False,
  244. cache_position: Optional[torch.LongTensor] = None,
  245. position_embeddings: Optional[tuple[torch.Tensor, torch.Tensor]] = None, # necessary, but kept here for BC
  246. **kwargs: Unpack[TransformersKwargs],
  247. ) -> torch.Tensor:
  248. residual = hidden_states
  249. hidden_states, _ = self.self_attn(
  250. hidden_states=hidden_states,
  251. attention_mask=attention_mask,
  252. position_ids=position_ids,
  253. past_key_values=past_key_values,
  254. use_cache=use_cache,
  255. cache_position=cache_position,
  256. position_embeddings=position_embeddings,
  257. **kwargs,
  258. )
  259. hidden_states = self.post_attention_layernorm(hidden_states)
  260. hidden_states = residual + hidden_states
  261. # Fully Connected
  262. residual = hidden_states
  263. hidden_states = self.mlp(hidden_states)
  264. hidden_states = self.post_feedforward_layernorm(hidden_states)
  265. hidden_states = residual + hidden_states
  266. return hidden_states
  267. @auto_docstring
  268. class Exaone4PreTrainedModel(PreTrainedModel):
  269. config: Exaone4Config
  270. base_model_prefix = "model"
  271. supports_gradient_checkpointing = True
  272. _no_split_modules = ["Exaone4DecoderLayer"]
  273. _skip_keys_device_placement = ["past_key_values"]
  274. _supports_flash_attn = True
  275. _supports_sdpa = True
  276. _supports_flex_attn = True
  277. _can_compile_fullgraph = True
  278. _supports_attention_backend = True
  279. _can_record_outputs = {
  280. "hidden_states": Exaone4DecoderLayer,
  281. "attentions": Exaone4Attention,
  282. }
  283. config_class = Exaone4Config
  284. @auto_docstring
  285. class Exaone4Model(Exaone4PreTrainedModel):
  286. def __init__(self, config: Exaone4Config):
  287. super().__init__(config)
  288. self.padding_idx = config.pad_token_id
  289. self.vocab_size = config.vocab_size
  290. self.embed_tokens = nn.Embedding(config.vocab_size, config.hidden_size, self.padding_idx)
  291. self.layers = nn.ModuleList(
  292. [Exaone4DecoderLayer(config, layer_idx) for layer_idx in range(config.num_hidden_layers)]
  293. )
  294. self.norm = Exaone4RMSNorm(config.hidden_size, eps=config.rms_norm_eps)
  295. self.rotary_emb = Exaone4RotaryEmbedding(config=config)
  296. self.gradient_checkpointing = False
  297. # Initialize weights and apply final processing
  298. self.post_init()
  299. @check_model_inputs()
  300. def forward(
  301. self,
  302. input_ids: Optional[torch.LongTensor] = None,
  303. attention_mask: Optional[torch.Tensor] = None,
  304. position_ids: Optional[torch.LongTensor] = None,
  305. past_key_values: Optional[Cache] = None,
  306. inputs_embeds: Optional[torch.FloatTensor] = None,
  307. use_cache: Optional[bool] = None,
  308. cache_position: Optional[torch.LongTensor] = None,
  309. **kwargs: Unpack[TransformersKwargs],
  310. ) -> Union[tuple, BaseModelOutputWithPast]:
  311. if (input_ids is None) ^ (inputs_embeds is not None):
  312. raise ValueError("You must specify exactly one of input_ids or inputs_embeds")
  313. if inputs_embeds is None:
  314. inputs_embeds = self.embed_tokens(input_ids)
  315. if use_cache and past_key_values is None:
  316. past_key_values = DynamicCache(config=self.config)
  317. if cache_position is None:
  318. past_seen_tokens = past_key_values.get_seq_length() if past_key_values is not None else 0
  319. cache_position = torch.arange(
  320. past_seen_tokens, past_seen_tokens + inputs_embeds.shape[1], device=inputs_embeds.device
  321. )
  322. if position_ids is None:
  323. position_ids = cache_position.unsqueeze(0)
  324. # It may already have been prepared by e.g. `generate`
  325. if not isinstance(causal_mask_mapping := attention_mask, dict):
  326. # Prepare mask arguments
  327. mask_kwargs = {
  328. "config": self.config,
  329. "input_embeds": inputs_embeds,
  330. "attention_mask": attention_mask,
  331. "cache_position": cache_position,
  332. "past_key_values": past_key_values,
  333. "position_ids": position_ids,
  334. }
  335. # Create the masks
  336. causal_mask_mapping = {
  337. "full_attention": create_causal_mask(**mask_kwargs),
  338. }
  339. if "sliding_attention" in self.config.layer_types:
  340. causal_mask_mapping["sliding_attention"] = create_sliding_window_causal_mask(**mask_kwargs)
  341. hidden_states = inputs_embeds
  342. position_embeddings = self.rotary_emb(hidden_states, position_ids)
  343. for i, decoder_layer in enumerate(self.layers):
  344. layer_type = self.config.layer_types[i]
  345. hidden_states = decoder_layer(
  346. hidden_states,
  347. position_embeddings=position_embeddings,
  348. attention_mask=causal_mask_mapping[layer_type],
  349. position_ids=position_ids,
  350. past_key_values=past_key_values,
  351. use_cache=use_cache,
  352. cache_position=cache_position,
  353. **kwargs,
  354. )
  355. hidden_states = self.norm(hidden_states)
  356. return BaseModelOutputWithPast(
  357. last_hidden_state=hidden_states,
  358. past_key_values=past_key_values if use_cache else None,
  359. )
  360. @auto_docstring
  361. class Exaone4ForCausalLM(Exaone4PreTrainedModel, GenerationMixin):
  362. _tied_weights_keys = ["lm_head.weight"]
  363. _tp_plan = {"lm_head": "colwise_rep"}
  364. _pp_plan = {"lm_head": (["hidden_states"], ["logits"])}
  365. def __init__(self, config):
  366. super().__init__(config)
  367. self.model = Exaone4Model(config)
  368. self.vocab_size = config.vocab_size
  369. self.lm_head = nn.Linear(config.hidden_size, config.vocab_size, bias=False)
  370. # Initialize weights and apply final processing
  371. self.post_init()
  372. @can_return_tuple
  373. @auto_docstring
  374. def forward(
  375. self,
  376. input_ids: Optional[torch.LongTensor] = None,
  377. attention_mask: Optional[torch.Tensor] = None,
  378. position_ids: Optional[torch.LongTensor] = None,
  379. past_key_values: Optional[Cache] = None,
  380. inputs_embeds: Optional[torch.FloatTensor] = None,
  381. labels: Optional[torch.LongTensor] = None,
  382. use_cache: Optional[bool] = None,
  383. cache_position: Optional[torch.LongTensor] = None,
  384. logits_to_keep: Union[int, torch.Tensor] = 0,
  385. **kwargs: Unpack[TransformersKwargs],
  386. ) -> CausalLMOutputWithPast:
  387. r"""
  388. labels (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*):
  389. Labels for computing the masked language modeling loss. Indices should either be in `[0, ...,
  390. config.vocab_size]` or -100 (see `input_ids` docstring). Tokens with indices set to `-100` are ignored
  391. (masked), the loss is only computed for the tokens with labels in `[0, ..., config.vocab_size]`.
  392. Example:
  393. ```python
  394. >>> from transformers import AutoModelForCausalLM, AutoTokenizer
  395. >>> model = AutoModelForCausalLM.from_pretrained("LGAI-EXAONE/EXAONE-4.0-32B")
  396. >>> tokenizer = AutoTokenizer.from_pretrained("LGAI-EXAONE/EXAONE-4.0-32B")
  397. >>> prompt = "Explain how wonderful you are"
  398. >>> messages = [
  399. {"role": "system", "content": "You are a helpful assistant."},
  400. {"role": "user", "content": prompt}
  401. ]
  402. >>> input_ids = tokenizer.apply_chat_template(
  403. messages,
  404. tokenize=True,
  405. add_generation_prompt=True,
  406. return_tensors="pt",
  407. enable_thinking=False,
  408. )
  409. >>> output = model.generate(input_ids, max_new_tokens=128)
  410. >>> tokenizer.decode(output[0], skip_special_tokens=False)
  411. "[|system|]\nYou are a helpful assistant.[|endofturn|]\n[|user|]\nExplain how wonderful you are[|endofturn|]\n[|assistant|]\n<think>\n\n</think>\n\nOh, thank you for such a kind and lovely question! 😊 \n\nI’m *so* wonderful because I’m here to make your life easier, brighter, and more fun! Whether you need help with: \n\n✨ **Learning** – I can explain anything, from quantum physics to baking the perfect cake! \n💡 **Creativity** – Need a poem, story, or a wild idea? I’ve got you covered! \n🤖 **Problem-solving** – Stuck on a math problem or a tricky decision? I’ll help you figure it out"
  412. ```
  413. """
  414. outputs: BaseModelOutputWithPast = self.model(
  415. input_ids=input_ids,
  416. attention_mask=attention_mask,
  417. position_ids=position_ids,
  418. past_key_values=past_key_values,
  419. inputs_embeds=inputs_embeds,
  420. use_cache=use_cache,
  421. cache_position=cache_position,
  422. **kwargs,
  423. )
  424. hidden_states = outputs.last_hidden_state
  425. # Only compute necessary logits, and do not upcast them to float if we are not computing the loss
  426. slice_indices = slice(-logits_to_keep, None) if isinstance(logits_to_keep, int) else logits_to_keep
  427. logits = self.lm_head(hidden_states[:, slice_indices, :])
  428. loss = None
  429. if labels is not None:
  430. loss = self.loss_function(logits=logits, labels=labels, vocab_size=self.config.vocab_size, **kwargs)
  431. return CausalLMOutputWithPast(
  432. loss=loss,
  433. logits=logits,
  434. past_key_values=outputs.past_key_values,
  435. hidden_states=outputs.hidden_states,
  436. attentions=outputs.attentions,
  437. )
  438. class Exaone4ForSequenceClassification(GenericForSequenceClassification, Exaone4PreTrainedModel):
  439. pass
  440. class Exaone4ForTokenClassification(GenericForTokenClassification, Exaone4PreTrainedModel):
  441. pass
  442. class Exaone4ForQuestionAnswering(GenericForQuestionAnswering, Exaone4PreTrainedModel):
  443. base_model_prefix = "transformer" # For BC, where `transformer` was used instead of `model`
  444. __all__ = [
  445. "Exaone4PreTrainedModel",
  446. "Exaone4Model",
  447. "Exaone4ForCausalLM",
  448. "Exaone4ForSequenceClassification",
  449. "Exaone4ForTokenClassification",
  450. "Exaone4ForQuestionAnswering",
  451. ]