modeling_gemma.py 21 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511
  1. # 🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨
  2. # This file was automatically generated from src/transformers/models/gemma/modular_gemma.py.
  3. # Do NOT edit this file manually as any edits will be overwritten by the generation of
  4. # the file from the modular. If any change should be done, please apply the change to the
  5. # modular_gemma.py file directly. One of our CI enforces this.
  6. # 🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨
  7. # coding=utf-8
  8. # Copyright 2024 Google Inc. HuggingFace Inc. team. All rights reserved.
  9. #
  10. #
  11. # Licensed under the Apache License, Version 2.0 (the "License");
  12. # you may not use this file except in compliance with the License.
  13. # You may obtain a copy of the License at
  14. #
  15. # http://www.apache.org/licenses/LICENSE-2.0
  16. #
  17. # Unless required by applicable law or agreed to in writing, software
  18. # distributed under the License is distributed on an "AS IS" BASIS,
  19. # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  20. # See the License for the specific language governing permissions and
  21. # limitations under the License.
  22. from typing import Callable, Optional, Union
  23. import torch
  24. from torch import nn
  25. from ...activations import ACT2FN
  26. from ...cache_utils import Cache, DynamicCache
  27. from ...generation import GenerationMixin
  28. from ...masking_utils import create_causal_mask
  29. from ...modeling_layers import (
  30. GenericForSequenceClassification,
  31. GenericForTokenClassification,
  32. GradientCheckpointingLayer,
  33. )
  34. from ...modeling_outputs import BaseModelOutputWithPast, CausalLMOutputWithPast
  35. from ...modeling_rope_utils import ROPE_INIT_FUNCTIONS, dynamic_rope_update
  36. from ...modeling_utils import ALL_ATTENTION_FUNCTIONS, PreTrainedModel
  37. from ...processing_utils import Unpack
  38. from ...utils import TransformersKwargs, auto_docstring, can_return_tuple
  39. from ...utils.deprecation import deprecate_kwarg
  40. from ...utils.generic import check_model_inputs
  41. from .configuration_gemma import GemmaConfig
  42. class GemmaRMSNorm(nn.Module):
  43. def __init__(self, dim: int, eps: float = 1e-6):
  44. super().__init__()
  45. self.eps = eps
  46. self.weight = nn.Parameter(torch.zeros(dim))
  47. def _norm(self, x):
  48. return x * torch.rsqrt(x.pow(2).mean(-1, keepdim=True) + self.eps)
  49. def forward(self, x):
  50. output = self._norm(x.float())
  51. # Llama does x.to(float16) * w whilst Gemma is (x * w).to(float16)
  52. # See https://github.com/huggingface/transformers/pull/29402
  53. output = output * (1.0 + self.weight.float())
  54. return output.type_as(x)
  55. def extra_repr(self):
  56. return f"{tuple(self.weight.shape)}, eps={self.eps}"
  57. class GemmaMLP(nn.Module):
  58. def __init__(self, config):
  59. super().__init__()
  60. self.config = config
  61. self.hidden_size = config.hidden_size
  62. self.intermediate_size = config.intermediate_size
  63. self.gate_proj = nn.Linear(self.hidden_size, self.intermediate_size, bias=False)
  64. self.up_proj = nn.Linear(self.hidden_size, self.intermediate_size, bias=False)
  65. self.down_proj = nn.Linear(self.intermediate_size, self.hidden_size, bias=False)
  66. self.act_fn = ACT2FN[config.hidden_act]
  67. def forward(self, x):
  68. down_proj = self.down_proj(self.act_fn(self.gate_proj(x)) * self.up_proj(x))
  69. return down_proj
  70. class GemmaRotaryEmbedding(nn.Module):
  71. inv_freq: torch.Tensor # fix linting for `register_buffer`
  72. def __init__(self, config: GemmaConfig, device=None):
  73. super().__init__()
  74. # BC: "rope_type" was originally "type"
  75. if hasattr(config, "rope_scaling") and isinstance(config.rope_scaling, dict):
  76. self.rope_type = config.rope_scaling.get("rope_type", config.rope_scaling.get("type"))
  77. else:
  78. self.rope_type = "default"
  79. self.max_seq_len_cached = config.max_position_embeddings
  80. self.original_max_seq_len = config.max_position_embeddings
  81. self.config = config
  82. self.rope_init_fn = ROPE_INIT_FUNCTIONS[self.rope_type]
  83. inv_freq, self.attention_scaling = self.rope_init_fn(self.config, device)
  84. self.register_buffer("inv_freq", inv_freq, persistent=False)
  85. self.original_inv_freq = self.inv_freq
  86. @torch.no_grad()
  87. @dynamic_rope_update # power user: used with advanced RoPE types (e.g. dynamic rope)
  88. def forward(self, x, position_ids):
  89. inv_freq_expanded = self.inv_freq[None, :, None].float().expand(position_ids.shape[0], -1, 1).to(x.device)
  90. position_ids_expanded = position_ids[:, None, :].float()
  91. device_type = x.device.type if isinstance(x.device.type, str) and x.device.type != "mps" else "cpu"
  92. with torch.autocast(device_type=device_type, enabled=False): # Force float32
  93. freqs = (inv_freq_expanded.float() @ position_ids_expanded.float()).transpose(1, 2)
  94. emb = torch.cat((freqs, freqs), dim=-1)
  95. cos = emb.cos() * self.attention_scaling
  96. sin = emb.sin() * self.attention_scaling
  97. return cos.to(dtype=x.dtype), sin.to(dtype=x.dtype)
  98. def rotate_half(x):
  99. """Rotates half the hidden dims of the input."""
  100. x1 = x[..., : x.shape[-1] // 2]
  101. x2 = x[..., x.shape[-1] // 2 :]
  102. return torch.cat((-x2, x1), dim=-1)
  103. def apply_rotary_pos_emb(q, k, cos, sin, position_ids=None, unsqueeze_dim=1):
  104. """Applies Rotary Position Embedding to the query and key tensors.
  105. Args:
  106. q (`torch.Tensor`): The query tensor.
  107. k (`torch.Tensor`): The key tensor.
  108. cos (`torch.Tensor`): The cosine part of the rotary embedding.
  109. sin (`torch.Tensor`): The sine part of the rotary embedding.
  110. position_ids (`torch.Tensor`, *optional*):
  111. Deprecated and unused.
  112. unsqueeze_dim (`int`, *optional*, defaults to 1):
  113. The 'unsqueeze_dim' argument specifies the dimension along which to unsqueeze cos[position_ids] and
  114. sin[position_ids] so that they can be properly broadcasted to the dimensions of q and k. For example, note
  115. that cos[position_ids] and sin[position_ids] have the shape [batch_size, seq_len, head_dim]. Then, if q and
  116. k have the shape [batch_size, heads, seq_len, head_dim], then setting unsqueeze_dim=1 makes
  117. cos[position_ids] and sin[position_ids] broadcastable to the shapes of q and k. Similarly, if q and k have
  118. the shape [batch_size, seq_len, heads, head_dim], then set unsqueeze_dim=2.
  119. Returns:
  120. `tuple(torch.Tensor)` comprising of the query and key tensors rotated using the Rotary Position Embedding.
  121. """
  122. cos = cos.unsqueeze(unsqueeze_dim)
  123. sin = sin.unsqueeze(unsqueeze_dim)
  124. q_embed = (q * cos) + (rotate_half(q) * sin)
  125. k_embed = (k * cos) + (rotate_half(k) * sin)
  126. return q_embed, k_embed
  127. def repeat_kv(hidden_states: torch.Tensor, n_rep: int) -> torch.Tensor:
  128. """
  129. This is the equivalent of torch.repeat_interleave(x, dim=1, repeats=n_rep). The hidden states go from (batch,
  130. num_key_value_heads, seqlen, head_dim) to (batch, num_attention_heads, seqlen, head_dim)
  131. """
  132. batch, num_key_value_heads, slen, head_dim = hidden_states.shape
  133. if n_rep == 1:
  134. return hidden_states
  135. hidden_states = hidden_states[:, :, None, :, :].expand(batch, num_key_value_heads, n_rep, slen, head_dim)
  136. return hidden_states.reshape(batch, num_key_value_heads * n_rep, slen, head_dim)
  137. def eager_attention_forward(
  138. module: nn.Module,
  139. query: torch.Tensor,
  140. key: torch.Tensor,
  141. value: torch.Tensor,
  142. attention_mask: Optional[torch.Tensor],
  143. scaling: float,
  144. dropout: float = 0.0,
  145. **kwargs: Unpack[TransformersKwargs],
  146. ):
  147. key_states = repeat_kv(key, module.num_key_value_groups)
  148. value_states = repeat_kv(value, module.num_key_value_groups)
  149. attn_weights = torch.matmul(query, key_states.transpose(2, 3)) * scaling
  150. if attention_mask is not None:
  151. causal_mask = attention_mask[:, :, :, : key_states.shape[-2]]
  152. attn_weights = attn_weights + causal_mask
  153. attn_weights = nn.functional.softmax(attn_weights, dim=-1, dtype=torch.float32).to(query.dtype)
  154. attn_weights = nn.functional.dropout(attn_weights, p=dropout, training=module.training)
  155. attn_output = torch.matmul(attn_weights, value_states)
  156. attn_output = attn_output.transpose(1, 2).contiguous()
  157. return attn_output, attn_weights
  158. class GemmaAttention(nn.Module):
  159. """Multi-headed attention from 'Attention Is All You Need' paper"""
  160. def __init__(self, config: GemmaConfig, layer_idx: int):
  161. super().__init__()
  162. self.config = config
  163. self.layer_idx = layer_idx
  164. self.head_dim = getattr(config, "head_dim", config.hidden_size // config.num_attention_heads)
  165. self.num_key_value_groups = config.num_attention_heads // config.num_key_value_heads
  166. self.scaling = self.head_dim**-0.5
  167. self.attention_dropout = config.attention_dropout
  168. self.is_causal = True
  169. self.q_proj = nn.Linear(
  170. config.hidden_size, config.num_attention_heads * self.head_dim, bias=config.attention_bias
  171. )
  172. self.k_proj = nn.Linear(
  173. config.hidden_size, config.num_key_value_heads * self.head_dim, bias=config.attention_bias
  174. )
  175. self.v_proj = nn.Linear(
  176. config.hidden_size, config.num_key_value_heads * self.head_dim, bias=config.attention_bias
  177. )
  178. self.o_proj = nn.Linear(
  179. config.num_attention_heads * self.head_dim, config.hidden_size, bias=config.attention_bias
  180. )
  181. @deprecate_kwarg("past_key_value", new_name="past_key_values", version="4.58")
  182. def forward(
  183. self,
  184. hidden_states: torch.Tensor,
  185. position_embeddings: tuple[torch.Tensor, torch.Tensor],
  186. attention_mask: Optional[torch.Tensor],
  187. past_key_values: Optional[Cache] = None,
  188. cache_position: Optional[torch.LongTensor] = None,
  189. **kwargs: Unpack[TransformersKwargs],
  190. ) -> tuple[torch.Tensor, torch.Tensor]:
  191. input_shape = hidden_states.shape[:-1]
  192. hidden_shape = (*input_shape, -1, self.head_dim)
  193. query_states = self.q_proj(hidden_states).view(hidden_shape).transpose(1, 2)
  194. key_states = self.k_proj(hidden_states).view(hidden_shape).transpose(1, 2)
  195. value_states = self.v_proj(hidden_states).view(hidden_shape).transpose(1, 2)
  196. cos, sin = position_embeddings
  197. query_states, key_states = apply_rotary_pos_emb(query_states, key_states, cos, sin)
  198. if past_key_values is not None:
  199. # sin and cos are specific to RoPE models; cache_position needed for the static cache
  200. cache_kwargs = {"sin": sin, "cos": cos, "cache_position": cache_position}
  201. key_states, value_states = past_key_values.update(key_states, value_states, self.layer_idx, cache_kwargs)
  202. attention_interface: Callable = eager_attention_forward
  203. if self.config._attn_implementation != "eager":
  204. attention_interface = ALL_ATTENTION_FUNCTIONS[self.config._attn_implementation]
  205. attn_output, attn_weights = attention_interface(
  206. self,
  207. query_states,
  208. key_states,
  209. value_states,
  210. attention_mask,
  211. dropout=0.0 if not self.training else self.attention_dropout,
  212. scaling=self.scaling,
  213. **kwargs,
  214. )
  215. attn_output = attn_output.reshape(*input_shape, -1).contiguous()
  216. attn_output = self.o_proj(attn_output)
  217. return attn_output, attn_weights
  218. class GemmaDecoderLayer(GradientCheckpointingLayer):
  219. def __init__(self, config: GemmaConfig, layer_idx: int):
  220. super().__init__()
  221. self.hidden_size = config.hidden_size
  222. self.self_attn = GemmaAttention(config=config, layer_idx=layer_idx)
  223. self.mlp = GemmaMLP(config)
  224. self.input_layernorm = GemmaRMSNorm(config.hidden_size, eps=config.rms_norm_eps)
  225. self.post_attention_layernorm = GemmaRMSNorm(config.hidden_size, eps=config.rms_norm_eps)
  226. @deprecate_kwarg("past_key_value", new_name="past_key_values", version="4.58")
  227. def forward(
  228. self,
  229. hidden_states: torch.Tensor,
  230. attention_mask: Optional[torch.Tensor] = None,
  231. position_ids: Optional[torch.LongTensor] = None,
  232. past_key_values: Optional[Cache] = None,
  233. use_cache: Optional[bool] = False,
  234. cache_position: Optional[torch.LongTensor] = None,
  235. position_embeddings: Optional[tuple[torch.Tensor, torch.Tensor]] = None, # necessary, but kept here for BC
  236. **kwargs: Unpack[TransformersKwargs],
  237. ) -> torch.Tensor:
  238. residual = hidden_states
  239. hidden_states = self.input_layernorm(hidden_states)
  240. # Self Attention
  241. hidden_states, _ = self.self_attn(
  242. hidden_states=hidden_states,
  243. attention_mask=attention_mask,
  244. position_ids=position_ids,
  245. past_key_values=past_key_values,
  246. use_cache=use_cache,
  247. cache_position=cache_position,
  248. position_embeddings=position_embeddings,
  249. **kwargs,
  250. )
  251. hidden_states = residual + hidden_states
  252. # Fully Connected
  253. residual = hidden_states
  254. hidden_states = self.post_attention_layernorm(hidden_states)
  255. hidden_states = self.mlp(hidden_states)
  256. hidden_states = residual + hidden_states
  257. return hidden_states
  258. @auto_docstring
  259. class GemmaPreTrainedModel(PreTrainedModel):
  260. config: GemmaConfig
  261. base_model_prefix = "model"
  262. supports_gradient_checkpointing = True
  263. _no_split_modules = ["GemmaDecoderLayer"]
  264. _skip_keys_device_placement = ["past_key_values"]
  265. _supports_flash_attn = True
  266. _supports_sdpa = True
  267. _supports_flex_attn = True
  268. _can_compile_fullgraph = True
  269. _supports_attention_backend = True
  270. _can_record_outputs = {
  271. "hidden_states": GemmaDecoderLayer,
  272. "attentions": GemmaAttention,
  273. }
  274. def _init_weights(self, module):
  275. super()._init_weights(module)
  276. # We initialize with 0s to be 1 centered as the RMSNorm here does (1 + weight)
  277. if "RMSNorm" in module.__class__.__name__:
  278. module.weight.data.zero_()
  279. @auto_docstring
  280. class GemmaModel(GemmaPreTrainedModel):
  281. def __init__(self, config: GemmaConfig):
  282. super().__init__(config)
  283. self.padding_idx = config.pad_token_id
  284. self.vocab_size = config.vocab_size
  285. self.embed_tokens = nn.Embedding(config.vocab_size, config.hidden_size, self.padding_idx)
  286. self.layers = nn.ModuleList(
  287. [GemmaDecoderLayer(config, layer_idx) for layer_idx in range(config.num_hidden_layers)]
  288. )
  289. self.norm = GemmaRMSNorm(config.hidden_size, eps=config.rms_norm_eps)
  290. self.rotary_emb = GemmaRotaryEmbedding(config=config)
  291. self.gradient_checkpointing = False
  292. # Initialize weights and apply final processing
  293. self.post_init()
  294. @check_model_inputs()
  295. @auto_docstring
  296. def forward(
  297. self,
  298. input_ids: Optional[torch.LongTensor] = None,
  299. attention_mask: Optional[torch.Tensor] = None,
  300. position_ids: Optional[torch.LongTensor] = None,
  301. past_key_values: Optional[Cache] = None,
  302. inputs_embeds: Optional[torch.FloatTensor] = None,
  303. use_cache: Optional[bool] = None,
  304. cache_position: Optional[torch.LongTensor] = None,
  305. **kwargs: Unpack[TransformersKwargs],
  306. ) -> BaseModelOutputWithPast:
  307. if (input_ids is None) ^ (inputs_embeds is not None):
  308. raise ValueError("You must specify exactly one of input_ids or inputs_embeds")
  309. if inputs_embeds is None:
  310. inputs_embeds = self.embed_tokens(input_ids)
  311. if use_cache and past_key_values is None:
  312. past_key_values = DynamicCache(config=self.config)
  313. if cache_position is None:
  314. past_seen_tokens = past_key_values.get_seq_length() if past_key_values is not None else 0
  315. cache_position = torch.arange(
  316. past_seen_tokens, past_seen_tokens + inputs_embeds.shape[1], device=inputs_embeds.device
  317. )
  318. if position_ids is None:
  319. position_ids = cache_position.unsqueeze(0)
  320. causal_mask = create_causal_mask(
  321. config=self.config,
  322. input_embeds=inputs_embeds,
  323. attention_mask=attention_mask,
  324. cache_position=cache_position,
  325. past_key_values=past_key_values,
  326. position_ids=position_ids,
  327. )
  328. # embed positions
  329. hidden_states = inputs_embeds
  330. # create position embeddings to be shared across the decoder layers
  331. position_embeddings = self.rotary_emb(hidden_states, position_ids)
  332. # normalized
  333. # Gemma downcasts the below to float16, causing sqrt(3072)=55.4256 to become 55.5
  334. # See https://github.com/huggingface/transformers/pull/29402
  335. normalizer = torch.tensor(self.config.hidden_size**0.5, dtype=hidden_states.dtype)
  336. hidden_states = hidden_states * normalizer
  337. for decoder_layer in self.layers[: self.config.num_hidden_layers]:
  338. hidden_states = decoder_layer(
  339. hidden_states,
  340. attention_mask=causal_mask,
  341. position_ids=position_ids,
  342. past_key_values=past_key_values,
  343. use_cache=use_cache,
  344. cache_position=cache_position,
  345. position_embeddings=position_embeddings,
  346. **kwargs,
  347. )
  348. hidden_states = self.norm(hidden_states)
  349. return BaseModelOutputWithPast(
  350. last_hidden_state=hidden_states,
  351. past_key_values=past_key_values if use_cache else None,
  352. )
  353. @auto_docstring
  354. class GemmaForCausalLM(GemmaPreTrainedModel, GenerationMixin):
  355. _tied_weights_keys = ["lm_head.weight"]
  356. _tp_plan = {"lm_head": "colwise_rep"}
  357. _pp_plan = {"lm_head": (["hidden_states"], ["logits"])}
  358. def __init__(self, config):
  359. super().__init__(config)
  360. self.model = GemmaModel(config)
  361. self.vocab_size = config.vocab_size
  362. self.lm_head = nn.Linear(config.hidden_size, config.vocab_size, bias=False)
  363. # Initialize weights and apply final processing
  364. self.post_init()
  365. @can_return_tuple
  366. @auto_docstring
  367. def forward(
  368. self,
  369. input_ids: Optional[torch.LongTensor] = None,
  370. attention_mask: Optional[torch.Tensor] = None,
  371. position_ids: Optional[torch.LongTensor] = None,
  372. past_key_values: Optional[Cache] = None,
  373. inputs_embeds: Optional[torch.FloatTensor] = None,
  374. labels: Optional[torch.LongTensor] = None,
  375. use_cache: Optional[bool] = None,
  376. cache_position: Optional[torch.LongTensor] = None,
  377. logits_to_keep: Union[int, torch.Tensor] = 0,
  378. **kwargs: Unpack[TransformersKwargs],
  379. ) -> CausalLMOutputWithPast:
  380. r"""
  381. Example:
  382. ```python
  383. >>> from transformers import AutoTokenizer, GemmaForCausalLM
  384. >>> model = GemmaForCausalLM.from_pretrained("google/gemma-7b")
  385. >>> tokenizer = AutoTokenizer.from_pretrained("google/gemma-7b")
  386. >>> prompt = "What is your favorite condiment?"
  387. >>> inputs = tokenizer(prompt, return_tensors="pt")
  388. >>> # Generate
  389. >>> generate_ids = model.generate(inputs.input_ids, max_length=30)
  390. >>> tokenizer.batch_decode(generate_ids, skip_special_tokens=True, clean_up_tokenization_spaces=False)[0]
  391. "What is your favorite condiment?"
  392. ```"""
  393. outputs: BaseModelOutputWithPast = self.model(
  394. input_ids=input_ids,
  395. attention_mask=attention_mask,
  396. position_ids=position_ids,
  397. past_key_values=past_key_values,
  398. inputs_embeds=inputs_embeds,
  399. use_cache=use_cache,
  400. cache_position=cache_position,
  401. **kwargs,
  402. )
  403. hidden_states = outputs.last_hidden_state
  404. # Only compute necessary logits, and do not upcast them to float if we are not computing the loss
  405. slice_indices = slice(-logits_to_keep, None) if isinstance(logits_to_keep, int) else logits_to_keep
  406. logits = self.lm_head(hidden_states[:, slice_indices, :])
  407. loss = None
  408. if labels is not None:
  409. loss = self.loss_function(logits=logits, labels=labels, vocab_size=self.config.vocab_size, **kwargs)
  410. return CausalLMOutputWithPast(
  411. loss=loss,
  412. logits=logits,
  413. past_key_values=outputs.past_key_values,
  414. hidden_states=outputs.hidden_states,
  415. attentions=outputs.attentions,
  416. )
  417. class GemmaForSequenceClassification(GenericForSequenceClassification, GemmaPreTrainedModel):
  418. pass
  419. class GemmaForTokenClassification(GenericForTokenClassification, GemmaPreTrainedModel):
  420. pass
  421. __all__ = [
  422. "GemmaModel",
  423. "GemmaForCausalLM",
  424. "GemmaForSequenceClassification",
  425. "GemmaForTokenClassification",
  426. "GemmaPreTrainedModel",
  427. ]