modular_qwen2.py 9.3 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244
  1. from typing import Callable, Optional
  2. import torch
  3. from packaging import version
  4. from torch import nn
  5. from ...cache_utils import Cache, DynamicCache
  6. from ...integrations import use_kernel_forward_from_hub
  7. from ...masking_utils import create_causal_mask, create_sliding_window_causal_mask
  8. from ...modeling_flash_attention_utils import FlashAttentionKwargs
  9. from ...modeling_outputs import (
  10. BaseModelOutputWithPast,
  11. )
  12. from ...modeling_utils import ALL_ATTENTION_FUNCTIONS
  13. from ...processing_utils import Unpack
  14. from ...utils import TransformersKwargs, auto_docstring, logging
  15. from ...utils.deprecation import deprecate_kwarg
  16. from ...utils.generic import check_model_inputs
  17. from ...utils.import_utils import get_torch_version
  18. from ..llama.modeling_llama import (
  19. LlamaAttention,
  20. LlamaDecoderLayer,
  21. LlamaForCausalLM,
  22. LlamaForQuestionAnswering,
  23. LlamaForSequenceClassification,
  24. LlamaForTokenClassification,
  25. LlamaMLP,
  26. LlamaPreTrainedModel,
  27. apply_rotary_pos_emb,
  28. eager_attention_forward,
  29. )
  30. from ..mistral.modeling_mistral import MistralModel
  31. from .configuration_qwen2 import Qwen2Config
  32. logger = logging.get_logger(__name__)
  33. class Qwen2MLP(LlamaMLP):
  34. def __init__(self, config):
  35. super().__init__(config)
  36. self.gate_proj = nn.Linear(self.hidden_size, self.intermediate_size, bias=False)
  37. self.up_proj = nn.Linear(self.hidden_size, self.intermediate_size, bias=False)
  38. self.down_proj = nn.Linear(self.intermediate_size, self.hidden_size, bias=False)
  39. class Qwen2Attention(LlamaAttention):
  40. def __init__(self, config: Qwen2Config, layer_idx: int):
  41. super().__init__(config, layer_idx)
  42. self.q_proj = nn.Linear(config.hidden_size, config.num_attention_heads * self.head_dim, bias=True)
  43. self.k_proj = nn.Linear(config.hidden_size, config.num_key_value_heads * self.head_dim, bias=True)
  44. self.v_proj = nn.Linear(config.hidden_size, config.num_key_value_heads * self.head_dim, bias=True)
  45. self.o_proj = nn.Linear(config.num_attention_heads * self.head_dim, config.hidden_size, bias=False)
  46. self.sliding_window = config.sliding_window if config.layer_types[layer_idx] == "sliding_attention" else None
  47. @deprecate_kwarg("past_key_value", new_name="past_key_values", version="4.58")
  48. def forward(
  49. self,
  50. hidden_states: torch.Tensor,
  51. position_embeddings: tuple[torch.Tensor, torch.Tensor],
  52. attention_mask: Optional[torch.Tensor],
  53. past_key_values: Optional[Cache] = None,
  54. cache_position: Optional[torch.LongTensor] = None,
  55. **kwargs: Unpack[FlashAttentionKwargs],
  56. ) -> tuple[torch.Tensor, Optional[torch.Tensor]]:
  57. input_shape = hidden_states.shape[:-1]
  58. hidden_shape = (*input_shape, -1, self.head_dim)
  59. query_states = self.q_proj(hidden_states).view(hidden_shape).transpose(1, 2)
  60. key_states = self.k_proj(hidden_states).view(hidden_shape).transpose(1, 2)
  61. value_states = self.v_proj(hidden_states).view(hidden_shape).transpose(1, 2)
  62. cos, sin = position_embeddings
  63. query_states, key_states = apply_rotary_pos_emb(query_states, key_states, cos, sin)
  64. if past_key_values is not None:
  65. # sin and cos are specific to RoPE models; cache_position needed for the static cache
  66. cache_kwargs = {"sin": sin, "cos": cos, "cache_position": cache_position}
  67. key_states, value_states = past_key_values.update(key_states, value_states, self.layer_idx, cache_kwargs)
  68. attention_interface: Callable = eager_attention_forward
  69. if self.config._attn_implementation != "eager":
  70. attention_interface = ALL_ATTENTION_FUNCTIONS[self.config._attn_implementation]
  71. attn_output, attn_weights = attention_interface(
  72. self,
  73. query_states,
  74. key_states,
  75. value_states,
  76. attention_mask,
  77. dropout=0.0 if not self.training else self.attention_dropout,
  78. scaling=self.scaling,
  79. sliding_window=self.sliding_window, # main diff with Llama
  80. **kwargs,
  81. )
  82. attn_output = attn_output.reshape(*input_shape, -1).contiguous()
  83. attn_output = self.o_proj(attn_output)
  84. return attn_output, attn_weights
  85. if version.parse(get_torch_version()) >= version.parse("2.3.0"):
  86. class Qwen2RMSNorm(nn.RMSNorm):
  87. def __init__(self, hidden_size, eps: float = 1e-6) -> None:
  88. super().__init__(normalized_shape=hidden_size, eps=eps, elementwise_affine=True)
  89. else:
  90. @use_kernel_forward_from_hub("RMSNorm")
  91. class Qwen2RMSNorm(nn.Module):
  92. def __init__(self, hidden_size, eps: float = 1e-6) -> None:
  93. """
  94. Qwen2RMSNorm is equivalent to T5LayerNorm
  95. """
  96. super().__init__()
  97. self.weight = nn.Parameter(torch.ones(hidden_size))
  98. self.variance_epsilon = eps
  99. def forward(self, hidden_states: torch.Tensor) -> torch.Tensor:
  100. input_dtype = hidden_states.dtype
  101. hidden_states = hidden_states.to(torch.float32)
  102. variance = hidden_states.pow(2).mean(-1, keepdim=True)
  103. hidden_states = hidden_states * torch.rsqrt(variance + self.variance_epsilon)
  104. return self.weight * hidden_states.to(input_dtype)
  105. def extra_repr(self):
  106. return f"{tuple(self.weight.shape)}, eps={self.variance_epsilon}"
  107. class Qwen2DecoderLayer(LlamaDecoderLayer):
  108. def __init__(self, config: Qwen2Config, layer_idx: int):
  109. super().__init__(config=config, layer_idx=layer_idx)
  110. self.attention_type = config.layer_types[layer_idx]
  111. class Qwen2PreTrainedModel(LlamaPreTrainedModel):
  112. pass
  113. class Qwen2Model(MistralModel):
  114. def __init__(self, config: Qwen2Config):
  115. super().__init__(config)
  116. self.has_sliding_layers = "sliding_attention" in self.config.layer_types
  117. @check_model_inputs()
  118. @auto_docstring
  119. def forward(
  120. self,
  121. input_ids: Optional[torch.LongTensor] = None,
  122. attention_mask: Optional[torch.Tensor] = None,
  123. position_ids: Optional[torch.LongTensor] = None,
  124. past_key_values: Optional[Cache] = None,
  125. inputs_embeds: Optional[torch.FloatTensor] = None,
  126. use_cache: Optional[bool] = None,
  127. cache_position: Optional[torch.LongTensor] = None,
  128. **kwargs: Unpack[TransformersKwargs],
  129. ) -> BaseModelOutputWithPast:
  130. if (input_ids is None) ^ (inputs_embeds is not None):
  131. raise ValueError("You must specify exactly one of input_ids or inputs_embeds")
  132. if inputs_embeds is None:
  133. inputs_embeds = self.embed_tokens(input_ids)
  134. if use_cache and past_key_values is None:
  135. past_key_values = DynamicCache(config=self.config)
  136. if cache_position is None:
  137. past_seen_tokens = past_key_values.get_seq_length() if past_key_values is not None else 0
  138. cache_position = torch.arange(
  139. past_seen_tokens, past_seen_tokens + inputs_embeds.shape[1], device=inputs_embeds.device
  140. )
  141. if position_ids is None:
  142. position_ids = cache_position.unsqueeze(0)
  143. # It may already have been prepared by e.g. `generate`
  144. if not isinstance(causal_mask_mapping := attention_mask, dict):
  145. # Prepare mask arguments
  146. mask_kwargs = {
  147. "config": self.config,
  148. "input_embeds": inputs_embeds,
  149. "attention_mask": attention_mask,
  150. "cache_position": cache_position,
  151. "past_key_values": past_key_values,
  152. "position_ids": position_ids,
  153. }
  154. # Create the masks
  155. causal_mask_mapping = {
  156. "full_attention": create_causal_mask(**mask_kwargs),
  157. }
  158. # The sliding window alternating layers are not always activated depending on the config
  159. if self.has_sliding_layers:
  160. causal_mask_mapping["sliding_attention"] = create_sliding_window_causal_mask(**mask_kwargs)
  161. hidden_states = inputs_embeds
  162. # create position embeddings to be shared across the decoder layers
  163. position_embeddings = self.rotary_emb(hidden_states, position_ids)
  164. for decoder_layer in self.layers[: self.config.num_hidden_layers]:
  165. hidden_states = decoder_layer(
  166. hidden_states,
  167. attention_mask=causal_mask_mapping[decoder_layer.attention_type],
  168. position_ids=position_ids,
  169. past_key_values=past_key_values,
  170. use_cache=use_cache,
  171. cache_position=cache_position,
  172. position_embeddings=position_embeddings,
  173. **kwargs,
  174. )
  175. hidden_states = self.norm(hidden_states)
  176. return BaseModelOutputWithPast(
  177. last_hidden_state=hidden_states,
  178. past_key_values=past_key_values if use_cache else None,
  179. )
  180. class Qwen2ForCausalLM(LlamaForCausalLM):
  181. pass
  182. class Qwen2ForSequenceClassification(LlamaForSequenceClassification):
  183. pass
  184. class Qwen2ForTokenClassification(LlamaForTokenClassification):
  185. pass
  186. class Qwen2ForQuestionAnswering(LlamaForQuestionAnswering):
  187. pass
  188. __all__ = [
  189. "Qwen2PreTrainedModel",
  190. "Qwen2Model",
  191. "Qwen2ForCausalLM",
  192. "Qwen2RMSNorm",
  193. "Qwen2ForSequenceClassification",
  194. "Qwen2ForTokenClassification",
  195. "Qwen2ForQuestionAnswering",
  196. ]