modeling_glm4.py 22 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521
  1. # 🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨
  2. # This file was automatically generated from src/transformers/models/glm4/modular_glm4.py.
  3. # Do NOT edit this file manually as any edits will be overwritten by the generation of
  4. # the file from the modular. If any change should be done, please apply the change to the
  5. # modular_glm4.py file directly. One of our CI enforces this.
  6. # 🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨
  7. # coding=utf-8
  8. # Copyright 2025 The GLM4 & ZhipuAI team and HuggingFace Inc. team. All rights reserved.
  9. #
  10. #
  11. # Licensed under the Apache License, Version 2.0 (the "License");
  12. # you may not use this file except in compliance with the License.
  13. # You may obtain a copy of the License at
  14. #
  15. # http://www.apache.org/licenses/LICENSE-2.0
  16. #
  17. # Unless required by applicable law or agreed to in writing, software
  18. # distributed under the License is distributed on an "AS IS" BASIS,
  19. # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  20. # See the License for the specific language governing permissions and
  21. # limitations under the License.
  22. from typing import Callable, Optional, Union
  23. import torch
  24. import torch.nn as nn
  25. from ...activations import ACT2FN
  26. from ...cache_utils import Cache, DynamicCache
  27. from ...generation import GenerationMixin
  28. from ...integrations import use_kernel_forward_from_hub
  29. from ...masking_utils import create_causal_mask
  30. from ...modeling_flash_attention_utils import FlashAttentionKwargs
  31. from ...modeling_layers import (
  32. GenericForSequenceClassification,
  33. GenericForTokenClassification,
  34. GradientCheckpointingLayer,
  35. )
  36. from ...modeling_outputs import BaseModelOutputWithPast, CausalLMOutputWithPast
  37. from ...modeling_rope_utils import ROPE_INIT_FUNCTIONS, dynamic_rope_update
  38. from ...modeling_utils import ALL_ATTENTION_FUNCTIONS, PreTrainedModel
  39. from ...processing_utils import Unpack
  40. from ...utils import TransformersKwargs, auto_docstring, can_return_tuple
  41. from ...utils.deprecation import deprecate_kwarg
  42. from ...utils.generic import check_model_inputs
  43. from .configuration_glm4 import Glm4Config
  44. class Glm4MLP(nn.Module):
  45. def __init__(self, config):
  46. super().__init__()
  47. self.config = config
  48. self.gate_up_proj = nn.Linear(config.hidden_size, 2 * config.intermediate_size, bias=False)
  49. self.down_proj = nn.Linear(config.intermediate_size, config.hidden_size, bias=False)
  50. self.activation_fn = ACT2FN[config.hidden_act]
  51. def forward(self, hidden_states: torch.FloatTensor) -> torch.FloatTensor:
  52. up_states = self.gate_up_proj(hidden_states)
  53. gate, up_states = up_states.chunk(2, dim=-1)
  54. up_states = up_states * self.activation_fn(gate)
  55. return self.down_proj(up_states)
  56. class Glm4DecoderLayer(GradientCheckpointingLayer):
  57. def __init__(self, config: Glm4Config, layer_idx: int):
  58. super().__init__()
  59. self.hidden_size = config.hidden_size
  60. self.self_attn = Glm4Attention(config=config, layer_idx=layer_idx)
  61. self.mlp = Glm4MLP(config)
  62. self.input_layernorm = Glm4RMSNorm(config.hidden_size, eps=config.rms_norm_eps)
  63. self.post_attention_layernorm = Glm4RMSNorm(config.hidden_size, eps=config.rms_norm_eps)
  64. self.post_self_attn_layernorm = Glm4RMSNorm(config.hidden_size, eps=config.rms_norm_eps)
  65. self.post_mlp_layernorm = Glm4RMSNorm(config.hidden_size, eps=config.rms_norm_eps)
  66. @deprecate_kwarg("past_key_value", new_name="past_key_values", version="4.58")
  67. def forward(
  68. self,
  69. hidden_states: torch.Tensor,
  70. attention_mask: Optional[torch.Tensor] = None,
  71. position_ids: Optional[torch.LongTensor] = None,
  72. past_key_values: Optional[Cache] = None,
  73. use_cache: Optional[bool] = False,
  74. cache_position: Optional[torch.LongTensor] = None,
  75. position_embeddings: Optional[tuple[torch.Tensor, torch.Tensor]] = None, # necessary, but kept here for BC
  76. **kwargs: Unpack[FlashAttentionKwargs],
  77. ) -> tuple[torch.FloatTensor, Optional[tuple[torch.FloatTensor, torch.FloatTensor]]]:
  78. residual = hidden_states
  79. hidden_states = self.input_layernorm(hidden_states)
  80. hidden_states, _ = self.self_attn(
  81. hidden_states=hidden_states,
  82. attention_mask=attention_mask,
  83. position_ids=position_ids,
  84. past_key_values=past_key_values,
  85. use_cache=use_cache,
  86. cache_position=cache_position,
  87. position_embeddings=position_embeddings,
  88. **kwargs,
  89. )
  90. hidden_states = self.post_self_attn_layernorm(hidden_states)
  91. hidden_states = residual + hidden_states
  92. residual = hidden_states
  93. hidden_states = self.post_attention_layernorm(hidden_states)
  94. hidden_states = self.mlp(hidden_states)
  95. hidden_states = self.post_mlp_layernorm(hidden_states)
  96. hidden_states = residual + hidden_states
  97. return hidden_states
  98. def repeat_kv(hidden_states: torch.Tensor, n_rep: int) -> torch.Tensor:
  99. """
  100. This is the equivalent of torch.repeat_interleave(x, dim=1, repeats=n_rep). The hidden states go from (batch,
  101. num_key_value_heads, seqlen, head_dim) to (batch, num_attention_heads, seqlen, head_dim)
  102. """
  103. batch, num_key_value_heads, slen, head_dim = hidden_states.shape
  104. if n_rep == 1:
  105. return hidden_states
  106. hidden_states = hidden_states[:, :, None, :, :].expand(batch, num_key_value_heads, n_rep, slen, head_dim)
  107. return hidden_states.reshape(batch, num_key_value_heads * n_rep, slen, head_dim)
  108. def eager_attention_forward(
  109. module: nn.Module,
  110. query: torch.Tensor,
  111. key: torch.Tensor,
  112. value: torch.Tensor,
  113. attention_mask: Optional[torch.Tensor],
  114. scaling: float,
  115. dropout: float = 0.0,
  116. **kwargs: Unpack[TransformersKwargs],
  117. ):
  118. key_states = repeat_kv(key, module.num_key_value_groups)
  119. value_states = repeat_kv(value, module.num_key_value_groups)
  120. attn_weights = torch.matmul(query, key_states.transpose(2, 3)) * scaling
  121. if attention_mask is not None:
  122. causal_mask = attention_mask[:, :, :, : key_states.shape[-2]]
  123. attn_weights = attn_weights + causal_mask
  124. attn_weights = nn.functional.softmax(attn_weights, dim=-1, dtype=torch.float32).to(query.dtype)
  125. attn_weights = nn.functional.dropout(attn_weights, p=dropout, training=module.training)
  126. attn_output = torch.matmul(attn_weights, value_states)
  127. attn_output = attn_output.transpose(1, 2).contiguous()
  128. return attn_output, attn_weights
  129. def rotate_half(x):
  130. """Rotates half the hidden dims of the input."""
  131. x1 = x[..., 0::2]
  132. x2 = x[..., 1::2]
  133. return torch.stack((-x2, x1), dim=-1).flatten(-2)
  134. def apply_rotary_pos_emb(q, k, cos, sin, position_ids=None, unsqueeze_dim=1):
  135. """Applies Rotary Position Embedding to the query and key tensors.
  136. Args:
  137. q (`torch.Tensor`): The query tensor.
  138. k (`torch.Tensor`): The key tensor.
  139. cos (`torch.Tensor`): The cosine part of the rotary embedding.
  140. sin (`torch.Tensor`): The sine part of the rotary embedding.
  141. position_ids (`torch.Tensor`, *optional*):
  142. Deprecated and unused.
  143. unsqueeze_dim (`int`, *optional*, defaults to 1):
  144. The 'unsqueeze_dim' argument specifies the dimension along which to unsqueeze cos[position_ids] and
  145. sin[position_ids] so that they can be properly broadcasted to the dimensions of q and k. For example, note
  146. that cos[position_ids] and sin[position_ids] have the shape [batch_size, seq_len, head_dim]. Then, if q and
  147. k have the shape [batch_size, heads, seq_len, head_dim], then setting unsqueeze_dim=1 makes
  148. cos[position_ids] and sin[position_ids] broadcastable to the shapes of q and k. Similarly, if q and k have
  149. the shape [batch_size, seq_len, heads, head_dim], then set unsqueeze_dim=2.
  150. Returns:
  151. `tuple(torch.Tensor)` comprising of the query and key tensors rotated using the Rotary Position Embedding.
  152. """
  153. cos = cos.unsqueeze(unsqueeze_dim)
  154. sin = sin.unsqueeze(unsqueeze_dim)
  155. # Interleave them instead of usual shape
  156. cos = cos[..., : cos.shape[-1] // 2].repeat_interleave(2, dim=-1)
  157. sin = sin[..., : sin.shape[-1] // 2].repeat_interleave(2, dim=-1)
  158. # Keep half or full tensor for later concatenation
  159. rotary_dim = cos.shape[-1]
  160. q_rot, q_pass = q[..., :rotary_dim], q[..., rotary_dim:]
  161. k_rot, k_pass = k[..., :rotary_dim], k[..., rotary_dim:]
  162. # Apply rotary embeddings on the first half or full tensor
  163. q_embed = (q_rot * cos) + (rotate_half(q_rot) * sin)
  164. k_embed = (k_rot * cos) + (rotate_half(k_rot) * sin)
  165. # Concatenate back to full shape
  166. q_embed = torch.cat([q_embed, q_pass], dim=-1)
  167. k_embed = torch.cat([k_embed, k_pass], dim=-1)
  168. return q_embed, k_embed
  169. class Glm4Attention(nn.Module):
  170. """Multi-headed attention from 'Attention Is All You Need' paper"""
  171. def __init__(self, config: Glm4Config, layer_idx: Optional[int] = None):
  172. super().__init__()
  173. self.config = config
  174. self.layer_idx = layer_idx
  175. self.head_dim = getattr(config, "head_dim", config.hidden_size // config.num_attention_heads)
  176. self.num_key_value_groups = config.num_attention_heads // config.num_key_value_heads
  177. self.scaling = self.head_dim**-0.5
  178. self.attention_dropout = config.attention_dropout
  179. self.is_causal = True
  180. self.q_proj = nn.Linear(
  181. config.hidden_size, config.num_attention_heads * self.head_dim, bias=config.attention_bias
  182. )
  183. self.k_proj = nn.Linear(
  184. config.hidden_size, config.num_key_value_heads * self.head_dim, bias=config.attention_bias
  185. )
  186. self.v_proj = nn.Linear(
  187. config.hidden_size, config.num_key_value_heads * self.head_dim, bias=config.attention_bias
  188. )
  189. self.o_proj = nn.Linear(config.num_attention_heads * self.head_dim, config.hidden_size, bias=False)
  190. @deprecate_kwarg("past_key_value", new_name="past_key_values", version="4.58")
  191. def forward(
  192. self,
  193. hidden_states: torch.Tensor,
  194. position_embeddings: tuple[torch.Tensor, torch.Tensor],
  195. attention_mask: Optional[torch.Tensor],
  196. past_key_values: Optional[Cache] = None,
  197. cache_position: Optional[torch.LongTensor] = None,
  198. **kwargs: Unpack[TransformersKwargs],
  199. ) -> tuple[torch.Tensor, torch.Tensor]:
  200. input_shape = hidden_states.shape[:-1]
  201. hidden_shape = (*input_shape, -1, self.head_dim)
  202. query_states = self.q_proj(hidden_states).view(hidden_shape).transpose(1, 2)
  203. key_states = self.k_proj(hidden_states).view(hidden_shape).transpose(1, 2)
  204. value_states = self.v_proj(hidden_states).view(hidden_shape).transpose(1, 2)
  205. cos, sin = position_embeddings
  206. query_states, key_states = apply_rotary_pos_emb(query_states, key_states, cos, sin)
  207. if past_key_values is not None:
  208. # sin and cos are specific to RoPE models; cache_position needed for the static cache
  209. cache_kwargs = {"sin": sin, "cos": cos, "cache_position": cache_position}
  210. key_states, value_states = past_key_values.update(key_states, value_states, self.layer_idx, cache_kwargs)
  211. attention_interface: Callable = eager_attention_forward
  212. if self.config._attn_implementation != "eager":
  213. attention_interface = ALL_ATTENTION_FUNCTIONS[self.config._attn_implementation]
  214. attn_output, attn_weights = attention_interface(
  215. self,
  216. query_states,
  217. key_states,
  218. value_states,
  219. attention_mask,
  220. dropout=0.0 if not self.training else self.attention_dropout,
  221. scaling=self.scaling,
  222. **kwargs,
  223. )
  224. attn_output = attn_output.reshape(*input_shape, -1).contiguous()
  225. attn_output = self.o_proj(attn_output)
  226. return attn_output, attn_weights
  227. @use_kernel_forward_from_hub("RMSNorm")
  228. class Glm4RMSNorm(nn.Module):
  229. def __init__(self, hidden_size, eps=1e-6):
  230. """
  231. Glm4RMSNorm is equivalent to T5LayerNorm
  232. """
  233. super().__init__()
  234. self.weight = nn.Parameter(torch.ones(hidden_size))
  235. self.variance_epsilon = eps
  236. def forward(self, hidden_states):
  237. input_dtype = hidden_states.dtype
  238. hidden_states = hidden_states.to(torch.float32)
  239. variance = hidden_states.pow(2).mean(-1, keepdim=True)
  240. hidden_states = hidden_states * torch.rsqrt(variance + self.variance_epsilon)
  241. return self.weight * hidden_states.to(input_dtype)
  242. def extra_repr(self):
  243. return f"{tuple(self.weight.shape)}, eps={self.variance_epsilon}"
  244. class Glm4RotaryEmbedding(nn.Module):
  245. inv_freq: torch.Tensor # fix linting for `register_buffer`
  246. def __init__(self, config: Glm4Config, device=None):
  247. super().__init__()
  248. # BC: "rope_type" was originally "type"
  249. if hasattr(config, "rope_scaling") and isinstance(config.rope_scaling, dict):
  250. self.rope_type = config.rope_scaling.get("rope_type", config.rope_scaling.get("type"))
  251. else:
  252. self.rope_type = "default"
  253. self.max_seq_len_cached = config.max_position_embeddings
  254. self.original_max_seq_len = config.max_position_embeddings
  255. self.config = config
  256. self.rope_init_fn = ROPE_INIT_FUNCTIONS[self.rope_type]
  257. inv_freq, self.attention_scaling = self.rope_init_fn(self.config, device)
  258. self.register_buffer("inv_freq", inv_freq, persistent=False)
  259. self.original_inv_freq = self.inv_freq
  260. @torch.no_grad()
  261. @dynamic_rope_update # power user: used with advanced RoPE types (e.g. dynamic rope)
  262. def forward(self, x, position_ids):
  263. inv_freq_expanded = self.inv_freq[None, :, None].float().expand(position_ids.shape[0], -1, 1).to(x.device)
  264. position_ids_expanded = position_ids[:, None, :].float()
  265. device_type = x.device.type if isinstance(x.device.type, str) and x.device.type != "mps" else "cpu"
  266. with torch.autocast(device_type=device_type, enabled=False): # Force float32
  267. freqs = (inv_freq_expanded.float() @ position_ids_expanded.float()).transpose(1, 2)
  268. emb = torch.cat((freqs, freqs), dim=-1)
  269. cos = emb.cos() * self.attention_scaling
  270. sin = emb.sin() * self.attention_scaling
  271. return cos.to(dtype=x.dtype), sin.to(dtype=x.dtype)
  272. @auto_docstring
  273. class Glm4PreTrainedModel(PreTrainedModel):
  274. config: Glm4Config
  275. base_model_prefix = "model"
  276. supports_gradient_checkpointing = True
  277. _no_split_modules = ["Glm4DecoderLayer"]
  278. _skip_keys_device_placement = ["past_key_values"]
  279. _supports_flash_attn = True
  280. _supports_sdpa = True
  281. _supports_flex_attn = True
  282. _can_compile_fullgraph = True
  283. _supports_attention_backend = True
  284. _can_record_outputs = {
  285. "hidden_states": Glm4DecoderLayer,
  286. "attentions": Glm4Attention,
  287. }
  288. @auto_docstring
  289. class Glm4Model(Glm4PreTrainedModel):
  290. def __init__(self, config: Glm4Config):
  291. super().__init__(config)
  292. self.padding_idx = config.pad_token_id
  293. self.vocab_size = config.vocab_size
  294. self.embed_tokens = nn.Embedding(config.vocab_size, config.hidden_size, self.padding_idx)
  295. self.layers = nn.ModuleList(
  296. [Glm4DecoderLayer(config, layer_idx) for layer_idx in range(config.num_hidden_layers)]
  297. )
  298. self.norm = Glm4RMSNorm(config.hidden_size, eps=config.rms_norm_eps)
  299. self.rotary_emb = Glm4RotaryEmbedding(config=config)
  300. self.gradient_checkpointing = False
  301. # Initialize weights and apply final processing
  302. self.post_init()
  303. @check_model_inputs()
  304. @auto_docstring
  305. def forward(
  306. self,
  307. input_ids: Optional[torch.LongTensor] = None,
  308. attention_mask: Optional[torch.Tensor] = None,
  309. position_ids: Optional[torch.LongTensor] = None,
  310. past_key_values: Optional[Cache] = None,
  311. inputs_embeds: Optional[torch.FloatTensor] = None,
  312. cache_position: Optional[torch.LongTensor] = None,
  313. use_cache: Optional[bool] = None,
  314. **kwargs: Unpack[TransformersKwargs],
  315. ) -> BaseModelOutputWithPast:
  316. if (input_ids is None) ^ (inputs_embeds is not None):
  317. raise ValueError("You must specify exactly one of input_ids or inputs_embeds")
  318. if inputs_embeds is None:
  319. inputs_embeds: torch.Tensor = self.embed_tokens(input_ids)
  320. if use_cache and past_key_values is None:
  321. past_key_values = DynamicCache(config=self.config)
  322. if cache_position is None:
  323. past_seen_tokens = past_key_values.get_seq_length() if past_key_values is not None else 0
  324. cache_position: torch.Tensor = torch.arange(
  325. past_seen_tokens, past_seen_tokens + inputs_embeds.shape[1], device=inputs_embeds.device
  326. )
  327. if position_ids is None:
  328. position_ids = cache_position.unsqueeze(0)
  329. causal_mask = create_causal_mask(
  330. config=self.config,
  331. input_embeds=inputs_embeds,
  332. attention_mask=attention_mask,
  333. cache_position=cache_position,
  334. past_key_values=past_key_values,
  335. position_ids=position_ids,
  336. )
  337. hidden_states = inputs_embeds
  338. position_embeddings = self.rotary_emb(hidden_states, position_ids)
  339. for decoder_layer in self.layers[: self.config.num_hidden_layers]:
  340. hidden_states = decoder_layer(
  341. hidden_states,
  342. attention_mask=causal_mask,
  343. position_ids=position_ids,
  344. past_key_values=past_key_values,
  345. cache_position=cache_position,
  346. position_embeddings=position_embeddings,
  347. **kwargs,
  348. )
  349. hidden_states = self.norm(hidden_states)
  350. return BaseModelOutputWithPast(
  351. last_hidden_state=hidden_states,
  352. past_key_values=past_key_values,
  353. )
  354. @auto_docstring
  355. class Glm4ForCausalLM(Glm4PreTrainedModel, GenerationMixin):
  356. _tied_weights_keys = ["lm_head.weight"]
  357. _tp_plan = {"lm_head": "colwise_rep"}
  358. _pp_plan = {"lm_head": (["hidden_states"], ["logits"])}
  359. def __init__(self, config):
  360. super().__init__(config)
  361. self.model = Glm4Model(config)
  362. self.vocab_size = config.vocab_size
  363. self.lm_head = nn.Linear(config.hidden_size, config.vocab_size, bias=False)
  364. # Initialize weights and apply final processing
  365. self.post_init()
  366. @can_return_tuple
  367. @auto_docstring
  368. def forward(
  369. self,
  370. input_ids: Optional[torch.LongTensor] = None,
  371. attention_mask: Optional[torch.Tensor] = None,
  372. position_ids: Optional[torch.LongTensor] = None,
  373. past_key_values: Optional[Cache] = None,
  374. inputs_embeds: Optional[torch.FloatTensor] = None,
  375. labels: Optional[torch.LongTensor] = None,
  376. use_cache: Optional[bool] = None,
  377. cache_position: Optional[torch.LongTensor] = None,
  378. logits_to_keep: Union[int, torch.Tensor] = 0,
  379. **kwargs: Unpack[TransformersKwargs],
  380. ) -> Union[tuple, CausalLMOutputWithPast]:
  381. r"""
  382. labels (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*):
  383. Labels for computing the masked language modeling loss. Indices should either be in `[0, ...,
  384. config.vocab_size]` or -100 (see `input_ids` docstring). Tokens with indices set to `-100` are ignored
  385. (masked), the loss is only computed for the tokens with labels in `[0, ..., config.vocab_size]`.
  386. Example:
  387. ```python
  388. >>> from transformers import AutoTokenizer, Glm4ForCausalLM
  389. >>> model = Glm4ForCausalLM.from_pretrained("THUDM/GLM-4-9B-0414")
  390. >>> tokenizer = AutoTokenizer.from_pretrained("THUDM/GLM-4-9B-0414")
  391. >>> prompt = "Hey, are you conscious? Can you talk to me?"
  392. >>> inputs = tokenizer(prompt, return_tensors="pt")
  393. >>> # Generate
  394. >>> generate_ids = model.generate(inputs.input_ids, max_length=30)
  395. >>> tokenizer.batch_decode(generate_ids, skip_special_tokens=True, clean_up_tokenization_spaces=False)[0]
  396. "Hey, are you conscious? Can you talk to me?\nI'm not conscious, but I can talk to you."
  397. ```"""
  398. outputs: BaseModelOutputWithPast = self.model(
  399. input_ids=input_ids,
  400. attention_mask=attention_mask,
  401. position_ids=position_ids,
  402. past_key_values=past_key_values,
  403. inputs_embeds=inputs_embeds,
  404. use_cache=use_cache,
  405. cache_position=cache_position,
  406. **kwargs,
  407. )
  408. hidden_states = outputs.last_hidden_state
  409. # Only compute necessary logits, and do not upcast them to float if we are not computing the loss
  410. slice_indices = slice(-logits_to_keep, None) if isinstance(logits_to_keep, int) else logits_to_keep
  411. logits = self.lm_head(hidden_states[:, slice_indices, :])
  412. loss = None
  413. if labels is not None:
  414. loss = self.loss_function(logits=logits, labels=labels, vocab_size=self.config.vocab_size, **kwargs)
  415. return CausalLMOutputWithPast(
  416. loss=loss,
  417. logits=logits,
  418. past_key_values=outputs.past_key_values,
  419. hidden_states=outputs.hidden_states,
  420. attentions=outputs.attentions,
  421. )
  422. class Glm4ForSequenceClassification(GenericForSequenceClassification, Glm4PreTrainedModel):
  423. pass
  424. class Glm4ForTokenClassification(GenericForTokenClassification, Glm4PreTrainedModel):
  425. pass
  426. __all__ = [
  427. "Glm4PreTrainedModel",
  428. "Glm4Model",
  429. "Glm4ForCausalLM",
  430. "Glm4ForSequenceClassification",
  431. "Glm4ForTokenClassification",
  432. ]