modular_glm4.py 5.2 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139
  1. # coding=utf-8
  2. # Copyright 2025 The GLM4 & ZhipuAI team and HuggingFace Inc. team. All rights reserved.
  3. #
  4. #
  5. # Licensed under the Apache License, Version 2.0 (the "License");
  6. # you may not use this file except in compliance with the License.
  7. # You may obtain a copy of the License at
  8. #
  9. # http://www.apache.org/licenses/LICENSE-2.0
  10. #
  11. # Unless required by applicable law or agreed to in writing, software
  12. # distributed under the License is distributed on an "AS IS" BASIS,
  13. # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  14. # See the License for the specific language governing permissions and
  15. # limitations under the License.
  16. from typing import Optional, Union
  17. import torch
  18. from ...cache_utils import Cache
  19. from ...modeling_flash_attention_utils import FlashAttentionKwargs
  20. from ...modeling_layers import GradientCheckpointingLayer
  21. from ...modeling_outputs import CausalLMOutputWithPast
  22. from ...processing_utils import Unpack
  23. from ...utils import TransformersKwargs, logging
  24. from ...utils.deprecation import deprecate_kwarg
  25. from ..glm.modeling_glm import GlmAttention, GlmForCausalLM, GlmForSequenceClassification, GlmForTokenClassification
  26. from ..phi3.modeling_phi3 import Phi3MLP
  27. from .configuration_glm4 import Glm4Config
  28. from .modeling_glm4 import Glm4RMSNorm
  29. logger = logging.get_logger(__name__)
  30. _CHECKPOINT_FOR_DOC = "THUDM/GLM-4-9B-0414"
  31. class Glm4MLP(Phi3MLP):
  32. pass
  33. class Glm4DecoderLayer(GradientCheckpointingLayer):
  34. def __init__(self, config: Glm4Config, layer_idx: int):
  35. super().__init__()
  36. self.hidden_size = config.hidden_size
  37. self.self_attn = Glm4Attention(config=config, layer_idx=layer_idx)
  38. self.mlp = Glm4MLP(config)
  39. self.input_layernorm = Glm4RMSNorm(config.hidden_size, eps=config.rms_norm_eps)
  40. self.post_attention_layernorm = Glm4RMSNorm(config.hidden_size, eps=config.rms_norm_eps)
  41. self.post_self_attn_layernorm = Glm4RMSNorm(config.hidden_size, eps=config.rms_norm_eps)
  42. self.post_mlp_layernorm = Glm4RMSNorm(config.hidden_size, eps=config.rms_norm_eps)
  43. @deprecate_kwarg("past_key_value", new_name="past_key_values", version="4.58")
  44. def forward(
  45. self,
  46. hidden_states: torch.Tensor,
  47. attention_mask: Optional[torch.Tensor] = None,
  48. position_ids: Optional[torch.LongTensor] = None,
  49. past_key_values: Optional[Cache] = None,
  50. use_cache: Optional[bool] = False,
  51. cache_position: Optional[torch.LongTensor] = None,
  52. position_embeddings: Optional[tuple[torch.Tensor, torch.Tensor]] = None, # necessary, but kept here for BC
  53. **kwargs: Unpack[FlashAttentionKwargs],
  54. ) -> tuple[torch.FloatTensor, Optional[tuple[torch.FloatTensor, torch.FloatTensor]]]:
  55. residual = hidden_states
  56. hidden_states = self.input_layernorm(hidden_states)
  57. hidden_states, _ = self.self_attn(
  58. hidden_states=hidden_states,
  59. attention_mask=attention_mask,
  60. position_ids=position_ids,
  61. past_key_values=past_key_values,
  62. use_cache=use_cache,
  63. cache_position=cache_position,
  64. position_embeddings=position_embeddings,
  65. **kwargs,
  66. )
  67. hidden_states = self.post_self_attn_layernorm(hidden_states)
  68. hidden_states = residual + hidden_states
  69. residual = hidden_states
  70. hidden_states = self.post_attention_layernorm(hidden_states)
  71. hidden_states = self.mlp(hidden_states)
  72. hidden_states = self.post_mlp_layernorm(hidden_states)
  73. hidden_states = residual + hidden_states
  74. return hidden_states
  75. class Glm4Attention(GlmAttention):
  76. pass
  77. class Glm4ForCausalLM(GlmForCausalLM):
  78. def forward(
  79. self,
  80. **super_kwargs: Unpack[TransformersKwargs],
  81. ) -> Union[tuple, CausalLMOutputWithPast]:
  82. r"""
  83. labels (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*):
  84. Labels for computing the masked language modeling loss. Indices should either be in `[0, ...,
  85. config.vocab_size]` or -100 (see `input_ids` docstring). Tokens with indices set to `-100` are ignored
  86. (masked), the loss is only computed for the tokens with labels in `[0, ..., config.vocab_size]`.
  87. Example:
  88. ```python
  89. >>> from transformers import AutoTokenizer, Glm4ForCausalLM
  90. >>> model = Glm4ForCausalLM.from_pretrained("THUDM/GLM-4-9B-0414")
  91. >>> tokenizer = AutoTokenizer.from_pretrained("THUDM/GLM-4-9B-0414")
  92. >>> prompt = "Hey, are you conscious? Can you talk to me?"
  93. >>> inputs = tokenizer(prompt, return_tensors="pt")
  94. >>> # Generate
  95. >>> generate_ids = model.generate(inputs.input_ids, max_length=30)
  96. >>> tokenizer.batch_decode(generate_ids, skip_special_tokens=True, clean_up_tokenization_spaces=False)[0]
  97. "Hey, are you conscious? Can you talk to me?\nI'm not conscious, but I can talk to you."
  98. ```"""
  99. return super().forward(**super_kwargs)
  100. class Glm4ForSequenceClassification(GlmForSequenceClassification):
  101. pass
  102. class Glm4ForTokenClassification(GlmForTokenClassification):
  103. pass
  104. __all__ = [
  105. "Glm4PreTrainedModel", # noqa: F822
  106. "Glm4Model", # noqa: F822
  107. "Glm4ForCausalLM",
  108. "Glm4ForSequenceClassification",
  109. "Glm4ForTokenClassification",
  110. ]