modeling_glm.py 21 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512
  1. # 🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨
  2. # This file was automatically generated from src/transformers/models/glm/modular_glm.py.
  3. # Do NOT edit this file manually as any edits will be overwritten by the generation of
  4. # the file from the modular. If any change should be done, please apply the change to the
  5. # modular_glm.py file directly. One of our CI enforces this.
  6. # 🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨
  7. # coding=utf-8
  8. # Copyright 2024 The GLM & ZhipuAI team and HuggingFace Inc. team. All rights reserved.
  9. #
  10. #
  11. # Licensed under the Apache License, Version 2.0 (the "License");
  12. # you may not use this file except in compliance with the License.
  13. # You may obtain a copy of the License at
  14. #
  15. # http://www.apache.org/licenses/LICENSE-2.0
  16. #
  17. # Unless required by applicable law or agreed to in writing, software
  18. # distributed under the License is distributed on an "AS IS" BASIS,
  19. # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  20. # See the License for the specific language governing permissions and
  21. # limitations under the License.
  22. from typing import Callable, Optional, Union
  23. import torch
  24. import torch.nn as nn
  25. from ...activations import ACT2FN
  26. from ...cache_utils import Cache, DynamicCache
  27. from ...generation import GenerationMixin
  28. from ...integrations import use_kernel_forward_from_hub
  29. from ...masking_utils import create_causal_mask
  30. from ...modeling_layers import (
  31. GenericForSequenceClassification,
  32. GenericForTokenClassification,
  33. GradientCheckpointingLayer,
  34. )
  35. from ...modeling_outputs import BaseModelOutputWithPast, CausalLMOutputWithPast
  36. from ...modeling_rope_utils import ROPE_INIT_FUNCTIONS, dynamic_rope_update
  37. from ...modeling_utils import ALL_ATTENTION_FUNCTIONS, PreTrainedModel
  38. from ...processing_utils import Unpack
  39. from ...utils import TransformersKwargs, auto_docstring, can_return_tuple
  40. from ...utils.deprecation import deprecate_kwarg
  41. from ...utils.generic import check_model_inputs
  42. from .configuration_glm import GlmConfig
  43. class GlmMLP(nn.Module):
  44. def __init__(self, config):
  45. super().__init__()
  46. self.config = config
  47. self.gate_up_proj = nn.Linear(config.hidden_size, 2 * config.intermediate_size, bias=False)
  48. self.down_proj = nn.Linear(config.intermediate_size, config.hidden_size, bias=False)
  49. self.activation_fn = ACT2FN[config.hidden_act]
  50. def forward(self, hidden_states: torch.FloatTensor) -> torch.FloatTensor:
  51. up_states = self.gate_up_proj(hidden_states)
  52. gate, up_states = up_states.chunk(2, dim=-1)
  53. up_states = up_states * self.activation_fn(gate)
  54. return self.down_proj(up_states)
  55. def repeat_kv(hidden_states: torch.Tensor, n_rep: int) -> torch.Tensor:
  56. """
  57. This is the equivalent of torch.repeat_interleave(x, dim=1, repeats=n_rep). The hidden states go from (batch,
  58. num_key_value_heads, seqlen, head_dim) to (batch, num_attention_heads, seqlen, head_dim)
  59. """
  60. batch, num_key_value_heads, slen, head_dim = hidden_states.shape
  61. if n_rep == 1:
  62. return hidden_states
  63. hidden_states = hidden_states[:, :, None, :, :].expand(batch, num_key_value_heads, n_rep, slen, head_dim)
  64. return hidden_states.reshape(batch, num_key_value_heads * n_rep, slen, head_dim)
  65. def eager_attention_forward(
  66. module: nn.Module,
  67. query: torch.Tensor,
  68. key: torch.Tensor,
  69. value: torch.Tensor,
  70. attention_mask: Optional[torch.Tensor],
  71. scaling: float,
  72. dropout: float = 0.0,
  73. **kwargs: Unpack[TransformersKwargs],
  74. ):
  75. key_states = repeat_kv(key, module.num_key_value_groups)
  76. value_states = repeat_kv(value, module.num_key_value_groups)
  77. attn_weights = torch.matmul(query, key_states.transpose(2, 3)) * scaling
  78. if attention_mask is not None:
  79. causal_mask = attention_mask[:, :, :, : key_states.shape[-2]]
  80. attn_weights = attn_weights + causal_mask
  81. attn_weights = nn.functional.softmax(attn_weights, dim=-1, dtype=torch.float32).to(query.dtype)
  82. attn_weights = nn.functional.dropout(attn_weights, p=dropout, training=module.training)
  83. attn_output = torch.matmul(attn_weights, value_states)
  84. attn_output = attn_output.transpose(1, 2).contiguous()
  85. return attn_output, attn_weights
  86. def rotate_half(x):
  87. """Rotates half the hidden dims of the input."""
  88. x1 = x[..., 0::2]
  89. x2 = x[..., 1::2]
  90. return torch.stack((-x2, x1), dim=-1).flatten(-2)
  91. def apply_rotary_pos_emb(q, k, cos, sin, position_ids=None, unsqueeze_dim=1):
  92. """Applies Rotary Position Embedding to the query and key tensors.
  93. Args:
  94. q (`torch.Tensor`): The query tensor.
  95. k (`torch.Tensor`): The key tensor.
  96. cos (`torch.Tensor`): The cosine part of the rotary embedding.
  97. sin (`torch.Tensor`): The sine part of the rotary embedding.
  98. position_ids (`torch.Tensor`, *optional*):
  99. Deprecated and unused.
  100. unsqueeze_dim (`int`, *optional*, defaults to 1):
  101. The 'unsqueeze_dim' argument specifies the dimension along which to unsqueeze cos[position_ids] and
  102. sin[position_ids] so that they can be properly broadcasted to the dimensions of q and k. For example, note
  103. that cos[position_ids] and sin[position_ids] have the shape [batch_size, seq_len, head_dim]. Then, if q and
  104. k have the shape [batch_size, heads, seq_len, head_dim], then setting unsqueeze_dim=1 makes
  105. cos[position_ids] and sin[position_ids] broadcastable to the shapes of q and k. Similarly, if q and k have
  106. the shape [batch_size, seq_len, heads, head_dim], then set unsqueeze_dim=2.
  107. Returns:
  108. `tuple(torch.Tensor)` comprising of the query and key tensors rotated using the Rotary Position Embedding.
  109. """
  110. cos = cos.unsqueeze(unsqueeze_dim)
  111. sin = sin.unsqueeze(unsqueeze_dim)
  112. # Interleave them instead of usual shape
  113. cos = cos[..., : cos.shape[-1] // 2].repeat_interleave(2, dim=-1)
  114. sin = sin[..., : sin.shape[-1] // 2].repeat_interleave(2, dim=-1)
  115. # Keep half or full tensor for later concatenation
  116. rotary_dim = cos.shape[-1]
  117. q_rot, q_pass = q[..., :rotary_dim], q[..., rotary_dim:]
  118. k_rot, k_pass = k[..., :rotary_dim], k[..., rotary_dim:]
  119. # Apply rotary embeddings on the first half or full tensor
  120. q_embed = (q_rot * cos) + (rotate_half(q_rot) * sin)
  121. k_embed = (k_rot * cos) + (rotate_half(k_rot) * sin)
  122. # Concatenate back to full shape
  123. q_embed = torch.cat([q_embed, q_pass], dim=-1)
  124. k_embed = torch.cat([k_embed, k_pass], dim=-1)
  125. return q_embed, k_embed
  126. class GlmAttention(nn.Module):
  127. """Multi-headed attention from 'Attention Is All You Need' paper"""
  128. def __init__(self, config: GlmConfig, layer_idx: Optional[int] = None):
  129. super().__init__()
  130. self.config = config
  131. self.layer_idx = layer_idx
  132. self.head_dim = getattr(config, "head_dim", config.hidden_size // config.num_attention_heads)
  133. self.num_key_value_groups = config.num_attention_heads // config.num_key_value_heads
  134. self.scaling = self.head_dim**-0.5
  135. self.attention_dropout = config.attention_dropout
  136. self.is_causal = True
  137. self.q_proj = nn.Linear(
  138. config.hidden_size, config.num_attention_heads * self.head_dim, bias=config.attention_bias
  139. )
  140. self.k_proj = nn.Linear(
  141. config.hidden_size, config.num_key_value_heads * self.head_dim, bias=config.attention_bias
  142. )
  143. self.v_proj = nn.Linear(
  144. config.hidden_size, config.num_key_value_heads * self.head_dim, bias=config.attention_bias
  145. )
  146. self.o_proj = nn.Linear(config.num_attention_heads * self.head_dim, config.hidden_size, bias=False)
  147. @deprecate_kwarg("past_key_value", new_name="past_key_values", version="4.58")
  148. def forward(
  149. self,
  150. hidden_states: torch.Tensor,
  151. position_embeddings: tuple[torch.Tensor, torch.Tensor],
  152. attention_mask: Optional[torch.Tensor],
  153. past_key_values: Optional[Cache] = None,
  154. cache_position: Optional[torch.LongTensor] = None,
  155. **kwargs: Unpack[TransformersKwargs],
  156. ) -> tuple[torch.Tensor, torch.Tensor]:
  157. input_shape = hidden_states.shape[:-1]
  158. hidden_shape = (*input_shape, -1, self.head_dim)
  159. query_states = self.q_proj(hidden_states).view(hidden_shape).transpose(1, 2)
  160. key_states = self.k_proj(hidden_states).view(hidden_shape).transpose(1, 2)
  161. value_states = self.v_proj(hidden_states).view(hidden_shape).transpose(1, 2)
  162. cos, sin = position_embeddings
  163. query_states, key_states = apply_rotary_pos_emb(query_states, key_states, cos, sin)
  164. if past_key_values is not None:
  165. # sin and cos are specific to RoPE models; cache_position needed for the static cache
  166. cache_kwargs = {"sin": sin, "cos": cos, "cache_position": cache_position}
  167. key_states, value_states = past_key_values.update(key_states, value_states, self.layer_idx, cache_kwargs)
  168. attention_interface: Callable = eager_attention_forward
  169. if self.config._attn_implementation != "eager":
  170. attention_interface = ALL_ATTENTION_FUNCTIONS[self.config._attn_implementation]
  171. attn_output, attn_weights = attention_interface(
  172. self,
  173. query_states,
  174. key_states,
  175. value_states,
  176. attention_mask,
  177. dropout=0.0 if not self.training else self.attention_dropout,
  178. scaling=self.scaling,
  179. **kwargs,
  180. )
  181. attn_output = attn_output.reshape(*input_shape, -1).contiguous()
  182. attn_output = self.o_proj(attn_output)
  183. return attn_output, attn_weights
  184. @use_kernel_forward_from_hub("RMSNorm")
  185. class GlmRMSNorm(nn.Module):
  186. def __init__(self, hidden_size, eps=1e-6):
  187. """
  188. GlmRMSNorm is equivalent to T5LayerNorm
  189. """
  190. super().__init__()
  191. self.weight = nn.Parameter(torch.ones(hidden_size))
  192. self.variance_epsilon = eps
  193. def forward(self, hidden_states):
  194. input_dtype = hidden_states.dtype
  195. hidden_states = hidden_states.to(torch.float32)
  196. variance = hidden_states.pow(2).mean(-1, keepdim=True)
  197. hidden_states = hidden_states * torch.rsqrt(variance + self.variance_epsilon)
  198. return self.weight * hidden_states.to(input_dtype)
  199. def extra_repr(self):
  200. return f"{tuple(self.weight.shape)}, eps={self.variance_epsilon}"
  201. class GlmRotaryEmbedding(nn.Module):
  202. inv_freq: torch.Tensor # fix linting for `register_buffer`
  203. def __init__(self, config: GlmConfig, device=None):
  204. super().__init__()
  205. # BC: "rope_type" was originally "type"
  206. if hasattr(config, "rope_scaling") and isinstance(config.rope_scaling, dict):
  207. self.rope_type = config.rope_scaling.get("rope_type", config.rope_scaling.get("type"))
  208. else:
  209. self.rope_type = "default"
  210. self.max_seq_len_cached = config.max_position_embeddings
  211. self.original_max_seq_len = config.max_position_embeddings
  212. self.config = config
  213. self.rope_init_fn = ROPE_INIT_FUNCTIONS[self.rope_type]
  214. inv_freq, self.attention_scaling = self.rope_init_fn(self.config, device)
  215. self.register_buffer("inv_freq", inv_freq, persistent=False)
  216. self.original_inv_freq = self.inv_freq
  217. @torch.no_grad()
  218. @dynamic_rope_update # power user: used with advanced RoPE types (e.g. dynamic rope)
  219. def forward(self, x, position_ids):
  220. inv_freq_expanded = self.inv_freq[None, :, None].float().expand(position_ids.shape[0], -1, 1).to(x.device)
  221. position_ids_expanded = position_ids[:, None, :].float()
  222. device_type = x.device.type if isinstance(x.device.type, str) and x.device.type != "mps" else "cpu"
  223. with torch.autocast(device_type=device_type, enabled=False): # Force float32
  224. freqs = (inv_freq_expanded.float() @ position_ids_expanded.float()).transpose(1, 2)
  225. emb = torch.cat((freqs, freqs), dim=-1)
  226. cos = emb.cos() * self.attention_scaling
  227. sin = emb.sin() * self.attention_scaling
  228. return cos.to(dtype=x.dtype), sin.to(dtype=x.dtype)
  229. class GlmDecoderLayer(GradientCheckpointingLayer):
  230. def __init__(self, config: GlmConfig, layer_idx: int):
  231. super().__init__()
  232. self.hidden_size = config.hidden_size
  233. self.self_attn = GlmAttention(config=config, layer_idx=layer_idx)
  234. self.mlp = GlmMLP(config)
  235. self.input_layernorm = GlmRMSNorm(config.hidden_size, eps=config.rms_norm_eps)
  236. self.post_attention_layernorm = GlmRMSNorm(config.hidden_size, eps=config.rms_norm_eps)
  237. @deprecate_kwarg("past_key_value", new_name="past_key_values", version="4.58")
  238. def forward(
  239. self,
  240. hidden_states: torch.Tensor,
  241. attention_mask: Optional[torch.Tensor] = None,
  242. position_ids: Optional[torch.LongTensor] = None,
  243. past_key_values: Optional[Cache] = None,
  244. use_cache: Optional[bool] = False,
  245. cache_position: Optional[torch.LongTensor] = None,
  246. position_embeddings: Optional[tuple[torch.Tensor, torch.Tensor]] = None, # necessary, but kept here for BC
  247. **kwargs: Unpack[TransformersKwargs],
  248. ) -> torch.Tensor:
  249. residual = hidden_states
  250. hidden_states = self.input_layernorm(hidden_states)
  251. # Self Attention
  252. hidden_states, _ = self.self_attn(
  253. hidden_states=hidden_states,
  254. attention_mask=attention_mask,
  255. position_ids=position_ids,
  256. past_key_values=past_key_values,
  257. use_cache=use_cache,
  258. cache_position=cache_position,
  259. position_embeddings=position_embeddings,
  260. **kwargs,
  261. )
  262. hidden_states = residual + hidden_states
  263. # Fully Connected
  264. residual = hidden_states
  265. hidden_states = self.post_attention_layernorm(hidden_states)
  266. hidden_states = self.mlp(hidden_states)
  267. hidden_states = residual + hidden_states
  268. return hidden_states
  269. @auto_docstring
  270. class GlmPreTrainedModel(PreTrainedModel):
  271. config: GlmConfig
  272. base_model_prefix = "model"
  273. supports_gradient_checkpointing = True
  274. _no_split_modules = ["GlmDecoderLayer"]
  275. _skip_keys_device_placement = ["past_key_values"]
  276. _supports_flash_attn = True
  277. _supports_sdpa = True
  278. _supports_flex_attn = True
  279. _can_compile_fullgraph = True
  280. _supports_attention_backend = True
  281. _can_record_outputs = {
  282. "hidden_states": GlmDecoderLayer,
  283. "attentions": GlmAttention,
  284. }
  285. @auto_docstring
  286. class GlmModel(GlmPreTrainedModel):
  287. def __init__(self, config: GlmConfig):
  288. super().__init__(config)
  289. self.padding_idx = config.pad_token_id
  290. self.vocab_size = config.vocab_size
  291. self.embed_tokens = nn.Embedding(config.vocab_size, config.hidden_size, self.padding_idx)
  292. self.layers = nn.ModuleList(
  293. [GlmDecoderLayer(config, layer_idx) for layer_idx in range(config.num_hidden_layers)]
  294. )
  295. self.norm = GlmRMSNorm(config.hidden_size, eps=config.rms_norm_eps)
  296. self.rotary_emb = GlmRotaryEmbedding(config=config)
  297. self.gradient_checkpointing = False
  298. # Initialize weights and apply final processing
  299. self.post_init()
  300. @check_model_inputs()
  301. @auto_docstring
  302. def forward(
  303. self,
  304. input_ids: Optional[torch.LongTensor] = None,
  305. attention_mask: Optional[torch.Tensor] = None,
  306. position_ids: Optional[torch.LongTensor] = None,
  307. past_key_values: Optional[Cache] = None,
  308. inputs_embeds: Optional[torch.FloatTensor] = None,
  309. cache_position: Optional[torch.LongTensor] = None,
  310. use_cache: Optional[bool] = None,
  311. **kwargs: Unpack[TransformersKwargs],
  312. ) -> BaseModelOutputWithPast:
  313. if (input_ids is None) ^ (inputs_embeds is not None):
  314. raise ValueError("You must specify exactly one of input_ids or inputs_embeds")
  315. if inputs_embeds is None:
  316. inputs_embeds: torch.Tensor = self.embed_tokens(input_ids)
  317. if use_cache and past_key_values is None:
  318. past_key_values = DynamicCache(config=self.config)
  319. if cache_position is None:
  320. past_seen_tokens = past_key_values.get_seq_length() if past_key_values is not None else 0
  321. cache_position: torch.Tensor = torch.arange(
  322. past_seen_tokens, past_seen_tokens + inputs_embeds.shape[1], device=inputs_embeds.device
  323. )
  324. if position_ids is None:
  325. position_ids = cache_position.unsqueeze(0)
  326. causal_mask = create_causal_mask(
  327. config=self.config,
  328. input_embeds=inputs_embeds,
  329. attention_mask=attention_mask,
  330. cache_position=cache_position,
  331. past_key_values=past_key_values,
  332. position_ids=position_ids,
  333. )
  334. hidden_states = inputs_embeds
  335. position_embeddings = self.rotary_emb(hidden_states, position_ids)
  336. for decoder_layer in self.layers[: self.config.num_hidden_layers]:
  337. hidden_states = decoder_layer(
  338. hidden_states,
  339. attention_mask=causal_mask,
  340. position_ids=position_ids,
  341. past_key_values=past_key_values,
  342. cache_position=cache_position,
  343. position_embeddings=position_embeddings,
  344. **kwargs,
  345. )
  346. hidden_states = self.norm(hidden_states)
  347. return BaseModelOutputWithPast(
  348. last_hidden_state=hidden_states,
  349. past_key_values=past_key_values,
  350. )
  351. @auto_docstring
  352. class GlmForCausalLM(GlmPreTrainedModel, GenerationMixin):
  353. _tied_weights_keys = ["lm_head.weight"]
  354. _tp_plan = {"lm_head": "colwise_rep"}
  355. _pp_plan = {"lm_head": (["hidden_states"], ["logits"])}
  356. def __init__(self, config):
  357. super().__init__(config)
  358. self.model = GlmModel(config)
  359. self.vocab_size = config.vocab_size
  360. self.lm_head = nn.Linear(config.hidden_size, config.vocab_size, bias=False)
  361. # Initialize weights and apply final processing
  362. self.post_init()
  363. @can_return_tuple
  364. @auto_docstring
  365. def forward(
  366. self,
  367. input_ids: Optional[torch.LongTensor] = None,
  368. attention_mask: Optional[torch.Tensor] = None,
  369. position_ids: Optional[torch.LongTensor] = None,
  370. past_key_values: Optional[Cache] = None,
  371. inputs_embeds: Optional[torch.FloatTensor] = None,
  372. labels: Optional[torch.LongTensor] = None,
  373. use_cache: Optional[bool] = None,
  374. cache_position: Optional[torch.LongTensor] = None,
  375. logits_to_keep: Union[int, torch.Tensor] = 0,
  376. **kwargs: Unpack[TransformersKwargs],
  377. ) -> CausalLMOutputWithPast:
  378. r"""
  379. Example:
  380. ```python
  381. >>> from transformers import AutoTokenizer, GlmForCausalLM
  382. >>> model = GlmForCausalLM.from_pretrained("meta-glm/Glm-2-7b-hf")
  383. >>> tokenizer = AutoTokenizer.from_pretrained("meta-glm/Glm-2-7b-hf")
  384. >>> prompt = "Hey, are you conscious? Can you talk to me?"
  385. >>> inputs = tokenizer(prompt, return_tensors="pt")
  386. >>> # Generate
  387. >>> generate_ids = model.generate(inputs.input_ids, max_length=30)
  388. >>> tokenizer.batch_decode(generate_ids, skip_special_tokens=True, clean_up_tokenization_spaces=False)[0]
  389. "Hey, are you conscious? Can you talk to me?\nI'm not conscious, but I can talk to you."
  390. ```"""
  391. outputs: BaseModelOutputWithPast = self.model(
  392. input_ids=input_ids,
  393. attention_mask=attention_mask,
  394. position_ids=position_ids,
  395. past_key_values=past_key_values,
  396. inputs_embeds=inputs_embeds,
  397. use_cache=use_cache,
  398. cache_position=cache_position,
  399. **kwargs,
  400. )
  401. hidden_states = outputs.last_hidden_state
  402. # Only compute necessary logits, and do not upcast them to float if we are not computing the loss
  403. slice_indices = slice(-logits_to_keep, None) if isinstance(logits_to_keep, int) else logits_to_keep
  404. logits = self.lm_head(hidden_states[:, slice_indices, :])
  405. loss = None
  406. if labels is not None:
  407. loss = self.loss_function(logits=logits, labels=labels, vocab_size=self.config.vocab_size, **kwargs)
  408. return CausalLMOutputWithPast(
  409. loss=loss,
  410. logits=logits,
  411. past_key_values=outputs.past_key_values,
  412. hidden_states=outputs.hidden_states,
  413. attentions=outputs.attentions,
  414. )
  415. class GlmForSequenceClassification(GenericForSequenceClassification, GlmPreTrainedModel):
  416. pass
  417. class GlmForTokenClassification(GenericForTokenClassification, GlmPreTrainedModel):
  418. pass
  419. __all__ = [
  420. "GlmPreTrainedModel",
  421. "GlmModel",
  422. "GlmForCausalLM",
  423. "GlmForSequenceClassification",
  424. "GlmForTokenClassification",
  425. ]