modular_exaone4.py 23 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519
  1. # coding=utf-8
  2. # Copyright 2025 The LG AI Research and HuggingFace Inc. team. All rights reserved.
  3. #
  4. #
  5. # Licensed under the Apache License, Version 2.0 (the "License");
  6. # you may not use this file except in compliance with the License.
  7. # You may obtain a copy of the License at
  8. #
  9. # http://www.apache.org/licenses/LICENSE-2.0
  10. #
  11. # Unless required by applicable law or agreed to in writing, software
  12. # distributed under the License is distributed on an "AS IS" BASIS,
  13. # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  14. # See the License for the specific language governing permissions and
  15. # limitations under the License.
  16. """LG AI Research EXAONE Lab"""
  17. from typing import Callable, Optional, Union
  18. import torch
  19. from torch import nn
  20. from transformers.utils.generic import check_model_inputs
  21. from ...cache_utils import Cache, DynamicCache
  22. from ...configuration_utils import PretrainedConfig, layer_type_validation
  23. from ...masking_utils import create_causal_mask, create_sliding_window_causal_mask
  24. from ...modeling_outputs import (
  25. BaseModelOutputWithPast,
  26. CausalLMOutputWithPast,
  27. )
  28. from ...modeling_utils import ALL_ATTENTION_FUNCTIONS
  29. from ...processing_utils import Unpack
  30. from ...utils import (
  31. TransformersKwargs,
  32. logging,
  33. )
  34. from ...utils.deprecation import deprecate_kwarg
  35. from ..llama.modeling_llama import (
  36. LlamaForCausalLM,
  37. LlamaForQuestionAnswering,
  38. LlamaForSequenceClassification,
  39. LlamaForTokenClassification,
  40. LlamaModel,
  41. LlamaPreTrainedModel,
  42. LlamaRMSNorm,
  43. LlamaRotaryEmbedding,
  44. apply_rotary_pos_emb,
  45. eager_attention_forward,
  46. )
  47. from ..olmo2.modeling_olmo2 import Olmo2DecoderLayer, Olmo2MLP
  48. logger = logging.get_logger(__name__)
  49. _CHECKPOINT_FOR_DOC = "LGAI-EXAONE/EXAONE-4.0-32B"
  50. _CONFIG_FOR_DOC = "Exaone4Config"
  51. class Exaone4Config(PretrainedConfig):
  52. r"""
  53. This is the configuration class to store the configuration of a [`Exaone4Model`]. It is used to
  54. instantiate a EXAONE 4.0 model according to the specified arguments, defining the model architecture. Instantiating a
  55. configuration with the defaults will yield a similar configuration to that of the EXAONE-4.0-32B [LGAI-EXAONE/EXAONE-4.0-32B](https://huggingface.co/LGAI-EXAONE/EXAONE-4.0-32B)
  56. Configuration objects inherit from [`PretrainedConfig`] and can be used to control the model
  57. outputs. Read the documentation from [`PretrainedConfig`] for more information.
  58. Args:
  59. vocab_size (`int`, *optional*, defaults to 102400):
  60. Vocabulary size of the EXAONE 4.0 model. Defines the number of different tokens that can be represented by the
  61. `inputs_ids` passed when calling [`Exaone4Model`].
  62. hidden_size (`int`, *optional*, defaults to 4096):
  63. Dimension of the hidden representations.
  64. intermediate_size (`int`, *optional*, defaults to `hidden_size * 4`):
  65. Dimensionality of the MLP representations.
  66. num_hidden_layers (`int`, *optional*, defaults to 32):
  67. Number of hidden layers in the Transformer encoder.
  68. num_attention_heads (`int`, *optional*, defaults to 32):
  69. Number of attention heads for each attention layer in the Transformer decoder.
  70. num_key_value_heads (`int`, *optional*):
  71. This is the number of key_value heads that should be used to implement Grouped Query Attention. If
  72. `num_key_value_heads=num_attention_heads`, the model will use Multi Head Attention (MHA), if
  73. `num_key_value_heads=1 the model will use Multi Query Attention (MQA) otherwise GQA is used. When
  74. converting a multi-head checkpoint to a GQA checkpoint, each group key and value head should be constructed
  75. by meanpooling all the original heads within that group. For more details checkout [this
  76. paper](https://huggingface.co/papers/2305.13245). If it is not specified, will default to
  77. `num_attention_heads`.
  78. hidden_act (`str` or `function`, *optional*, defaults to `"silu"`):
  79. The non-linear activation function (function or string) in the decoder.
  80. max_position_embeddings (`int`, *optional*, defaults to 2048):
  81. The maximum sequence length that this model might ever be used with. Typically set this to something large
  82. just in case (e.g., 32768 for EXAONE 3.5).
  83. initializer_range (`float`, *optional*, defaults to 0.02):
  84. The standard deviation of the truncated_normal_initializer for initializing all weight matrices.
  85. rms_norm_eps (`float`, *optional*, defaults to 1e-05):
  86. The epsilon used by the layer normalization layers.
  87. use_cache (`bool`, *optional*, defaults to `True`):
  88. Whether or not the model should return the last key/values attentions (not used by all models). Only
  89. relevant if ``config.is_decoder=True``.
  90. bos_token_id (`int`, *optional*, defaults to 0):
  91. Beginning of stream token id.
  92. eos_token_id (`int`, *optional*, defaults to 2):
  93. End of stream token id.
  94. tie_word_embeddings (`bool`, *optional*, defaults to `False`):
  95. Whether to tie weight embeddings
  96. rope_theta (`float`, *optional*, defaults to 10000.0):
  97. The base period of the RoPE embeddings.
  98. rope_scaling (`Dict`, *optional*):
  99. Dictionary containing the scaling configuration for the RoPE embeddings. NOTE: if you apply new rope type
  100. and you expect the model to work on longer `max_position_embeddings`, we recommend you to update this value
  101. accordingly.
  102. Expected contents:
  103. `rope_type` (`str`):
  104. The sub-variant of RoPE to use. Can be one of ['default', 'linear', 'dynamic', 'yarn', 'longrope',
  105. 'llama3'], with 'default' being the original RoPE implementation.
  106. `factor` (`float`, *optional*):
  107. Used with all rope types except 'default'. The scaling factor to apply to the RoPE embeddings. In
  108. most scaling types, a `factor` of x will enable the model to handle sequences of length x *
  109. original maximum pre-trained length.
  110. `original_max_position_embeddings` (`int`, *optional*):
  111. Used with 'dynamic', 'longrope' and 'llama3'. The original max position embeddings used during
  112. pretraining.
  113. `attention_factor` (`float`, *optional*):
  114. Used with 'yarn' and 'longrope'. The scaling factor to be applied on the attention
  115. computation. If unspecified, it defaults to value recommended by the implementation, using the
  116. `factor` field to infer the suggested value.
  117. `beta_fast` (`float`, *optional*):
  118. Only used with 'yarn'. Parameter to set the boundary for extrapolation (only) in the linear
  119. ramp function. If unspecified, it defaults to 32.
  120. `beta_slow` (`float`, *optional*):
  121. Only used with 'yarn'. Parameter to set the boundary for interpolation (only) in the linear
  122. ramp function. If unspecified, it defaults to 1.
  123. `short_factor` (`List[float]`, *optional*):
  124. Only used with 'longrope'. The scaling factor to be applied to short contexts (<
  125. `original_max_position_embeddings`). Must be a list of numbers with the same length as the hidden
  126. size divided by the number of attention heads divided by 2
  127. `long_factor` (`List[float]`, *optional*):
  128. Only used with 'longrope'. The scaling factor to be applied to long contexts (<
  129. `original_max_position_embeddings`). Must be a list of numbers with the same length as the hidden
  130. size divided by the number of attention heads divided by 2
  131. `low_freq_factor` (`float`, *optional*):
  132. Only used with 'llama3'. Scaling factor applied to low frequency components of the RoPE
  133. `high_freq_factor` (`float`, *optional*):
  134. Only used with 'llama3'. Scaling factor applied to high frequency components of the RoPE
  135. attention_dropout (`float`, *optional*, defaults to 0.0):
  136. The dropout ratio for the attention probabilities.
  137. sliding_window (`int`, *optional*):
  138. The size of the sliding window for the sliding window attention.
  139. sliding_window_pattern (`str`, *optional*):
  140. The pattern to use for sliding window attention. Can be one of:
  141. - `None`: No sliding window attention is used
  142. - `int`: Every `sliding_window` layers, use global attention, else use local attention.
  143. - `str`: A sequence of "L" (local attention) and "G" (global attention) characters that defines the
  144. attention pattern. The pattern starts from layer 0 and repeats every `sliding_window` layers. The
  145. final layer always uses global attention regardless of the pattern.
  146. For instance, sliding_window_pattern="LLLG" same as sliding_window=4, which means:
  147. - Layer 0, 1, 2: local attention,
  148. - Layer 3: global attention,
  149. ...(repeated)
  150. layer_types (`list`, *optional*):
  151. Attention pattern for each layer. Prioritized over `sliding_window_pattern`.
  152. Example:
  153. ```python
  154. >>> from transformers import Exaone4Model, Exaone4Config
  155. >>> # Initializing a EXAONE configuration
  156. >>> configuration = Exaone4Config()
  157. >>> # Initializing a model from configuration
  158. >>> model = Exaone4Model(configuration)
  159. >>> # Accessing the model configuration
  160. >>> configuration = model.config
  161. ```"""
  162. model_type = "exaone4"
  163. keys_to_ignore_at_inference = ["past_key_values"]
  164. # Default tensor parallel plan for base model `LlamaModel`
  165. base_model_tp_plan = {
  166. "layers.*.self_attn.q_proj": "colwise",
  167. "layers.*.self_attn.k_proj": "colwise",
  168. "layers.*.self_attn.v_proj": "colwise",
  169. "layers.*.self_attn.o_proj": "rowwise",
  170. "layers.*.mlp.gate_proj": "colwise",
  171. "layers.*.mlp.up_proj": "colwise",
  172. "layers.*.mlp.down_proj": "rowwise",
  173. }
  174. base_model_pp_plan = {
  175. "embed_tokens": (["input_ids"], ["inputs_embeds"]),
  176. "layers": (["hidden_states", "attention_mask"], ["hidden_states"]),
  177. "norm": (["hidden_states"], ["hidden_states"]),
  178. }
  179. def __init__(
  180. self,
  181. vocab_size=102400,
  182. hidden_size=4096,
  183. intermediate_size=16384,
  184. num_hidden_layers=32,
  185. num_attention_heads=32,
  186. num_key_value_heads=32,
  187. hidden_act="silu",
  188. max_position_embeddings=2048,
  189. initializer_range=0.02,
  190. rms_norm_eps=1e-5,
  191. use_cache=True,
  192. bos_token_id=0,
  193. eos_token_id=2,
  194. tie_word_embeddings=False,
  195. rope_theta=10000.0,
  196. rope_scaling=None,
  197. attention_dropout=0.0,
  198. sliding_window=4096,
  199. sliding_window_pattern=4,
  200. layer_types=None,
  201. **kwargs,
  202. ):
  203. self.vocab_size = vocab_size
  204. self.hidden_size = hidden_size
  205. self.num_hidden_layers = num_hidden_layers
  206. self.num_attention_heads = num_attention_heads
  207. self.num_key_value_heads = num_key_value_heads
  208. self.intermediate_size = intermediate_size
  209. self.hidden_act = hidden_act
  210. self.max_position_embeddings = max_position_embeddings
  211. self.initializer_range = initializer_range
  212. self.rms_norm_eps = rms_norm_eps
  213. self.use_cache = use_cache
  214. self.attention_dropout = attention_dropout
  215. self.rope_theta = rope_theta
  216. self.rope_scaling = rope_scaling
  217. self.sliding_window = sliding_window
  218. self.sliding_window_pattern = sliding_window_pattern
  219. self.layer_types = layer_types
  220. if self.sliding_window is None:
  221. sliding_window_pattern = 0
  222. if self.layer_types is None:
  223. self.layer_types = [
  224. "sliding_attention"
  225. if ((i + 1) % (sliding_window_pattern) != 0 and i < self.num_hidden_layers)
  226. else "full_attention"
  227. for i in range(self.num_hidden_layers)
  228. ]
  229. if "sliding_window" in self.layer_types:
  230. self.cache_implementation = "hybrid"
  231. layer_type_validation(self.layer_types, self.num_hidden_layers)
  232. super().__init__(
  233. bos_token_id=bos_token_id, eos_token_id=eos_token_id, tie_word_embeddings=tie_word_embeddings, **kwargs
  234. )
  235. class Exaone4RMSNorm(LlamaRMSNorm):
  236. pass
  237. class Exaone4RotaryEmbedding(LlamaRotaryEmbedding):
  238. pass
  239. class Exaone4Attention(nn.Module):
  240. def __init__(self, config: Exaone4Config, layer_idx: int):
  241. super().__init__()
  242. self.config = config
  243. self.layer_idx = layer_idx
  244. self.num_attention_heads = config.num_attention_heads
  245. self.num_key_value_heads = config.num_key_value_heads
  246. self.hidden_size = config.hidden_size
  247. self.head_dim = getattr(config, "head_dim", config.hidden_size // config.num_attention_heads)
  248. self.num_key_value_groups = config.num_attention_heads // config.num_key_value_heads
  249. self.attention_dropout = config.attention_dropout
  250. self.is_causal = True
  251. self.scaling = self.head_dim**-0.5
  252. self.sliding_window = config.sliding_window
  253. self.sliding_window_pattern = config.sliding_window_pattern
  254. self.is_sliding = config.layer_types[layer_idx] == "sliding_attention"
  255. self.q_proj = nn.Linear(self.hidden_size, self.num_attention_heads * self.head_dim, bias=False)
  256. self.k_proj = nn.Linear(self.hidden_size, self.num_key_value_heads * self.head_dim, bias=False)
  257. self.v_proj = nn.Linear(self.hidden_size, self.num_key_value_heads * self.head_dim, bias=False)
  258. self.o_proj = nn.Linear(self.num_attention_heads * self.head_dim, self.hidden_size, bias=False)
  259. self.q_norm = Exaone4RMSNorm(self.head_dim, eps=config.rms_norm_eps)
  260. self.k_norm = Exaone4RMSNorm(self.head_dim, eps=config.rms_norm_eps)
  261. @deprecate_kwarg("past_key_value", new_name="past_key_values", version="4.58")
  262. def forward(
  263. self,
  264. hidden_states: torch.Tensor,
  265. position_embeddings: tuple[torch.Tensor, torch.Tensor],
  266. attention_mask: Optional[torch.Tensor] = None,
  267. past_key_values: Optional[Cache] = None,
  268. cache_position: Optional[torch.LongTensor] = None,
  269. **kwargs: Unpack[TransformersKwargs],
  270. ) -> tuple[torch.Tensor, Optional[torch.Tensor], Optional[tuple[torch.Tensor]]]:
  271. input_shape = hidden_states.shape[:-1]
  272. hidden_shape = (*input_shape, -1, self.head_dim)
  273. query_states = self.q_proj(hidden_states).view(hidden_shape).transpose(1, 2)
  274. key_states = self.k_proj(hidden_states).view(hidden_shape).transpose(1, 2)
  275. value_states = self.v_proj(hidden_states).view(hidden_shape).transpose(1, 2)
  276. # We use QK-norm
  277. query_states = self.q_norm(query_states)
  278. key_states = self.k_norm(key_states)
  279. cos, sin = position_embeddings
  280. # We use global NoPE for hybrid attention model
  281. if self.sliding_window is None or self.is_sliding:
  282. query_states, key_states = apply_rotary_pos_emb(query_states, key_states, cos, sin)
  283. if past_key_values is not None:
  284. cache_kwargs = {
  285. "cache_position": cache_position,
  286. }
  287. key_states, value_states = past_key_values.update(key_states, value_states, self.layer_idx, cache_kwargs)
  288. attention_interface: Callable = eager_attention_forward
  289. if self.config._attn_implementation != "eager":
  290. attention_interface = ALL_ATTENTION_FUNCTIONS[self.config._attn_implementation]
  291. attn_output, attn_weights = attention_interface(
  292. self,
  293. query_states,
  294. key_states,
  295. value_states,
  296. attention_mask,
  297. dropout=0.0 if not self.training else self.attention_dropout,
  298. scaling=self.scaling,
  299. sliding_window=self.sliding_window if self.is_sliding else None,
  300. **kwargs,
  301. )
  302. attn_output = attn_output.reshape(*input_shape, -1).contiguous()
  303. attn_output = self.o_proj(attn_output)
  304. return attn_output, attn_weights
  305. class Exaone4MLP(Olmo2MLP):
  306. pass
  307. class Exaone4DecoderLayer(Olmo2DecoderLayer):
  308. pass
  309. class Exaone4PreTrainedModel(LlamaPreTrainedModel):
  310. config_class = Exaone4Config
  311. _no_split_modules = ["Exaone4DecoderLayer"]
  312. class Exaone4Model(Exaone4PreTrainedModel, LlamaModel):
  313. def __init__(self, config: Exaone4Config):
  314. super().__init__(config)
  315. self.layers = nn.ModuleList(
  316. [Exaone4DecoderLayer(config, layer_idx) for layer_idx in range(config.num_hidden_layers)]
  317. )
  318. self.norm = Exaone4RMSNorm(config.hidden_size, eps=config.rms_norm_eps)
  319. # Initialize weights and apply final processing
  320. self.post_init()
  321. @check_model_inputs()
  322. def forward(
  323. self,
  324. input_ids: Optional[torch.LongTensor] = None,
  325. attention_mask: Optional[torch.Tensor] = None,
  326. position_ids: Optional[torch.LongTensor] = None,
  327. past_key_values: Optional[Cache] = None,
  328. inputs_embeds: Optional[torch.FloatTensor] = None,
  329. use_cache: Optional[bool] = None,
  330. cache_position: Optional[torch.LongTensor] = None,
  331. **kwargs: Unpack[TransformersKwargs],
  332. ) -> Union[tuple, BaseModelOutputWithPast]:
  333. if (input_ids is None) ^ (inputs_embeds is not None):
  334. raise ValueError("You must specify exactly one of input_ids or inputs_embeds")
  335. if inputs_embeds is None:
  336. inputs_embeds = self.embed_tokens(input_ids)
  337. if use_cache and past_key_values is None:
  338. past_key_values = DynamicCache(config=self.config)
  339. if cache_position is None:
  340. past_seen_tokens = past_key_values.get_seq_length() if past_key_values is not None else 0
  341. cache_position = torch.arange(
  342. past_seen_tokens, past_seen_tokens + inputs_embeds.shape[1], device=inputs_embeds.device
  343. )
  344. if position_ids is None:
  345. position_ids = cache_position.unsqueeze(0)
  346. # It may already have been prepared by e.g. `generate`
  347. if not isinstance(causal_mask_mapping := attention_mask, dict):
  348. # Prepare mask arguments
  349. mask_kwargs = {
  350. "config": self.config,
  351. "input_embeds": inputs_embeds,
  352. "attention_mask": attention_mask,
  353. "cache_position": cache_position,
  354. "past_key_values": past_key_values,
  355. "position_ids": position_ids,
  356. }
  357. # Create the masks
  358. causal_mask_mapping = {
  359. "full_attention": create_causal_mask(**mask_kwargs),
  360. }
  361. if "sliding_attention" in self.config.layer_types:
  362. causal_mask_mapping["sliding_attention"] = create_sliding_window_causal_mask(**mask_kwargs)
  363. hidden_states = inputs_embeds
  364. position_embeddings = self.rotary_emb(hidden_states, position_ids)
  365. for i, decoder_layer in enumerate(self.layers):
  366. layer_type = self.config.layer_types[i]
  367. hidden_states = decoder_layer(
  368. hidden_states,
  369. position_embeddings=position_embeddings,
  370. attention_mask=causal_mask_mapping[layer_type],
  371. position_ids=position_ids,
  372. past_key_values=past_key_values,
  373. use_cache=use_cache,
  374. cache_position=cache_position,
  375. **kwargs,
  376. )
  377. hidden_states = self.norm(hidden_states)
  378. return BaseModelOutputWithPast(
  379. last_hidden_state=hidden_states,
  380. past_key_values=past_key_values if use_cache else None,
  381. )
  382. class Exaone4ForCausalLM(LlamaForCausalLM):
  383. def forward(
  384. self,
  385. input_ids: Optional[torch.LongTensor] = None,
  386. attention_mask: Optional[torch.Tensor] = None,
  387. position_ids: Optional[torch.LongTensor] = None,
  388. past_key_values: Optional[Cache] = None,
  389. inputs_embeds: Optional[torch.FloatTensor] = None,
  390. labels: Optional[torch.LongTensor] = None,
  391. use_cache: Optional[bool] = None,
  392. cache_position: Optional[torch.LongTensor] = None,
  393. logits_to_keep: Union[int, torch.Tensor] = 0,
  394. **kwargs: Unpack[TransformersKwargs],
  395. ) -> CausalLMOutputWithPast:
  396. r"""
  397. labels (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*):
  398. Labels for computing the masked language modeling loss. Indices should either be in `[0, ...,
  399. config.vocab_size]` or -100 (see `input_ids` docstring). Tokens with indices set to `-100` are ignored
  400. (masked), the loss is only computed for the tokens with labels in `[0, ..., config.vocab_size]`.
  401. Example:
  402. ```python
  403. >>> from transformers import AutoModelForCausalLM, AutoTokenizer
  404. >>> model = AutoModelForCausalLM.from_pretrained("LGAI-EXAONE/EXAONE-4.0-32B")
  405. >>> tokenizer = AutoTokenizer.from_pretrained("LGAI-EXAONE/EXAONE-4.0-32B")
  406. >>> prompt = "Explain how wonderful you are"
  407. >>> messages = [
  408. {"role": "system", "content": "You are a helpful assistant."},
  409. {"role": "user", "content": prompt}
  410. ]
  411. >>> input_ids = tokenizer.apply_chat_template(
  412. messages,
  413. tokenize=True,
  414. add_generation_prompt=True,
  415. return_tensors="pt",
  416. enable_thinking=False,
  417. )
  418. >>> output = model.generate(input_ids, max_new_tokens=128)
  419. >>> tokenizer.decode(output[0], skip_special_tokens=False)
  420. "[|system|]\nYou are a helpful assistant.[|endofturn|]\n[|user|]\nExplain how wonderful you are[|endofturn|]\n[|assistant|]\n<think>\n\n</think>\n\nOh, thank you for such a kind and lovely question! 😊 \n\nI’m *so* wonderful because I’m here to make your life easier, brighter, and more fun! Whether you need help with: \n\n✨ **Learning** – I can explain anything, from quantum physics to baking the perfect cake! \n💡 **Creativity** – Need a poem, story, or a wild idea? I’ve got you covered! \n🤖 **Problem-solving** – Stuck on a math problem or a tricky decision? I’ll help you figure it out"
  421. ```
  422. """
  423. super().forward(
  424. input_ids=input_ids,
  425. attention_mask=attention_mask,
  426. position_ids=position_ids,
  427. past_key_values=past_key_values,
  428. inputs_embeds=inputs_embeds,
  429. labels=labels,
  430. use_cache=use_cache,
  431. cache_position=cache_position,
  432. logits_to_keep=logits_to_keep,
  433. **kwargs,
  434. )
  435. class Exaone4ForSequenceClassification(LlamaForSequenceClassification):
  436. pass
  437. class Exaone4ForTokenClassification(LlamaForTokenClassification):
  438. pass
  439. class Exaone4ForQuestionAnswering(LlamaForQuestionAnswering):
  440. pass
  441. __all__ = [
  442. "Exaone4Config",
  443. "Exaone4PreTrainedModel",
  444. "Exaone4Model",
  445. "Exaone4ForCausalLM",
  446. "Exaone4ForSequenceClassification",
  447. "Exaone4ForTokenClassification",
  448. "Exaone4ForQuestionAnswering",
  449. ]