modeling_olmo3.py 22 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509
  1. # 🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨
  2. # This file was automatically generated from src/transformers/models/olmo3/modular_olmo3.py.
  3. # Do NOT edit this file manually as any edits will be overwritten by the generation of
  4. # the file from the modular. If any change should be done, please apply the change to the
  5. # modular_olmo3.py file directly. One of our CI enforces this.
  6. # 🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨
  7. # coding=utf-8
  8. # Copyright 2025 the HuggingFace Team. All rights reserved.
  9. #
  10. # Licensed under the Apache License, Version 2.0 (the "License");
  11. # you may not use this file except in compliance with the License.
  12. # You may obtain a copy of the License at
  13. #
  14. # http://www.apache.org/licenses/LICENSE-2.0
  15. #
  16. # Unless required by applicable law or agreed to in writing, software
  17. # distributed under the License is distributed on an "AS IS" BASIS,
  18. # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  19. # See the License for the specific language governing permissions and
  20. # limitations under the License.
  21. from typing import Callable, Optional, Union
  22. import torch
  23. import torch.nn as nn
  24. from transformers.utils.generic import TransformersKwargs
  25. from ...activations import ACT2FN
  26. from ...cache_utils import Cache, DynamicCache
  27. from ...generation import GenerationMixin
  28. from ...integrations import use_kernel_forward_from_hub
  29. from ...masking_utils import create_causal_mask, create_sliding_window_causal_mask
  30. from ...modeling_layers import GradientCheckpointingLayer
  31. from ...modeling_outputs import BaseModelOutputWithPast, CausalLMOutputWithPast
  32. from ...modeling_rope_utils import ROPE_INIT_FUNCTIONS, dynamic_rope_update
  33. from ...modeling_utils import ALL_ATTENTION_FUNCTIONS, PreTrainedModel
  34. from ...processing_utils import Unpack
  35. from ...utils import auto_docstring, can_return_tuple
  36. from ...utils.deprecation import deprecate_kwarg
  37. from ...utils.generic import check_model_inputs
  38. from .configuration_olmo3 import Olmo3Config
  39. @use_kernel_forward_from_hub("RMSNorm")
  40. class Olmo3RMSNorm(nn.Module):
  41. def __init__(self, hidden_size, eps=1e-6):
  42. """
  43. Olmo3RMSNorm is equivalent to T5LayerNorm
  44. """
  45. super().__init__()
  46. self.weight = nn.Parameter(torch.ones(hidden_size))
  47. self.variance_epsilon = eps
  48. def forward(self, hidden_states):
  49. input_dtype = hidden_states.dtype
  50. hidden_states = hidden_states.to(torch.float32)
  51. variance = hidden_states.pow(2).mean(-1, keepdim=True)
  52. hidden_states = hidden_states * torch.rsqrt(variance + self.variance_epsilon)
  53. return (self.weight * hidden_states).to(input_dtype)
  54. def extra_repr(self):
  55. return f"{tuple(self.weight.shape)}, eps={self.variance_epsilon}"
  56. def repeat_kv(hidden_states: torch.Tensor, n_rep: int) -> torch.Tensor:
  57. """
  58. This is the equivalent of torch.repeat_interleave(x, dim=1, repeats=n_rep). The hidden states go from (batch,
  59. num_key_value_heads, seqlen, head_dim) to (batch, num_attention_heads, seqlen, head_dim)
  60. """
  61. batch, num_key_value_heads, slen, head_dim = hidden_states.shape
  62. if n_rep == 1:
  63. return hidden_states
  64. hidden_states = hidden_states[:, :, None, :, :].expand(batch, num_key_value_heads, n_rep, slen, head_dim)
  65. return hidden_states.reshape(batch, num_key_value_heads * n_rep, slen, head_dim)
  66. def eager_attention_forward(
  67. module: nn.Module,
  68. query: torch.Tensor,
  69. key: torch.Tensor,
  70. value: torch.Tensor,
  71. attention_mask: Optional[torch.Tensor],
  72. scaling: float,
  73. dropout: float = 0.0,
  74. **kwargs: Unpack[TransformersKwargs],
  75. ):
  76. key_states = repeat_kv(key, module.num_key_value_groups)
  77. value_states = repeat_kv(value, module.num_key_value_groups)
  78. attn_weights = torch.matmul(query, key_states.transpose(2, 3)) * scaling
  79. if attention_mask is not None:
  80. causal_mask = attention_mask[:, :, :, : key_states.shape[-2]]
  81. attn_weights = attn_weights + causal_mask
  82. attn_weights = nn.functional.softmax(attn_weights, dim=-1, dtype=torch.float32).to(query.dtype)
  83. attn_weights = nn.functional.dropout(attn_weights, p=dropout, training=module.training)
  84. attn_output = torch.matmul(attn_weights, value_states)
  85. attn_output = attn_output.transpose(1, 2).contiguous()
  86. return attn_output, attn_weights
  87. def apply_rotary_pos_emb(q, k, cos, sin, position_ids=None, unsqueeze_dim=1):
  88. """Applies Rotary Position Embedding to the query and key tensors.
  89. Args:
  90. q (`torch.Tensor`): The query tensor.
  91. k (`torch.Tensor`): The key tensor.
  92. cos (`torch.Tensor`): The cosine part of the rotary embedding.
  93. sin (`torch.Tensor`): The sine part of the rotary embedding.
  94. position_ids (`torch.Tensor`, *optional*):
  95. Deprecated and unused.
  96. unsqueeze_dim (`int`, *optional*, defaults to 1):
  97. The 'unsqueeze_dim' argument specifies the dimension along which to unsqueeze cos[position_ids] and
  98. sin[position_ids] so that they can be properly broadcasted to the dimensions of q and k. For example, note
  99. that cos[position_ids] and sin[position_ids] have the shape [batch_size, seq_len, head_dim]. Then, if q and
  100. k have the shape [batch_size, heads, seq_len, head_dim], then setting unsqueeze_dim=1 makes
  101. cos[position_ids] and sin[position_ids] broadcastable to the shapes of q and k. Similarly, if q and k have
  102. the shape [batch_size, seq_len, heads, head_dim], then set unsqueeze_dim=2.
  103. Returns:
  104. `tuple(torch.Tensor)` comprising of the query and key tensors rotated using the Rotary Position Embedding.
  105. """
  106. q_type, k_type = q.dtype, k.dtype
  107. cos = cos.unsqueeze(unsqueeze_dim)
  108. sin = sin.unsqueeze(unsqueeze_dim)
  109. q_embed = (q * cos) + (rotate_half(q) * sin)
  110. k_embed = (k * cos) + (rotate_half(k) * sin)
  111. return q_embed.to(q_type), k_embed.to(k_type)
  112. def rotate_half(x):
  113. """Rotates half the hidden dims of the input."""
  114. x1 = x[..., : x.shape[-1] // 2]
  115. x2 = x[..., x.shape[-1] // 2 :]
  116. return torch.cat((-x2, x1), dim=-1)
  117. class Olmo3Attention(nn.Module):
  118. """Multi-headed attention from 'Attention Is All You Need' paper"""
  119. def __init__(self, config: Olmo3Config, layer_idx: int):
  120. super().__init__()
  121. self.config = config
  122. self.layer_idx = layer_idx
  123. self.head_dim = getattr(config, "head_dim", config.hidden_size // config.num_attention_heads)
  124. self.num_key_value_groups = config.num_attention_heads // config.num_key_value_heads
  125. self.scaling = self.head_dim**-0.5
  126. self.attention_dropout = config.attention_dropout
  127. self.is_causal = True
  128. self.q_proj = nn.Linear(
  129. config.hidden_size, config.num_attention_heads * self.head_dim, bias=config.attention_bias
  130. )
  131. self.k_proj = nn.Linear(
  132. config.hidden_size, config.num_key_value_heads * self.head_dim, bias=config.attention_bias
  133. )
  134. self.v_proj = nn.Linear(
  135. config.hidden_size, config.num_key_value_heads * self.head_dim, bias=config.attention_bias
  136. )
  137. self.o_proj = nn.Linear(
  138. config.num_attention_heads * self.head_dim, config.hidden_size, bias=config.attention_bias
  139. )
  140. self.q_norm = Olmo3RMSNorm(config.num_attention_heads * self.head_dim, config.rms_norm_eps)
  141. self.k_norm = Olmo3RMSNorm(config.num_key_value_heads * self.head_dim, config.rms_norm_eps)
  142. assert config.layer_types is not None
  143. self.attention_type = config.layer_types[layer_idx]
  144. self.sliding_window = config.sliding_window if self.attention_type == "sliding_attention" else None
  145. @deprecate_kwarg("past_key_value", new_name="past_key_values", version="4.58")
  146. def forward(
  147. self,
  148. hidden_states: torch.Tensor,
  149. position_embeddings: tuple[torch.Tensor, torch.Tensor],
  150. attention_mask: Optional[torch.Tensor],
  151. past_key_values: Optional[Cache] = None,
  152. cache_position: Optional[torch.LongTensor] = None,
  153. **kwargs: Unpack[TransformersKwargs],
  154. ) -> tuple[torch.Tensor, Optional[torch.Tensor]]:
  155. input_shape = hidden_states.shape[:-1]
  156. hidden_shape = (*input_shape, -1, self.head_dim)
  157. query_states = self.q_norm(self.q_proj(hidden_states))
  158. key_states = self.k_norm(self.k_proj(hidden_states))
  159. value_states = self.v_proj(hidden_states)
  160. query_states = query_states.view(hidden_shape).transpose(1, 2)
  161. key_states = key_states.view(hidden_shape).transpose(1, 2)
  162. value_states = value_states.view(hidden_shape).transpose(1, 2)
  163. cos, sin = position_embeddings
  164. query_states, key_states = apply_rotary_pos_emb(query_states, key_states, cos, sin)
  165. if past_key_values is not None:
  166. # sin and cos are specific to RoPE models; cache_position needed for the static cache
  167. cache_kwargs = {"sin": sin, "cos": cos, "cache_position": cache_position}
  168. key_states, value_states = past_key_values.update(key_states, value_states, self.layer_idx, cache_kwargs)
  169. attention_interface: Callable = eager_attention_forward
  170. if self.config._attn_implementation != "eager":
  171. attention_interface = ALL_ATTENTION_FUNCTIONS[self.config._attn_implementation]
  172. attn_output, attn_weights = attention_interface(
  173. self,
  174. query_states,
  175. key_states,
  176. value_states,
  177. attention_mask,
  178. dropout=0.0 if not self.training else self.attention_dropout,
  179. scaling=self.scaling,
  180. sliding_window=self.sliding_window,
  181. **kwargs,
  182. )
  183. attn_output = attn_output.reshape(*input_shape, -1).contiguous()
  184. attn_output = self.o_proj(attn_output)
  185. return attn_output, attn_weights
  186. class Olmo3MLP(nn.Module):
  187. def __init__(self, config):
  188. super().__init__()
  189. self.config = config
  190. self.hidden_size = config.hidden_size
  191. self.intermediate_size = config.intermediate_size
  192. self.gate_proj = nn.Linear(self.hidden_size, self.intermediate_size, bias=False)
  193. self.up_proj = nn.Linear(self.hidden_size, self.intermediate_size, bias=False)
  194. self.down_proj = nn.Linear(self.intermediate_size, self.hidden_size, bias=False)
  195. self.act_fn = ACT2FN[config.hidden_act]
  196. def forward(self, x):
  197. down_proj = self.down_proj(self.act_fn(self.gate_proj(x)) * self.up_proj(x))
  198. return down_proj
  199. class Olmo3DecoderLayer(GradientCheckpointingLayer):
  200. def __init__(self, config: Olmo3Config, layer_idx: int):
  201. super().__init__()
  202. self.hidden_size = config.hidden_size
  203. self.self_attn = Olmo3Attention(config=config, layer_idx=layer_idx)
  204. self.mlp = Olmo3MLP(config)
  205. self.post_attention_layernorm = Olmo3RMSNorm(config.hidden_size, eps=config.rms_norm_eps)
  206. self.post_feedforward_layernorm = Olmo3RMSNorm(config.hidden_size, eps=config.rms_norm_eps)
  207. @deprecate_kwarg("past_key_value", new_name="past_key_values", version="4.58")
  208. def forward(
  209. self,
  210. hidden_states: torch.Tensor,
  211. attention_mask: Optional[torch.Tensor] = None,
  212. position_ids: Optional[torch.LongTensor] = None,
  213. past_key_values: Optional[Cache] = None,
  214. use_cache: Optional[bool] = False,
  215. cache_position: Optional[torch.LongTensor] = None,
  216. position_embeddings: Optional[tuple[torch.Tensor, torch.Tensor]] = None, # necessary, but kept here for BC
  217. **kwargs: Unpack[TransformersKwargs],
  218. ) -> torch.Tensor:
  219. residual = hidden_states
  220. hidden_states, _ = self.self_attn(
  221. hidden_states=hidden_states,
  222. attention_mask=attention_mask,
  223. position_ids=position_ids,
  224. past_key_values=past_key_values,
  225. use_cache=use_cache,
  226. cache_position=cache_position,
  227. position_embeddings=position_embeddings,
  228. **kwargs,
  229. )
  230. hidden_states = self.post_attention_layernorm(hidden_states)
  231. hidden_states = residual + hidden_states
  232. # Fully Connected
  233. residual = hidden_states
  234. hidden_states = self.mlp(hidden_states)
  235. hidden_states = self.post_feedforward_layernorm(hidden_states)
  236. hidden_states = residual + hidden_states
  237. return hidden_states
  238. class Olmo3RotaryEmbedding(nn.Module):
  239. inv_freq: torch.Tensor # fix linting for `register_buffer`
  240. def __init__(self, config: Olmo3Config, device=None, rope_type: Optional[str] = None):
  241. super().__init__()
  242. if rope_type is not None:
  243. self.rope_type = rope_type
  244. elif hasattr(config, "rope_scaling") and isinstance(config.rope_scaling, dict):
  245. # BC: "rope_type" was originally "type"
  246. self.rope_type = config.rope_scaling.get("rope_type", config.rope_scaling.get("type"))
  247. else:
  248. self.rope_type = "default"
  249. assert self.rope_type is not None
  250. self.max_seq_len_cached = config.max_position_embeddings
  251. self.original_max_seq_len = config.max_position_embeddings
  252. self.config = config
  253. self.rope_init_fn = ROPE_INIT_FUNCTIONS[self.rope_type]
  254. inv_freq, self.attention_scaling = self.rope_init_fn(self.config, device)
  255. self.register_buffer("inv_freq", inv_freq, persistent=False)
  256. self.original_inv_freq = self.inv_freq
  257. @torch.no_grad()
  258. @dynamic_rope_update # power user: used with advanced RoPE types (e.g. dynamic rope)
  259. def forward(self, x, position_ids):
  260. inv_freq_expanded = self.inv_freq[None, :, None].float().expand(position_ids.shape[0], -1, 1).to(x.device)
  261. position_ids_expanded = position_ids[:, None, :].float()
  262. device_type = x.device.type if isinstance(x.device.type, str) and x.device.type != "mps" else "cpu"
  263. with torch.autocast(device_type=device_type, enabled=False): # Force float32
  264. freqs = (inv_freq_expanded.float() @ position_ids_expanded.float()).transpose(1, 2)
  265. emb = torch.cat((freqs, freqs), dim=-1)
  266. cos = emb.cos() * self.attention_scaling
  267. sin = emb.sin() * self.attention_scaling
  268. return cos, sin
  269. @auto_docstring
  270. class Olmo3PreTrainedModel(PreTrainedModel):
  271. config: Olmo3Config
  272. base_model_prefix = "model"
  273. supports_gradient_checkpointing = True
  274. _no_split_modules = ["Olmo3DecoderLayer"]
  275. _skip_keys_device_placement = ["past_key_values"]
  276. _supports_flash_attn = True
  277. _supports_sdpa = True
  278. _supports_flex_attn = True
  279. _can_compile_fullgraph = True
  280. _supports_attention_backend = True
  281. _can_record_outputs = {
  282. "hidden_states": Olmo3DecoderLayer,
  283. "attentions": Olmo3Attention,
  284. }
  285. @auto_docstring
  286. class Olmo3Model(Olmo3PreTrainedModel):
  287. def __init__(self, config: Olmo3Config):
  288. super().__init__(config)
  289. self.padding_idx = config.pad_token_id
  290. self.vocab_size = config.vocab_size
  291. self.embed_tokens = nn.Embedding(config.vocab_size, config.hidden_size, self.padding_idx)
  292. self.layers = nn.ModuleList(
  293. [Olmo3DecoderLayer(config, layer_idx) for layer_idx in range(config.num_hidden_layers)]
  294. )
  295. self.norm = Olmo3RMSNorm(config.hidden_size, eps=config.rms_norm_eps)
  296. self.gradient_checkpointing = False
  297. self.rotary_embs = nn.ModuleDict(
  298. {
  299. "sliding_attention": Olmo3RotaryEmbedding(config=config, rope_type="default"),
  300. "full_attention": Olmo3RotaryEmbedding(config=config),
  301. }
  302. )
  303. # Initialize weights and apply final processing
  304. self.post_init()
  305. @check_model_inputs()
  306. @auto_docstring
  307. def forward(
  308. self,
  309. input_ids: Optional[torch.LongTensor] = None,
  310. attention_mask: Optional[torch.Tensor] = None,
  311. position_ids: Optional[torch.LongTensor] = None,
  312. past_key_values: Optional[Cache] = None,
  313. inputs_embeds: Optional[torch.FloatTensor] = None,
  314. cache_position: Optional[torch.LongTensor] = None,
  315. use_cache: Optional[bool] = None,
  316. **kwargs: Unpack[TransformersKwargs],
  317. ) -> BaseModelOutputWithPast:
  318. if (input_ids is None) ^ (inputs_embeds is not None):
  319. raise ValueError("You must specify exactly one of input_ids or inputs_embeds")
  320. if inputs_embeds is None:
  321. inputs_embeds: torch.Tensor = self.embed_tokens(input_ids)
  322. if use_cache and past_key_values is None:
  323. past_key_values = DynamicCache(config=self.config)
  324. if cache_position is None:
  325. past_seen_tokens = past_key_values.get_seq_length() if past_key_values is not None else 0
  326. cache_position: torch.Tensor = torch.arange(
  327. past_seen_tokens, past_seen_tokens + inputs_embeds.shape[1], device=inputs_embeds.device
  328. )
  329. if position_ids is None:
  330. position_ids = cache_position.unsqueeze(0)
  331. # It may already have been prepared by e.g. `generate`
  332. if not isinstance(causal_mask_mapping := attention_mask, dict):
  333. # Prepare mask arguments
  334. mask_kwargs = {
  335. "config": self.config,
  336. "input_embeds": inputs_embeds,
  337. "attention_mask": attention_mask,
  338. "cache_position": cache_position,
  339. "past_key_values": past_key_values,
  340. "position_ids": position_ids,
  341. }
  342. # Create the masks
  343. causal_mask_mapping = {
  344. "full_attention": create_causal_mask(**mask_kwargs),
  345. "sliding_attention": create_sliding_window_causal_mask(**mask_kwargs),
  346. }
  347. hidden_states = inputs_embeds
  348. position_embeddings_mapping = {
  349. "sliding_attention": self.rotary_embs["sliding_attention"](hidden_states, position_ids),
  350. "full_attention": self.rotary_embs["full_attention"](hidden_states, position_ids),
  351. }
  352. for decoder_layer in self.layers[: self.config.num_hidden_layers]:
  353. hidden_states = decoder_layer(
  354. hidden_states,
  355. attention_mask=causal_mask_mapping[decoder_layer.self_attn.attention_type],
  356. position_ids=position_ids,
  357. past_key_values=past_key_values,
  358. cache_position=cache_position,
  359. position_embeddings=position_embeddings_mapping[decoder_layer.self_attn.attention_type],
  360. **kwargs,
  361. )
  362. hidden_states = self.norm(hidden_states)
  363. return BaseModelOutputWithPast(
  364. last_hidden_state=hidden_states,
  365. past_key_values=past_key_values,
  366. )
  367. @auto_docstring
  368. class Olmo3ForCausalLM(Olmo3PreTrainedModel, GenerationMixin):
  369. _tied_weights_keys = ["lm_head.weight"]
  370. _tp_plan = {"lm_head": "colwise_rep"}
  371. _pp_plan = {"lm_head": (["hidden_states"], ["logits"])}
  372. def __init__(self, config):
  373. super().__init__(config)
  374. self.model = Olmo3Model(config)
  375. self.vocab_size = config.vocab_size
  376. self.lm_head = nn.Linear(config.hidden_size, config.vocab_size, bias=False)
  377. # Initialize weights and apply final processing
  378. self.post_init()
  379. @can_return_tuple
  380. @auto_docstring
  381. def forward(
  382. self,
  383. input_ids: Optional[torch.LongTensor] = None,
  384. attention_mask: Optional[torch.Tensor] = None,
  385. position_ids: Optional[torch.LongTensor] = None,
  386. past_key_values: Optional[Cache] = None,
  387. inputs_embeds: Optional[torch.FloatTensor] = None,
  388. labels: Optional[torch.LongTensor] = None,
  389. use_cache: Optional[bool] = None,
  390. cache_position: Optional[torch.LongTensor] = None,
  391. logits_to_keep: Union[int, torch.Tensor] = 0,
  392. **kwargs: Unpack[TransformersKwargs],
  393. ) -> CausalLMOutputWithPast:
  394. r"""
  395. Example:
  396. ```python
  397. >>> from transformers import AutoTokenizer, Olmo3ForCausalLM
  398. >>> model = Olmo3ForCausalLM.from_pretrained("meta-olmo3/Olmo3-2-7b-hf")
  399. >>> tokenizer = AutoTokenizer.from_pretrained("meta-olmo3/Olmo3-2-7b-hf")
  400. >>> prompt = "Hey, are you conscious? Can you talk to me?"
  401. >>> inputs = tokenizer(prompt, return_tensors="pt")
  402. >>> # Generate
  403. >>> generate_ids = model.generate(inputs.input_ids, max_length=30)
  404. >>> tokenizer.batch_decode(generate_ids, skip_special_tokens=True, clean_up_tokenization_spaces=False)[0]
  405. "Hey, are you conscious? Can you talk to me?\nI'm not conscious, but I can talk to you."
  406. ```"""
  407. outputs: BaseModelOutputWithPast = self.model(
  408. input_ids=input_ids,
  409. attention_mask=attention_mask,
  410. position_ids=position_ids,
  411. past_key_values=past_key_values,
  412. inputs_embeds=inputs_embeds,
  413. use_cache=use_cache,
  414. cache_position=cache_position,
  415. **kwargs,
  416. )
  417. hidden_states = outputs.last_hidden_state
  418. # Only compute necessary logits, and do not upcast them to float if we are not computing the loss
  419. slice_indices = slice(-logits_to_keep, None) if isinstance(logits_to_keep, int) else logits_to_keep
  420. logits = self.lm_head(hidden_states[:, slice_indices, :])
  421. loss = None
  422. if labels is not None:
  423. loss = self.loss_function(logits=logits, labels=labels, vocab_size=self.config.vocab_size, **kwargs)
  424. return CausalLMOutputWithPast(
  425. loss=loss,
  426. logits=logits,
  427. past_key_values=outputs.past_key_values,
  428. hidden_states=outputs.hidden_states,
  429. attentions=outputs.attentions,
  430. )
  431. __all__ = ["Olmo3ForCausalLM", "Olmo3Model", "Olmo3PreTrainedModel"]