modular_gemma2.py 24 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525526527528529530531532533534535536537538539540541542543544545546547548549550551552553554555556557558559560561562563564565566567568569570571572573574575576577578579580581582583584585586587
  1. # coding=utf-8
  2. # Copyright 2024 Google Inc. HuggingFace Inc. team. All rights reserved.
  3. #
  4. #
  5. # Licensed under the Apache License, Version 2.0 (the "License");
  6. # you may not use this file except in compliance with the License.
  7. # You may obtain a copy of the License at
  8. #
  9. # http://www.apache.org/licenses/LICENSE-2.0
  10. #
  11. # Unless required by applicable law or agreed to in writing, software
  12. # distributed under the License is distributed on an "AS IS" BASIS,
  13. # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  14. # See the License for the specific language governing permissions and
  15. # limitations under the License.
  16. from typing import Callable, Optional, Union
  17. import torch
  18. import torch.nn as nn
  19. from ...activations import ACT2FN
  20. from ...cache_utils import Cache, DynamicCache
  21. from ...configuration_utils import PretrainedConfig, layer_type_validation
  22. from ...masking_utils import create_causal_mask, create_sliding_window_causal_mask
  23. from ...modeling_flash_attention_utils import FlashAttentionKwargs
  24. from ...modeling_layers import GradientCheckpointingLayer
  25. from ...modeling_outputs import BaseModelOutputWithPast, CausalLMOutputWithPast
  26. from ...modeling_utils import ALL_ATTENTION_FUNCTIONS
  27. from ...processing_utils import Unpack
  28. from ...utils import TransformersKwargs, logging
  29. from ...utils.deprecation import deprecate_kwarg
  30. from ..gemma.modeling_gemma import (
  31. GemmaAttention,
  32. GemmaForCausalLM,
  33. GemmaForSequenceClassification,
  34. GemmaForTokenClassification,
  35. GemmaMLP,
  36. GemmaModel,
  37. GemmaPreTrainedModel,
  38. GemmaRMSNorm,
  39. GemmaRotaryEmbedding,
  40. apply_rotary_pos_emb,
  41. repeat_kv,
  42. )
  43. logger = logging.get_logger(__name__)
  44. class Gemma2Config(PretrainedConfig):
  45. r"""
  46. This is the configuration class to store the configuration of a [`Gemma2Model`]. It is used to instantiate an Gemma2
  47. model according to the specified arguments, defining the model architecture. Instantiating a configuration with the
  48. defaults will yield a similar configuration to that of the Gemma2-7B.
  49. e.g. [google/gemma2-7b](https://huggingface.co/google/gemma2-7b)
  50. Configuration objects inherit from [`PretrainedConfig`] and can be used to control the model outputs. Read the
  51. documentation from [`PretrainedConfig`] for more information.
  52. Args:
  53. vocab_size (`int`, *optional*, defaults to 256000):
  54. Vocabulary size of the Gemma2 model. Defines the number of different tokens that can be represented by the
  55. `inputs_ids` passed when calling [`Gemma2Model`]
  56. hidden_size (`int`, *optional*, defaults to 2304):
  57. Dimension of the hidden representations.
  58. intermediate_size (`int`, *optional*, defaults to 9216):
  59. Dimension of the MLP representations.
  60. num_hidden_layers (`int`, *optional*, defaults to 26):
  61. Number of hidden layers in the Transformer decoder.
  62. num_attention_heads (`int`, *optional*, defaults to 8):
  63. Number of attention heads for each attention layer in the Transformer decoder.
  64. num_key_value_heads (`int`, *optional*, defaults to 4):
  65. This is the number of key_value heads that should be used to implement Grouped Query Attention. If
  66. `num_key_value_heads=num_attention_heads`, the model will use Multi Head Attention (MHA), if
  67. `num_key_value_heads=1` the model will use Multi Query Attention (MQA) otherwise GQA is used. When
  68. converting a multi-head checkpoint to a GQA checkpoint, each group key and value head should be constructed
  69. by meanpooling all the original heads within that group. For more details, check out [this
  70. paper](https://huggingface.co/papers/2305.13245). If it is not specified, will default to
  71. `num_attention_heads`.
  72. head_dim (`int`, *optional*, defaults to 256):
  73. The attention head dimension.
  74. hidden_activation (`str` or `function`, *optional*, defaults to `"gelu_pytorch_tanh"`):
  75. The non-linear activation function (function or string) in the decoder. Will default to `"gelu_pytorch_tanh"`
  76. if not specified. `"gelu_pytorch_tanh"` uses an approximation of the `"gelu"` activation function.
  77. max_position_embeddings (`int`, *optional*, defaults to 8192):
  78. The maximum sequence length that this model might ever be used with.
  79. initializer_range (`float`, *optional*, defaults to 0.02):
  80. The standard deviation of the truncated_normal_initializer for initializing all weight matrices.
  81. rms_norm_eps (`float`, *optional*, defaults to 1e-06):
  82. The epsilon used by the rms normalization layers.
  83. use_cache (`bool`, *optional*, defaults to `True`):
  84. Whether or not the model should return the last key/values attentions (not used by all models). Only
  85. relevant if `config.is_decoder=True`.
  86. pad_token_id (`int`, *optional*, defaults to 0):
  87. Padding token id.
  88. eos_token_id (`int`, *optional*, defaults to 1):
  89. End of stream token id.
  90. bos_token_id (`int`, *optional*, defaults to 2):
  91. Beginning of stream token id.
  92. tie_word_embeddings (`bool`, *optional*, defaults to `True`):
  93. Whether to tie weight embeddings
  94. rope_theta (`float`, *optional*, defaults to 10000.0):
  95. The base period of the RoPE embeddings.
  96. attention_bias (`bool`, defaults to `False`, *optional*, defaults to `False`):
  97. Whether to use a bias in the query, key, value and output projection layers during self-attention.
  98. attention_dropout (`float`, *optional*, defaults to 0.0):
  99. The dropout ratio for the attention probabilities.
  100. query_pre_attn_scalar (`float`, *optional*, defaults to 256):
  101. scaling factor used on the attention scores
  102. sliding_window (`int`, *optional*, defaults to 4096):
  103. in Gemma2, every other layer uses sliding window attention. This is the size of the sliding window.
  104. layer_types (`list`, *optional*):
  105. Attention pattern for each layer.
  106. final_logit_softcapping (`float`, *optional*, defaults to 30.0):
  107. scaling factor when applying tanh softcapping on the logits.
  108. attn_logit_softcapping (`float`, *optional*, defaults to 50.0):
  109. scaling factor when applying tanh softcapping on the attention scores.
  110. ```python
  111. >>> from transformers import Gemma2Model, Gemma2Config
  112. >>> # Initializing a Gemma2 gemma2-7b style configuration
  113. >>> configuration = Gemma2Config()
  114. >>> # Initializing a model from the gemma2-7b style configuration
  115. >>> model = Gemma2Model(configuration)
  116. >>> # Accessing the model configuration
  117. >>> configuration = model.config
  118. ```"""
  119. model_type = "gemma2"
  120. keys_to_ignore_at_inference = ["past_key_values"]
  121. base_model_tp_plan = {
  122. "layers.*.self_attn.q_proj": "colwise",
  123. "layers.*.self_attn.k_proj": "colwise",
  124. "layers.*.self_attn.v_proj": "colwise",
  125. "layers.*.self_attn.o_proj": "rowwise",
  126. "layers.*.mlp.gate_proj": "colwise",
  127. "layers.*.mlp.up_proj": "colwise",
  128. "layers.*.mlp.down_proj": "rowwise",
  129. }
  130. base_model_pp_plan = {
  131. "embed_tokens": (["input_ids"], ["inputs_embeds"]),
  132. "layers": (["hidden_states", "attention_mask"], ["hidden_states"]),
  133. "norm": (["hidden_states"], ["hidden_states"]),
  134. }
  135. def __init__(
  136. self,
  137. vocab_size=256000,
  138. hidden_size=2304,
  139. intermediate_size=9216,
  140. num_hidden_layers=26,
  141. num_attention_heads=8,
  142. num_key_value_heads=4,
  143. head_dim=256,
  144. hidden_activation="gelu_pytorch_tanh",
  145. max_position_embeddings=8192,
  146. initializer_range=0.02,
  147. rms_norm_eps=1e-6,
  148. use_cache=True,
  149. pad_token_id=0,
  150. eos_token_id=1,
  151. bos_token_id=2,
  152. tie_word_embeddings=True,
  153. rope_theta=10000.0,
  154. attention_bias=False,
  155. attention_dropout=0.0,
  156. query_pre_attn_scalar=256,
  157. sliding_window=4096,
  158. layer_types=None,
  159. final_logit_softcapping=30.0,
  160. attn_logit_softcapping=50.0,
  161. **kwargs,
  162. ):
  163. super().__init__(
  164. pad_token_id=pad_token_id,
  165. bos_token_id=bos_token_id,
  166. eos_token_id=eos_token_id,
  167. tie_word_embeddings=tie_word_embeddings,
  168. **kwargs,
  169. )
  170. self.vocab_size = vocab_size
  171. self.max_position_embeddings = max_position_embeddings
  172. self.hidden_size = hidden_size
  173. self.intermediate_size = intermediate_size
  174. self.num_hidden_layers = num_hidden_layers
  175. self.num_attention_heads = num_attention_heads
  176. self.head_dim = head_dim
  177. self.num_key_value_heads = num_key_value_heads
  178. self.initializer_range = initializer_range
  179. self.rms_norm_eps = rms_norm_eps
  180. self.use_cache = use_cache
  181. self.rope_theta = rope_theta
  182. self.attention_bias = attention_bias
  183. self.attention_dropout = attention_dropout
  184. self.hidden_activation = hidden_activation
  185. self.query_pre_attn_scalar = query_pre_attn_scalar
  186. self.sliding_window = sliding_window
  187. self.final_logit_softcapping = final_logit_softcapping
  188. self.attn_logit_softcapping = attn_logit_softcapping
  189. self.layer_types = layer_types
  190. if self.layer_types is None:
  191. self.layer_types = [
  192. "sliding_attention" if bool((i + 1) % 2) else "full_attention" for i in range(self.num_hidden_layers)
  193. ]
  194. layer_type_validation(self.layer_types, self.num_hidden_layers)
  195. class Gemma2RMSNorm(GemmaRMSNorm):
  196. pass
  197. class Gemma2MLP(GemmaMLP):
  198. def __init__(self, config):
  199. super().__init__(config)
  200. self.act_fn = ACT2FN[config.hidden_activation]
  201. class Gemma2RotaryEmbedding(GemmaRotaryEmbedding):
  202. pass
  203. def eager_attention_forward(
  204. module: nn.Module,
  205. query: torch.Tensor,
  206. key: torch.Tensor,
  207. value: torch.Tensor,
  208. attention_mask: Optional[torch.Tensor],
  209. dropout: float = 0.0,
  210. scaling: Optional[float] = None,
  211. softcap: Optional[float] = None,
  212. **kwargs,
  213. ) -> tuple[torch.Tensor, torch.Tensor]:
  214. if scaling is None:
  215. scaling = module.head_dim**-0.5
  216. key_states = repeat_kv(key, module.num_key_value_groups)
  217. value_states = repeat_kv(value, module.num_key_value_groups)
  218. attn_weights = torch.matmul(query, key_states.transpose(2, 3)) * scaling
  219. if softcap is not None:
  220. attn_weights = attn_weights / softcap
  221. attn_weights = torch.tanh(attn_weights)
  222. attn_weights = attn_weights * softcap
  223. if attention_mask is not None: # no matter the length, we just slice it
  224. causal_mask = attention_mask[:, :, :, : key_states.shape[-2]]
  225. attn_weights = attn_weights + causal_mask
  226. # upcast attention to fp32
  227. attn_weights = nn.functional.softmax(attn_weights, dim=-1, dtype=torch.float32).to(query.dtype)
  228. attn_weights = nn.functional.dropout(attn_weights, p=dropout, training=module.training)
  229. attn_output = torch.matmul(attn_weights, value_states)
  230. attn_output = attn_output.transpose(1, 2).contiguous()
  231. return attn_output, attn_weights
  232. class Gemma2Attention(GemmaAttention):
  233. def __init__(self, config: Gemma2Config, layer_idx: int):
  234. super().__init__(config, layer_idx)
  235. self.attn_logit_softcapping = self.config.attn_logit_softcapping
  236. self.attention_dropout = self.config.attention_dropout
  237. self.is_causal = True
  238. self.scaling = config.query_pre_attn_scalar**-0.5
  239. self.sliding_window = config.sliding_window if config.layer_types[layer_idx] == "sliding_attention" else None
  240. @deprecate_kwarg("past_key_value", new_name="past_key_values", version="4.58")
  241. def forward(
  242. self,
  243. hidden_states: torch.Tensor,
  244. position_embeddings: tuple[torch.Tensor, torch.Tensor],
  245. attention_mask: Optional[torch.Tensor],
  246. past_key_values: Optional[Cache] = None,
  247. cache_position: Optional[torch.LongTensor] = None,
  248. **kwargs: Unpack[FlashAttentionKwargs],
  249. ) -> tuple[torch.Tensor, Optional[torch.Tensor], Optional[tuple[torch.Tensor]]]:
  250. input_shape = hidden_states.shape[:-1]
  251. hidden_shape = (*input_shape, -1, self.head_dim)
  252. query_states = self.q_proj(hidden_states).view(hidden_shape).transpose(1, 2)
  253. key_states = self.k_proj(hidden_states).view(hidden_shape).transpose(1, 2)
  254. value_states = self.v_proj(hidden_states).view(hidden_shape).transpose(1, 2)
  255. cos, sin = position_embeddings
  256. query_states, key_states = apply_rotary_pos_emb(query_states, key_states, cos, sin)
  257. if past_key_values is not None:
  258. # sin and cos are specific to RoPE models; cache_position needed for the static cache
  259. cache_kwargs = {"sin": sin, "cos": cos, "cache_position": cache_position}
  260. key_states, value_states = past_key_values.update(key_states, value_states, self.layer_idx, cache_kwargs)
  261. attention_interface: Callable = eager_attention_forward
  262. if self.config._attn_implementation != "eager":
  263. attention_interface = ALL_ATTENTION_FUNCTIONS[self.config._attn_implementation]
  264. attn_output, attn_weights = attention_interface(
  265. self,
  266. query_states,
  267. key_states,
  268. value_states,
  269. attention_mask,
  270. dropout=self.attention_dropout if self.training else 0.0,
  271. scaling=self.scaling,
  272. sliding_window=self.sliding_window,
  273. softcap=self.attn_logit_softcapping,
  274. **kwargs,
  275. )
  276. attn_output = attn_output.reshape(*input_shape, -1).contiguous()
  277. attn_output = self.o_proj(attn_output)
  278. return attn_output, attn_weights
  279. class Gemma2DecoderLayer(GradientCheckpointingLayer):
  280. def __init__(self, config: Gemma2Config, layer_idx: int):
  281. super().__init__()
  282. self.hidden_size = config.hidden_size
  283. self.config = config
  284. self.attention_type = config.layer_types[layer_idx]
  285. self.self_attn = Gemma2Attention(config=config, layer_idx=layer_idx)
  286. self.mlp = Gemma2MLP(config)
  287. self.input_layernorm = Gemma2RMSNorm(config.hidden_size, eps=config.rms_norm_eps)
  288. self.post_attention_layernorm = Gemma2RMSNorm(config.hidden_size, eps=config.rms_norm_eps)
  289. self.pre_feedforward_layernorm = Gemma2RMSNorm(config.hidden_size, eps=config.rms_norm_eps)
  290. self.post_feedforward_layernorm = Gemma2RMSNorm(config.hidden_size, eps=config.rms_norm_eps)
  291. @deprecate_kwarg("past_key_value", new_name="past_key_values", version="4.58")
  292. def forward(
  293. self,
  294. hidden_states: torch.Tensor,
  295. position_embeddings: tuple[torch.Tensor, torch.Tensor],
  296. attention_mask: Optional[torch.Tensor] = None,
  297. position_ids: Optional[torch.LongTensor] = None,
  298. past_key_values: Optional[Cache] = None,
  299. output_attentions: Optional[bool] = False,
  300. use_cache: Optional[bool] = False,
  301. cache_position: Optional[torch.LongTensor] = None,
  302. **kwargs,
  303. ) -> tuple[torch.FloatTensor, Optional[tuple[torch.FloatTensor, torch.FloatTensor]]]:
  304. residual = hidden_states
  305. hidden_states = self.input_layernorm(hidden_states)
  306. # Self Attention
  307. hidden_states, self_attn_weights = self.self_attn(
  308. hidden_states=hidden_states,
  309. position_embeddings=position_embeddings,
  310. attention_mask=attention_mask,
  311. position_ids=position_ids,
  312. past_key_values=past_key_values,
  313. output_attentions=output_attentions,
  314. use_cache=use_cache,
  315. cache_position=cache_position,
  316. **kwargs,
  317. )
  318. hidden_states = self.post_attention_layernorm(hidden_states)
  319. hidden_states = residual + hidden_states
  320. residual = hidden_states
  321. hidden_states = self.pre_feedforward_layernorm(hidden_states)
  322. hidden_states = self.mlp(hidden_states)
  323. hidden_states = self.post_feedforward_layernorm(hidden_states)
  324. hidden_states = residual + hidden_states
  325. outputs = (hidden_states,)
  326. if output_attentions:
  327. outputs += (self_attn_weights,)
  328. return outputs
  329. class Gemma2PreTrainedModel(GemmaPreTrainedModel):
  330. pass
  331. class Gemma2Model(GemmaModel):
  332. def __init__(self, config: Gemma2Config):
  333. super().__init__(config)
  334. self.layers = nn.ModuleList(
  335. [Gemma2DecoderLayer(config, layer_idx) for layer_idx in range(config.num_hidden_layers)]
  336. )
  337. def forward(
  338. self,
  339. input_ids: Optional[torch.LongTensor] = None,
  340. attention_mask: Optional[torch.Tensor] = None,
  341. position_ids: Optional[torch.LongTensor] = None,
  342. past_key_values: Optional[Cache] = None,
  343. inputs_embeds: Optional[torch.FloatTensor] = None,
  344. use_cache: Optional[bool] = None,
  345. output_attentions: Optional[bool] = None,
  346. output_hidden_states: Optional[bool] = None,
  347. cache_position: Optional[torch.LongTensor] = None,
  348. **kwargs: Unpack[TransformersKwargs],
  349. ) -> BaseModelOutputWithPast:
  350. output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
  351. output_hidden_states = (
  352. output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
  353. )
  354. use_cache = use_cache if use_cache is not None else self.config.use_cache
  355. if (input_ids is None) ^ (inputs_embeds is not None):
  356. raise ValueError("You must specify exactly one of input_ids or inputs_embeds")
  357. if self.gradient_checkpointing and self.training and use_cache:
  358. logger.warning_once(
  359. "`use_cache=True` is incompatible with gradient checkpointing. Setting `use_cache=False`."
  360. )
  361. use_cache = False
  362. if inputs_embeds is None:
  363. inputs_embeds = self.embed_tokens(input_ids)
  364. if use_cache and past_key_values is None and not self.training:
  365. past_key_values = DynamicCache(config=self.config)
  366. if cache_position is None:
  367. past_seen_tokens = past_key_values.get_seq_length() if past_key_values is not None else 0
  368. cache_position = torch.arange(
  369. past_seen_tokens, past_seen_tokens + inputs_embeds.shape[1], device=inputs_embeds.device
  370. )
  371. if position_ids is None:
  372. position_ids = cache_position.unsqueeze(0)
  373. # It may already have been prepared by e.g. `generate`
  374. if not isinstance(causal_mask_mapping := attention_mask, dict):
  375. # Prepare mask arguments
  376. mask_kwargs = {
  377. "config": self.config,
  378. "input_embeds": inputs_embeds,
  379. "attention_mask": attention_mask,
  380. "cache_position": cache_position,
  381. "past_key_values": past_key_values,
  382. "position_ids": position_ids,
  383. }
  384. # Create the masks
  385. causal_mask_mapping = {
  386. "full_attention": create_causal_mask(**mask_kwargs),
  387. "sliding_attention": create_sliding_window_causal_mask(**mask_kwargs),
  388. }
  389. # embed positions
  390. hidden_states = inputs_embeds
  391. # create position embeddings to be shared across the decoder layers
  392. position_embeddings = self.rotary_emb(hidden_states, position_ids)
  393. # normalized
  394. # Gemma2 downcasts the below to float16, causing sqrt(3072)=55.4256 to become 55.5
  395. # See https://github.com/huggingface/transformers/pull/29402
  396. normalizer = torch.tensor(self.config.hidden_size**0.5, dtype=hidden_states.dtype)
  397. hidden_states = hidden_states * normalizer
  398. # decoder layers
  399. all_hidden_states = () if output_hidden_states else None
  400. all_self_attns = () if output_attentions else None
  401. for decoder_layer in self.layers[: self.config.num_hidden_layers]:
  402. if output_hidden_states:
  403. all_hidden_states += (hidden_states,)
  404. layer_outputs = decoder_layer(
  405. hidden_states,
  406. position_embeddings=position_embeddings,
  407. attention_mask=causal_mask_mapping[decoder_layer.attention_type],
  408. position_ids=position_ids,
  409. past_key_values=past_key_values,
  410. output_attentions=output_attentions,
  411. use_cache=use_cache,
  412. cache_position=cache_position,
  413. **kwargs,
  414. )
  415. hidden_states = layer_outputs[0]
  416. if output_attentions:
  417. all_self_attns += (layer_outputs[1],)
  418. hidden_states = self.norm(hidden_states)
  419. if output_hidden_states:
  420. all_hidden_states += (hidden_states,)
  421. return BaseModelOutputWithPast(
  422. last_hidden_state=hidden_states,
  423. past_key_values=past_key_values,
  424. hidden_states=all_hidden_states,
  425. attentions=all_self_attns,
  426. )
  427. class Gemma2ForCausalLM(GemmaForCausalLM):
  428. def __init__(self, config):
  429. super().__init__(config)
  430. self.model = Gemma2Model(config)
  431. self.post_init()
  432. def forward(
  433. self,
  434. input_ids: Optional[torch.LongTensor] = None,
  435. attention_mask: Optional[torch.Tensor] = None,
  436. position_ids: Optional[torch.LongTensor] = None,
  437. past_key_values: Optional[Cache] = None,
  438. inputs_embeds: Optional[torch.FloatTensor] = None,
  439. labels: Optional[torch.LongTensor] = None,
  440. use_cache: Optional[bool] = None,
  441. output_attentions: Optional[bool] = None,
  442. output_hidden_states: Optional[bool] = None,
  443. cache_position: Optional[torch.LongTensor] = None,
  444. logits_to_keep: Union[int, torch.Tensor] = 0,
  445. **kwargs,
  446. ) -> CausalLMOutputWithPast:
  447. r"""
  448. Example:
  449. ```python
  450. >>> from transformers import AutoTokenizer, Gemma2ForCausalLM
  451. >>> model = Gemma2ForCausalLM.from_pretrained("google/gemma-2-9b")
  452. >>> tokenizer = AutoTokenizer.from_pretrained("google/gemma-2-9b")
  453. >>> prompt = "What is your favorite condiment?"
  454. >>> inputs = tokenizer(prompt, return_tensors="pt")
  455. >>> # Generate
  456. >>> generate_ids = model.generate(inputs.input_ids, max_length=30)
  457. >>> tokenizer.batch_decode(generate_ids, skip_special_tokens=True, clean_up_tokenization_spaces=False)[0]
  458. "What is your favorite condiment?"
  459. ```"""
  460. output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
  461. output_hidden_states = (
  462. output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
  463. )
  464. # decoder outputs consists of (dec_features, layer_state, dec_hidden, dec_attn)
  465. outputs: BaseModelOutputWithPast = self.model(
  466. input_ids=input_ids,
  467. attention_mask=attention_mask,
  468. position_ids=position_ids,
  469. past_key_values=past_key_values,
  470. inputs_embeds=inputs_embeds,
  471. use_cache=use_cache,
  472. output_attentions=output_attentions,
  473. output_hidden_states=output_hidden_states,
  474. cache_position=cache_position,
  475. **kwargs,
  476. )
  477. hidden_states = outputs.last_hidden_state
  478. # Only compute necessary logits, and do not upcast them to float if we are not computing the loss
  479. slice_indices = slice(-logits_to_keep, None) if isinstance(logits_to_keep, int) else logits_to_keep
  480. logits = self.lm_head(hidden_states[:, slice_indices, :])
  481. if self.config.final_logit_softcapping is not None:
  482. logits = logits / self.config.final_logit_softcapping
  483. logits = torch.tanh(logits)
  484. logits = logits * self.config.final_logit_softcapping
  485. loss = None
  486. if labels is not None:
  487. loss = self.loss_function(logits, labels, self.vocab_size, **kwargs)
  488. return CausalLMOutputWithPast(
  489. loss=loss,
  490. logits=logits,
  491. past_key_values=outputs.past_key_values,
  492. hidden_states=outputs.hidden_states,
  493. attentions=outputs.attentions,
  494. )
  495. class Gemma2ForSequenceClassification(GemmaForSequenceClassification):
  496. pass
  497. class Gemma2ForTokenClassification(GemmaForTokenClassification):
  498. pass
  499. __all__ = [
  500. "Gemma2Config",
  501. "Gemma2ForCausalLM",
  502. "Gemma2Model",
  503. "Gemma2PreTrainedModel",
  504. "Gemma2ForSequenceClassification",
  505. "Gemma2ForTokenClassification",
  506. ]