| 123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525526527528529530531532533534535536537538539540541542543544545546547548549550551552553554555556557558559560561562563564565566567568569570571572573574575576577578579580581582583584585586587 |
- # coding=utf-8
- # Copyright 2024 Google Inc. HuggingFace Inc. team. All rights reserved.
- #
- #
- # Licensed under the Apache License, Version 2.0 (the "License");
- # you may not use this file except in compliance with the License.
- # You may obtain a copy of the License at
- #
- # http://www.apache.org/licenses/LICENSE-2.0
- #
- # Unless required by applicable law or agreed to in writing, software
- # distributed under the License is distributed on an "AS IS" BASIS,
- # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
- # See the License for the specific language governing permissions and
- # limitations under the License.
- from typing import Callable, Optional, Union
- import torch
- import torch.nn as nn
- from ...activations import ACT2FN
- from ...cache_utils import Cache, DynamicCache
- from ...configuration_utils import PretrainedConfig, layer_type_validation
- from ...masking_utils import create_causal_mask, create_sliding_window_causal_mask
- from ...modeling_flash_attention_utils import FlashAttentionKwargs
- from ...modeling_layers import GradientCheckpointingLayer
- from ...modeling_outputs import BaseModelOutputWithPast, CausalLMOutputWithPast
- from ...modeling_utils import ALL_ATTENTION_FUNCTIONS
- from ...processing_utils import Unpack
- from ...utils import TransformersKwargs, logging
- from ...utils.deprecation import deprecate_kwarg
- from ..gemma.modeling_gemma import (
- GemmaAttention,
- GemmaForCausalLM,
- GemmaForSequenceClassification,
- GemmaForTokenClassification,
- GemmaMLP,
- GemmaModel,
- GemmaPreTrainedModel,
- GemmaRMSNorm,
- GemmaRotaryEmbedding,
- apply_rotary_pos_emb,
- repeat_kv,
- )
- logger = logging.get_logger(__name__)
- class Gemma2Config(PretrainedConfig):
- r"""
- This is the configuration class to store the configuration of a [`Gemma2Model`]. It is used to instantiate an Gemma2
- model according to the specified arguments, defining the model architecture. Instantiating a configuration with the
- defaults will yield a similar configuration to that of the Gemma2-7B.
- e.g. [google/gemma2-7b](https://huggingface.co/google/gemma2-7b)
- Configuration objects inherit from [`PretrainedConfig`] and can be used to control the model outputs. Read the
- documentation from [`PretrainedConfig`] for more information.
- Args:
- vocab_size (`int`, *optional*, defaults to 256000):
- Vocabulary size of the Gemma2 model. Defines the number of different tokens that can be represented by the
- `inputs_ids` passed when calling [`Gemma2Model`]
- hidden_size (`int`, *optional*, defaults to 2304):
- Dimension of the hidden representations.
- intermediate_size (`int`, *optional*, defaults to 9216):
- Dimension of the MLP representations.
- num_hidden_layers (`int`, *optional*, defaults to 26):
- Number of hidden layers in the Transformer decoder.
- num_attention_heads (`int`, *optional*, defaults to 8):
- Number of attention heads for each attention layer in the Transformer decoder.
- num_key_value_heads (`int`, *optional*, defaults to 4):
- This is the number of key_value heads that should be used to implement Grouped Query Attention. If
- `num_key_value_heads=num_attention_heads`, the model will use Multi Head Attention (MHA), if
- `num_key_value_heads=1` the model will use Multi Query Attention (MQA) otherwise GQA is used. When
- converting a multi-head checkpoint to a GQA checkpoint, each group key and value head should be constructed
- by meanpooling all the original heads within that group. For more details, check out [this
- paper](https://huggingface.co/papers/2305.13245). If it is not specified, will default to
- `num_attention_heads`.
- head_dim (`int`, *optional*, defaults to 256):
- The attention head dimension.
- hidden_activation (`str` or `function`, *optional*, defaults to `"gelu_pytorch_tanh"`):
- The non-linear activation function (function or string) in the decoder. Will default to `"gelu_pytorch_tanh"`
- if not specified. `"gelu_pytorch_tanh"` uses an approximation of the `"gelu"` activation function.
- max_position_embeddings (`int`, *optional*, defaults to 8192):
- The maximum sequence length that this model might ever be used with.
- initializer_range (`float`, *optional*, defaults to 0.02):
- The standard deviation of the truncated_normal_initializer for initializing all weight matrices.
- rms_norm_eps (`float`, *optional*, defaults to 1e-06):
- The epsilon used by the rms normalization layers.
- use_cache (`bool`, *optional*, defaults to `True`):
- Whether or not the model should return the last key/values attentions (not used by all models). Only
- relevant if `config.is_decoder=True`.
- pad_token_id (`int`, *optional*, defaults to 0):
- Padding token id.
- eos_token_id (`int`, *optional*, defaults to 1):
- End of stream token id.
- bos_token_id (`int`, *optional*, defaults to 2):
- Beginning of stream token id.
- tie_word_embeddings (`bool`, *optional*, defaults to `True`):
- Whether to tie weight embeddings
- rope_theta (`float`, *optional*, defaults to 10000.0):
- The base period of the RoPE embeddings.
- attention_bias (`bool`, defaults to `False`, *optional*, defaults to `False`):
- Whether to use a bias in the query, key, value and output projection layers during self-attention.
- attention_dropout (`float`, *optional*, defaults to 0.0):
- The dropout ratio for the attention probabilities.
- query_pre_attn_scalar (`float`, *optional*, defaults to 256):
- scaling factor used on the attention scores
- sliding_window (`int`, *optional*, defaults to 4096):
- in Gemma2, every other layer uses sliding window attention. This is the size of the sliding window.
- layer_types (`list`, *optional*):
- Attention pattern for each layer.
- final_logit_softcapping (`float`, *optional*, defaults to 30.0):
- scaling factor when applying tanh softcapping on the logits.
- attn_logit_softcapping (`float`, *optional*, defaults to 50.0):
- scaling factor when applying tanh softcapping on the attention scores.
- ```python
- >>> from transformers import Gemma2Model, Gemma2Config
- >>> # Initializing a Gemma2 gemma2-7b style configuration
- >>> configuration = Gemma2Config()
- >>> # Initializing a model from the gemma2-7b style configuration
- >>> model = Gemma2Model(configuration)
- >>> # Accessing the model configuration
- >>> configuration = model.config
- ```"""
- model_type = "gemma2"
- keys_to_ignore_at_inference = ["past_key_values"]
- base_model_tp_plan = {
- "layers.*.self_attn.q_proj": "colwise",
- "layers.*.self_attn.k_proj": "colwise",
- "layers.*.self_attn.v_proj": "colwise",
- "layers.*.self_attn.o_proj": "rowwise",
- "layers.*.mlp.gate_proj": "colwise",
- "layers.*.mlp.up_proj": "colwise",
- "layers.*.mlp.down_proj": "rowwise",
- }
- base_model_pp_plan = {
- "embed_tokens": (["input_ids"], ["inputs_embeds"]),
- "layers": (["hidden_states", "attention_mask"], ["hidden_states"]),
- "norm": (["hidden_states"], ["hidden_states"]),
- }
- def __init__(
- self,
- vocab_size=256000,
- hidden_size=2304,
- intermediate_size=9216,
- num_hidden_layers=26,
- num_attention_heads=8,
- num_key_value_heads=4,
- head_dim=256,
- hidden_activation="gelu_pytorch_tanh",
- max_position_embeddings=8192,
- initializer_range=0.02,
- rms_norm_eps=1e-6,
- use_cache=True,
- pad_token_id=0,
- eos_token_id=1,
- bos_token_id=2,
- tie_word_embeddings=True,
- rope_theta=10000.0,
- attention_bias=False,
- attention_dropout=0.0,
- query_pre_attn_scalar=256,
- sliding_window=4096,
- layer_types=None,
- final_logit_softcapping=30.0,
- attn_logit_softcapping=50.0,
- **kwargs,
- ):
- super().__init__(
- pad_token_id=pad_token_id,
- bos_token_id=bos_token_id,
- eos_token_id=eos_token_id,
- tie_word_embeddings=tie_word_embeddings,
- **kwargs,
- )
- self.vocab_size = vocab_size
- self.max_position_embeddings = max_position_embeddings
- self.hidden_size = hidden_size
- self.intermediate_size = intermediate_size
- self.num_hidden_layers = num_hidden_layers
- self.num_attention_heads = num_attention_heads
- self.head_dim = head_dim
- self.num_key_value_heads = num_key_value_heads
- self.initializer_range = initializer_range
- self.rms_norm_eps = rms_norm_eps
- self.use_cache = use_cache
- self.rope_theta = rope_theta
- self.attention_bias = attention_bias
- self.attention_dropout = attention_dropout
- self.hidden_activation = hidden_activation
- self.query_pre_attn_scalar = query_pre_attn_scalar
- self.sliding_window = sliding_window
- self.final_logit_softcapping = final_logit_softcapping
- self.attn_logit_softcapping = attn_logit_softcapping
- self.layer_types = layer_types
- if self.layer_types is None:
- self.layer_types = [
- "sliding_attention" if bool((i + 1) % 2) else "full_attention" for i in range(self.num_hidden_layers)
- ]
- layer_type_validation(self.layer_types, self.num_hidden_layers)
- class Gemma2RMSNorm(GemmaRMSNorm):
- pass
- class Gemma2MLP(GemmaMLP):
- def __init__(self, config):
- super().__init__(config)
- self.act_fn = ACT2FN[config.hidden_activation]
- class Gemma2RotaryEmbedding(GemmaRotaryEmbedding):
- pass
- def eager_attention_forward(
- module: nn.Module,
- query: torch.Tensor,
- key: torch.Tensor,
- value: torch.Tensor,
- attention_mask: Optional[torch.Tensor],
- dropout: float = 0.0,
- scaling: Optional[float] = None,
- softcap: Optional[float] = None,
- **kwargs,
- ) -> tuple[torch.Tensor, torch.Tensor]:
- if scaling is None:
- scaling = module.head_dim**-0.5
- key_states = repeat_kv(key, module.num_key_value_groups)
- value_states = repeat_kv(value, module.num_key_value_groups)
- attn_weights = torch.matmul(query, key_states.transpose(2, 3)) * scaling
- if softcap is not None:
- attn_weights = attn_weights / softcap
- attn_weights = torch.tanh(attn_weights)
- attn_weights = attn_weights * softcap
- if attention_mask is not None: # no matter the length, we just slice it
- causal_mask = attention_mask[:, :, :, : key_states.shape[-2]]
- attn_weights = attn_weights + causal_mask
- # upcast attention to fp32
- attn_weights = nn.functional.softmax(attn_weights, dim=-1, dtype=torch.float32).to(query.dtype)
- attn_weights = nn.functional.dropout(attn_weights, p=dropout, training=module.training)
- attn_output = torch.matmul(attn_weights, value_states)
- attn_output = attn_output.transpose(1, 2).contiguous()
- return attn_output, attn_weights
- class Gemma2Attention(GemmaAttention):
- def __init__(self, config: Gemma2Config, layer_idx: int):
- super().__init__(config, layer_idx)
- self.attn_logit_softcapping = self.config.attn_logit_softcapping
- self.attention_dropout = self.config.attention_dropout
- self.is_causal = True
- self.scaling = config.query_pre_attn_scalar**-0.5
- self.sliding_window = config.sliding_window if config.layer_types[layer_idx] == "sliding_attention" else None
- @deprecate_kwarg("past_key_value", new_name="past_key_values", version="4.58")
- def forward(
- self,
- hidden_states: torch.Tensor,
- position_embeddings: tuple[torch.Tensor, torch.Tensor],
- attention_mask: Optional[torch.Tensor],
- past_key_values: Optional[Cache] = None,
- cache_position: Optional[torch.LongTensor] = None,
- **kwargs: Unpack[FlashAttentionKwargs],
- ) -> tuple[torch.Tensor, Optional[torch.Tensor], Optional[tuple[torch.Tensor]]]:
- input_shape = hidden_states.shape[:-1]
- hidden_shape = (*input_shape, -1, self.head_dim)
- query_states = self.q_proj(hidden_states).view(hidden_shape).transpose(1, 2)
- key_states = self.k_proj(hidden_states).view(hidden_shape).transpose(1, 2)
- value_states = self.v_proj(hidden_states).view(hidden_shape).transpose(1, 2)
- cos, sin = position_embeddings
- query_states, key_states = apply_rotary_pos_emb(query_states, key_states, cos, sin)
- if past_key_values is not None:
- # sin and cos are specific to RoPE models; cache_position needed for the static cache
- cache_kwargs = {"sin": sin, "cos": cos, "cache_position": cache_position}
- key_states, value_states = past_key_values.update(key_states, value_states, self.layer_idx, cache_kwargs)
- attention_interface: Callable = eager_attention_forward
- if self.config._attn_implementation != "eager":
- attention_interface = ALL_ATTENTION_FUNCTIONS[self.config._attn_implementation]
- attn_output, attn_weights = attention_interface(
- self,
- query_states,
- key_states,
- value_states,
- attention_mask,
- dropout=self.attention_dropout if self.training else 0.0,
- scaling=self.scaling,
- sliding_window=self.sliding_window,
- softcap=self.attn_logit_softcapping,
- **kwargs,
- )
- attn_output = attn_output.reshape(*input_shape, -1).contiguous()
- attn_output = self.o_proj(attn_output)
- return attn_output, attn_weights
- class Gemma2DecoderLayer(GradientCheckpointingLayer):
- def __init__(self, config: Gemma2Config, layer_idx: int):
- super().__init__()
- self.hidden_size = config.hidden_size
- self.config = config
- self.attention_type = config.layer_types[layer_idx]
- self.self_attn = Gemma2Attention(config=config, layer_idx=layer_idx)
- self.mlp = Gemma2MLP(config)
- self.input_layernorm = Gemma2RMSNorm(config.hidden_size, eps=config.rms_norm_eps)
- self.post_attention_layernorm = Gemma2RMSNorm(config.hidden_size, eps=config.rms_norm_eps)
- self.pre_feedforward_layernorm = Gemma2RMSNorm(config.hidden_size, eps=config.rms_norm_eps)
- self.post_feedforward_layernorm = Gemma2RMSNorm(config.hidden_size, eps=config.rms_norm_eps)
- @deprecate_kwarg("past_key_value", new_name="past_key_values", version="4.58")
- def forward(
- self,
- hidden_states: torch.Tensor,
- position_embeddings: tuple[torch.Tensor, torch.Tensor],
- attention_mask: Optional[torch.Tensor] = None,
- position_ids: Optional[torch.LongTensor] = None,
- past_key_values: Optional[Cache] = None,
- output_attentions: Optional[bool] = False,
- use_cache: Optional[bool] = False,
- cache_position: Optional[torch.LongTensor] = None,
- **kwargs,
- ) -> tuple[torch.FloatTensor, Optional[tuple[torch.FloatTensor, torch.FloatTensor]]]:
- residual = hidden_states
- hidden_states = self.input_layernorm(hidden_states)
- # Self Attention
- hidden_states, self_attn_weights = self.self_attn(
- hidden_states=hidden_states,
- position_embeddings=position_embeddings,
- attention_mask=attention_mask,
- position_ids=position_ids,
- past_key_values=past_key_values,
- output_attentions=output_attentions,
- use_cache=use_cache,
- cache_position=cache_position,
- **kwargs,
- )
- hidden_states = self.post_attention_layernorm(hidden_states)
- hidden_states = residual + hidden_states
- residual = hidden_states
- hidden_states = self.pre_feedforward_layernorm(hidden_states)
- hidden_states = self.mlp(hidden_states)
- hidden_states = self.post_feedforward_layernorm(hidden_states)
- hidden_states = residual + hidden_states
- outputs = (hidden_states,)
- if output_attentions:
- outputs += (self_attn_weights,)
- return outputs
- class Gemma2PreTrainedModel(GemmaPreTrainedModel):
- pass
- class Gemma2Model(GemmaModel):
- def __init__(self, config: Gemma2Config):
- super().__init__(config)
- self.layers = nn.ModuleList(
- [Gemma2DecoderLayer(config, layer_idx) for layer_idx in range(config.num_hidden_layers)]
- )
- def forward(
- self,
- input_ids: Optional[torch.LongTensor] = None,
- attention_mask: Optional[torch.Tensor] = None,
- position_ids: Optional[torch.LongTensor] = None,
- past_key_values: Optional[Cache] = None,
- inputs_embeds: Optional[torch.FloatTensor] = None,
- use_cache: Optional[bool] = None,
- output_attentions: Optional[bool] = None,
- output_hidden_states: Optional[bool] = None,
- cache_position: Optional[torch.LongTensor] = None,
- **kwargs: Unpack[TransformersKwargs],
- ) -> BaseModelOutputWithPast:
- output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
- output_hidden_states = (
- output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
- )
- use_cache = use_cache if use_cache is not None else self.config.use_cache
- if (input_ids is None) ^ (inputs_embeds is not None):
- raise ValueError("You must specify exactly one of input_ids or inputs_embeds")
- if self.gradient_checkpointing and self.training and use_cache:
- logger.warning_once(
- "`use_cache=True` is incompatible with gradient checkpointing. Setting `use_cache=False`."
- )
- use_cache = False
- if inputs_embeds is None:
- inputs_embeds = self.embed_tokens(input_ids)
- if use_cache and past_key_values is None and not self.training:
- past_key_values = DynamicCache(config=self.config)
- if cache_position is None:
- past_seen_tokens = past_key_values.get_seq_length() if past_key_values is not None else 0
- cache_position = torch.arange(
- past_seen_tokens, past_seen_tokens + inputs_embeds.shape[1], device=inputs_embeds.device
- )
- if position_ids is None:
- position_ids = cache_position.unsqueeze(0)
- # It may already have been prepared by e.g. `generate`
- if not isinstance(causal_mask_mapping := attention_mask, dict):
- # Prepare mask arguments
- mask_kwargs = {
- "config": self.config,
- "input_embeds": inputs_embeds,
- "attention_mask": attention_mask,
- "cache_position": cache_position,
- "past_key_values": past_key_values,
- "position_ids": position_ids,
- }
- # Create the masks
- causal_mask_mapping = {
- "full_attention": create_causal_mask(**mask_kwargs),
- "sliding_attention": create_sliding_window_causal_mask(**mask_kwargs),
- }
- # embed positions
- hidden_states = inputs_embeds
- # create position embeddings to be shared across the decoder layers
- position_embeddings = self.rotary_emb(hidden_states, position_ids)
- # normalized
- # Gemma2 downcasts the below to float16, causing sqrt(3072)=55.4256 to become 55.5
- # See https://github.com/huggingface/transformers/pull/29402
- normalizer = torch.tensor(self.config.hidden_size**0.5, dtype=hidden_states.dtype)
- hidden_states = hidden_states * normalizer
- # decoder layers
- all_hidden_states = () if output_hidden_states else None
- all_self_attns = () if output_attentions else None
- for decoder_layer in self.layers[: self.config.num_hidden_layers]:
- if output_hidden_states:
- all_hidden_states += (hidden_states,)
- layer_outputs = decoder_layer(
- hidden_states,
- position_embeddings=position_embeddings,
- attention_mask=causal_mask_mapping[decoder_layer.attention_type],
- position_ids=position_ids,
- past_key_values=past_key_values,
- output_attentions=output_attentions,
- use_cache=use_cache,
- cache_position=cache_position,
- **kwargs,
- )
- hidden_states = layer_outputs[0]
- if output_attentions:
- all_self_attns += (layer_outputs[1],)
- hidden_states = self.norm(hidden_states)
- if output_hidden_states:
- all_hidden_states += (hidden_states,)
- return BaseModelOutputWithPast(
- last_hidden_state=hidden_states,
- past_key_values=past_key_values,
- hidden_states=all_hidden_states,
- attentions=all_self_attns,
- )
- class Gemma2ForCausalLM(GemmaForCausalLM):
- def __init__(self, config):
- super().__init__(config)
- self.model = Gemma2Model(config)
- self.post_init()
- def forward(
- self,
- input_ids: Optional[torch.LongTensor] = None,
- attention_mask: Optional[torch.Tensor] = None,
- position_ids: Optional[torch.LongTensor] = None,
- past_key_values: Optional[Cache] = None,
- inputs_embeds: Optional[torch.FloatTensor] = None,
- labels: Optional[torch.LongTensor] = None,
- use_cache: Optional[bool] = None,
- output_attentions: Optional[bool] = None,
- output_hidden_states: Optional[bool] = None,
- cache_position: Optional[torch.LongTensor] = None,
- logits_to_keep: Union[int, torch.Tensor] = 0,
- **kwargs,
- ) -> CausalLMOutputWithPast:
- r"""
- Example:
- ```python
- >>> from transformers import AutoTokenizer, Gemma2ForCausalLM
- >>> model = Gemma2ForCausalLM.from_pretrained("google/gemma-2-9b")
- >>> tokenizer = AutoTokenizer.from_pretrained("google/gemma-2-9b")
- >>> prompt = "What is your favorite condiment?"
- >>> inputs = tokenizer(prompt, return_tensors="pt")
- >>> # Generate
- >>> generate_ids = model.generate(inputs.input_ids, max_length=30)
- >>> tokenizer.batch_decode(generate_ids, skip_special_tokens=True, clean_up_tokenization_spaces=False)[0]
- "What is your favorite condiment?"
- ```"""
- output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
- output_hidden_states = (
- output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
- )
- # decoder outputs consists of (dec_features, layer_state, dec_hidden, dec_attn)
- outputs: BaseModelOutputWithPast = self.model(
- input_ids=input_ids,
- attention_mask=attention_mask,
- position_ids=position_ids,
- past_key_values=past_key_values,
- inputs_embeds=inputs_embeds,
- use_cache=use_cache,
- output_attentions=output_attentions,
- output_hidden_states=output_hidden_states,
- cache_position=cache_position,
- **kwargs,
- )
- hidden_states = outputs.last_hidden_state
- # Only compute necessary logits, and do not upcast them to float if we are not computing the loss
- slice_indices = slice(-logits_to_keep, None) if isinstance(logits_to_keep, int) else logits_to_keep
- logits = self.lm_head(hidden_states[:, slice_indices, :])
- if self.config.final_logit_softcapping is not None:
- logits = logits / self.config.final_logit_softcapping
- logits = torch.tanh(logits)
- logits = logits * self.config.final_logit_softcapping
- loss = None
- if labels is not None:
- loss = self.loss_function(logits, labels, self.vocab_size, **kwargs)
- return CausalLMOutputWithPast(
- loss=loss,
- logits=logits,
- past_key_values=outputs.past_key_values,
- hidden_states=outputs.hidden_states,
- attentions=outputs.attentions,
- )
- class Gemma2ForSequenceClassification(GemmaForSequenceClassification):
- pass
- class Gemma2ForTokenClassification(GemmaForTokenClassification):
- pass
- __all__ = [
- "Gemma2Config",
- "Gemma2ForCausalLM",
- "Gemma2Model",
- "Gemma2PreTrainedModel",
- "Gemma2ForSequenceClassification",
- "Gemma2ForTokenClassification",
- ]
|