modeling_gemma3.py 58 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525526527528529530531532533534535536537538539540541542543544545546547548549550551552553554555556557558559560561562563564565566567568569570571572573574575576577578579580581582583584585586587588589590591592593594595596597598599600601602603604605606607608609610611612613614615616617618619620621622623624625626627628629630631632633634635636637638639640641642643644645646647648649650651652653654655656657658659660661662663664665666667668669670671672673674675676677678679680681682683684685686687688689690691692693694695696697698699700701702703704705706707708709710711712713714715716717718719720721722723724725726727728729730731732733734735736737738739740741742743744745746747748749750751752753754755756757758759760761762763764765766767768769770771772773774775776777778779780781782783784785786787788789790791792793794795796797798799800801802803804805806807808809810811812813814815816817818819820821822823824825826827828829830831832833834835836837838839840841842843844845846847848849850851852853854855856857858859860861862863864865866867868869870871872873874875876877878879880881882883884885886887888889890891892893894895896897898899900901902903904905906907908909910911912913914915916917918919920921922923924925926927928929930931932933934935936937938939940941942943944945946947948949950951952953954955956957958959960961962963964965966967968969970971972973974975976977978979980981982983984985986987988989990991992993994995996997998999100010011002100310041005100610071008100910101011101210131014101510161017101810191020102110221023102410251026102710281029103010311032103310341035103610371038103910401041104210431044104510461047104810491050105110521053105410551056105710581059106010611062106310641065106610671068106910701071107210731074107510761077107810791080108110821083108410851086108710881089109010911092109310941095109610971098109911001101110211031104110511061107110811091110111111121113111411151116111711181119112011211122112311241125112611271128112911301131113211331134113511361137113811391140114111421143114411451146114711481149115011511152115311541155115611571158115911601161116211631164116511661167116811691170117111721173117411751176117711781179118011811182118311841185118611871188118911901191119211931194119511961197119811991200120112021203120412051206120712081209121012111212121312141215121612171218121912201221122212231224122512261227122812291230123112321233123412351236123712381239124012411242124312441245124612471248124912501251125212531254125512561257125812591260126112621263126412651266126712681269127012711272127312741275127612771278127912801281128212831284128512861287128812891290129112921293129412951296129712981299130013011302130313041305130613071308130913101311131213131314131513161317131813191320132113221323132413251326132713281329133013311332133313341335133613371338133913401341
  1. # 🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨
  2. # This file was automatically generated from src/transformers/models/gemma3/modular_gemma3.py.
  3. # Do NOT edit this file manually as any edits will be overwritten by the generation of
  4. # the file from the modular. If any change should be done, please apply the change to the
  5. # modular_gemma3.py file directly. One of our CI enforces this.
  6. # 🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨
  7. # coding=utf-8
  8. # Copyright 2025 Google Inc. HuggingFace Inc. team. All rights reserved.
  9. #
  10. #
  11. # Licensed under the Apache License, Version 2.0 (the "License");
  12. # you may not use this file except in compliance with the License.
  13. # You may obtain a copy of the License at
  14. #
  15. # http://www.apache.org/licenses/LICENSE-2.0
  16. #
  17. # Unless required by applicable law or agreed to in writing, software
  18. # distributed under the License is distributed on an "AS IS" BASIS,
  19. # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  20. # See the License for the specific language governing permissions and
  21. # limitations under the License.
  22. import copy
  23. from collections.abc import Callable
  24. from dataclasses import dataclass
  25. from typing import Optional, Union
  26. import torch
  27. import torch.nn as nn
  28. from ...activations import ACT2FN
  29. from ...cache_utils import Cache, DynamicCache
  30. from ...configuration_utils import PretrainedConfig
  31. from ...generation import GenerationMixin
  32. from ...masking_utils import create_causal_mask, create_masks_for_generate, create_sliding_window_causal_mask
  33. from ...modeling_flash_attention_utils import FlashAttentionKwargs
  34. from ...modeling_layers import GenericForSequenceClassification, GradientCheckpointingLayer
  35. from ...modeling_outputs import BaseModelOutputWithPast, CausalLMOutputWithPast, SequenceClassifierOutputWithPast
  36. from ...modeling_rope_utils import ROPE_INIT_FUNCTIONS, dynamic_rope_update
  37. from ...modeling_utils import ALL_ATTENTION_FUNCTIONS, PreTrainedModel
  38. from ...processing_utils import Unpack
  39. from ...utils import ModelOutput, TransformersKwargs, auto_docstring, can_return_tuple, logging
  40. from ...utils.deprecation import deprecate_kwarg
  41. from ...utils.generic import check_model_inputs
  42. from ..auto import AutoModel
  43. from .configuration_gemma3 import Gemma3Config, Gemma3TextConfig
  44. logger = logging.get_logger(__name__)
  45. @dataclass
  46. @auto_docstring(
  47. custom_intro="""
  48. Base class for Gemma3 outputs, with hidden states and attentions.
  49. """
  50. )
  51. class Gemma3ModelOutputWithPast(BaseModelOutputWithPast):
  52. r"""
  53. image_hidden_states (`torch.FloatTensor`, *optional*):
  54. A `torch.FloatTensor` of size `(batch_size, num_images, sequence_length, hidden_size)`.
  55. image_hidden_states of the model produced by the vision encoder and after projecting the last hidden state.
  56. """
  57. image_hidden_states: Optional[torch.FloatTensor] = None
  58. @dataclass
  59. @auto_docstring(
  60. custom_intro="""
  61. Base class for Gemma3 causal language model (or autoregressive) outputs.
  62. """
  63. )
  64. class Gemma3CausalLMOutputWithPast(ModelOutput):
  65. r"""
  66. loss (`torch.FloatTensor` of shape `(1,)`, *optional*, returned when `labels` is provided):
  67. Language modeling loss (for next-token prediction).
  68. logits (`torch.FloatTensor` of shape `(batch_size, sequence_length, config.text_config.vocab_size)`):
  69. Prediction scores of the language modeling head (scores for each vocabulary token before SoftMax).
  70. past_key_values (`Cache`, *optional*, returned when `use_cache=True` is passed or when `config.use_cache=True`):
  71. It is a [`~cache_utils.Cache`] instance. For more details, see our [kv cache guide](https://huggingface.co/docs/transformers/en/kv_cache).
  72. Contains pre-computed hidden-states (key and values in the self-attention blocks) that can be used (see
  73. `past_key_values` input) to speed up sequential decoding.
  74. image_hidden_states (`torch.FloatTensor`, *optional*):
  75. A `torch.FloatTensor` of size `(batch_size, num_images, sequence_length, hidden_size)`.
  76. image_hidden_states of the model produced by the vision encoder after projecting last hidden state.
  77. """
  78. loss: Optional[torch.FloatTensor] = None
  79. logits: Optional[torch.FloatTensor] = None
  80. past_key_values: Optional[Cache] = None
  81. hidden_states: Optional[tuple[torch.FloatTensor]] = None
  82. attentions: Optional[tuple[torch.FloatTensor]] = None
  83. image_hidden_states: Optional[torch.FloatTensor] = None
  84. class Gemma3TextScaledWordEmbedding(nn.Embedding):
  85. """
  86. This module overrides nn.Embeddings' forward by multiplying with embeddings scale.
  87. """
  88. def __init__(self, num_embeddings: int, embedding_dim: int, padding_idx: int, embed_scale: float = 1.0):
  89. super().__init__(num_embeddings, embedding_dim, padding_idx)
  90. self.register_buffer("embed_scale", torch.tensor(embed_scale), persistent=False)
  91. def forward(self, input_ids: torch.Tensor):
  92. return super().forward(input_ids) * self.embed_scale.to(self.weight.dtype)
  93. class Gemma3MLP(nn.Module):
  94. def __init__(self, config: Gemma3TextConfig):
  95. super().__init__()
  96. self.config = config
  97. self.hidden_size = config.hidden_size
  98. self.intermediate_size = config.intermediate_size
  99. self.gate_proj = nn.Linear(self.hidden_size, self.intermediate_size, bias=False)
  100. self.up_proj = nn.Linear(self.hidden_size, self.intermediate_size, bias=False)
  101. self.down_proj = nn.Linear(self.intermediate_size, self.hidden_size, bias=False)
  102. self.act_fn = ACT2FN[config.hidden_activation]
  103. def forward(self, x):
  104. down_proj = self.down_proj(self.act_fn(self.gate_proj(x)) * self.up_proj(x))
  105. return down_proj
  106. class Gemma3RMSNorm(nn.Module):
  107. def __init__(self, dim: int, eps: float = 1e-6):
  108. super().__init__()
  109. self.eps = eps
  110. self.weight = nn.Parameter(torch.zeros(dim))
  111. def _norm(self, x):
  112. return x * torch.rsqrt(x.pow(2).mean(-1, keepdim=True) + self.eps)
  113. def forward(self, x):
  114. output = self._norm(x.float())
  115. # Llama does x.to(float16) * w whilst Gemma3 is (x * w).to(float16)
  116. # See https://github.com/huggingface/transformers/pull/29402
  117. output = output * (1.0 + self.weight.float())
  118. return output.type_as(x)
  119. def extra_repr(self):
  120. return f"{tuple(self.weight.shape)}, eps={self.eps}"
  121. class Gemma3RotaryEmbedding(nn.Module):
  122. inv_freq: torch.Tensor # fix linting for `register_buffer`
  123. def __init__(self, config: Gemma3TextConfig, device=None):
  124. super().__init__()
  125. # BC: "rope_type" was originally "type"
  126. if hasattr(config, "rope_scaling") and isinstance(config.rope_scaling, dict):
  127. self.rope_type = config.rope_scaling.get("rope_type", config.rope_scaling.get("type"))
  128. else:
  129. self.rope_type = "default"
  130. self.max_seq_len_cached = config.max_position_embeddings
  131. self.original_max_seq_len = config.max_position_embeddings
  132. self.config = config
  133. self.rope_init_fn = ROPE_INIT_FUNCTIONS[self.rope_type]
  134. inv_freq, self.attention_scaling = self.rope_init_fn(self.config, device)
  135. self.register_buffer("inv_freq", inv_freq, persistent=False)
  136. self.original_inv_freq = self.inv_freq
  137. @torch.no_grad()
  138. @dynamic_rope_update # power user: used with advanced RoPE types (e.g. dynamic rope)
  139. def forward(self, x, position_ids):
  140. inv_freq_expanded = self.inv_freq[None, :, None].float().expand(position_ids.shape[0], -1, 1).to(x.device)
  141. position_ids_expanded = position_ids[:, None, :].float()
  142. device_type = x.device.type if isinstance(x.device.type, str) and x.device.type != "mps" else "cpu"
  143. with torch.autocast(device_type=device_type, enabled=False): # Force float32
  144. freqs = (inv_freq_expanded.float() @ position_ids_expanded.float()).transpose(1, 2)
  145. emb = torch.cat((freqs, freqs), dim=-1)
  146. cos = emb.cos() * self.attention_scaling
  147. sin = emb.sin() * self.attention_scaling
  148. return cos.to(dtype=x.dtype), sin.to(dtype=x.dtype)
  149. def rotate_half(x):
  150. """Rotates half the hidden dims of the input."""
  151. x1 = x[..., : x.shape[-1] // 2]
  152. x2 = x[..., x.shape[-1] // 2 :]
  153. return torch.cat((-x2, x1), dim=-1)
  154. def apply_rotary_pos_emb(q, k, cos, sin, position_ids=None, unsqueeze_dim=1):
  155. """Applies Rotary Position Embedding to the query and key tensors.
  156. Args:
  157. q (`torch.Tensor`): The query tensor.
  158. k (`torch.Tensor`): The key tensor.
  159. cos (`torch.Tensor`): The cosine part of the rotary embedding.
  160. sin (`torch.Tensor`): The sine part of the rotary embedding.
  161. position_ids (`torch.Tensor`, *optional*):
  162. Deprecated and unused.
  163. unsqueeze_dim (`int`, *optional*, defaults to 1):
  164. The 'unsqueeze_dim' argument specifies the dimension along which to unsqueeze cos[position_ids] and
  165. sin[position_ids] so that they can be properly broadcasted to the dimensions of q and k. For example, note
  166. that cos[position_ids] and sin[position_ids] have the shape [batch_size, seq_len, head_dim]. Then, if q and
  167. k have the shape [batch_size, heads, seq_len, head_dim], then setting unsqueeze_dim=1 makes
  168. cos[position_ids] and sin[position_ids] broadcastable to the shapes of q and k. Similarly, if q and k have
  169. the shape [batch_size, seq_len, heads, head_dim], then set unsqueeze_dim=2.
  170. Returns:
  171. `tuple(torch.Tensor)` comprising of the query and key tensors rotated using the Rotary Position Embedding.
  172. """
  173. cos = cos.unsqueeze(unsqueeze_dim)
  174. sin = sin.unsqueeze(unsqueeze_dim)
  175. q_embed = (q * cos) + (rotate_half(q) * sin)
  176. k_embed = (k * cos) + (rotate_half(k) * sin)
  177. return q_embed, k_embed
  178. def repeat_kv(hidden_states: torch.Tensor, n_rep: int) -> torch.Tensor:
  179. """
  180. This is the equivalent of torch.repeat_interleave(x, dim=1, repeats=n_rep). The hidden states go from (batch,
  181. num_key_value_heads, seqlen, head_dim) to (batch, num_attention_heads, seqlen, head_dim)
  182. """
  183. batch, num_key_value_heads, slen, head_dim = hidden_states.shape
  184. if n_rep == 1:
  185. return hidden_states
  186. hidden_states = hidden_states[:, :, None, :, :].expand(batch, num_key_value_heads, n_rep, slen, head_dim)
  187. return hidden_states.reshape(batch, num_key_value_heads * n_rep, slen, head_dim)
  188. def eager_attention_forward(
  189. module: nn.Module,
  190. query: torch.Tensor,
  191. key: torch.Tensor,
  192. value: torch.Tensor,
  193. attention_mask: Optional[torch.Tensor],
  194. dropout: float = 0.0,
  195. scaling: Optional[float] = None,
  196. softcap: Optional[float] = None,
  197. **kwargs,
  198. ) -> tuple[torch.Tensor, torch.Tensor]:
  199. if scaling is None:
  200. scaling = module.head_dim**-0.5
  201. key_states = repeat_kv(key, module.num_key_value_groups)
  202. value_states = repeat_kv(value, module.num_key_value_groups)
  203. attn_weights = torch.matmul(query, key_states.transpose(2, 3)) * scaling
  204. if softcap is not None:
  205. attn_weights = attn_weights / softcap
  206. attn_weights = torch.tanh(attn_weights)
  207. attn_weights = attn_weights * softcap
  208. if attention_mask is not None: # no matter the length, we just slice it
  209. causal_mask = attention_mask[:, :, :, : key_states.shape[-2]]
  210. attn_weights = attn_weights + causal_mask
  211. # upcast attention to fp32
  212. attn_weights = nn.functional.softmax(attn_weights, dim=-1, dtype=torch.float32).to(query.dtype)
  213. attn_weights = nn.functional.dropout(attn_weights, p=dropout, training=module.training)
  214. attn_output = torch.matmul(attn_weights, value_states)
  215. attn_output = attn_output.transpose(1, 2).contiguous()
  216. return attn_output, attn_weights
  217. class Gemma3Attention(nn.Module):
  218. """Multi-headed attention from 'Attention Is All You Need' paper"""
  219. def __init__(self, config: Gemma3TextConfig, layer_idx: int):
  220. super().__init__()
  221. self.is_sliding = config.layer_types[layer_idx] == "sliding_attention"
  222. self.config = config
  223. self.layer_idx = layer_idx
  224. self.head_dim = getattr(config, "head_dim", config.hidden_size // config.num_attention_heads)
  225. self.num_key_value_groups = config.num_attention_heads // config.num_key_value_heads
  226. self.scaling = config.query_pre_attn_scalar**-0.5
  227. self.attention_dropout = self.config.attention_dropout
  228. self.is_causal = not self.config.use_bidirectional_attention
  229. self.q_proj = nn.Linear(
  230. config.hidden_size, config.num_attention_heads * self.head_dim, bias=config.attention_bias
  231. )
  232. self.k_proj = nn.Linear(
  233. config.hidden_size, config.num_key_value_heads * self.head_dim, bias=config.attention_bias
  234. )
  235. self.v_proj = nn.Linear(
  236. config.hidden_size, config.num_key_value_heads * self.head_dim, bias=config.attention_bias
  237. )
  238. self.o_proj = nn.Linear(
  239. config.num_attention_heads * self.head_dim, config.hidden_size, bias=config.attention_bias
  240. )
  241. self.attn_logit_softcapping = self.config.attn_logit_softcapping
  242. self.sliding_window = config.sliding_window if self.is_sliding else None
  243. self.q_norm = Gemma3RMSNorm(dim=config.head_dim, eps=config.rms_norm_eps)
  244. self.k_norm = Gemma3RMSNorm(dim=config.head_dim, eps=config.rms_norm_eps)
  245. @deprecate_kwarg("past_key_value", new_name="past_key_values", version="4.58")
  246. def forward(
  247. self,
  248. hidden_states: torch.Tensor,
  249. position_embeddings: torch.Tensor,
  250. attention_mask: Optional[torch.Tensor],
  251. past_key_values: Optional[Cache] = None,
  252. cache_position: Optional[torch.LongTensor] = None,
  253. **kwargs: Unpack[FlashAttentionKwargs],
  254. ) -> tuple[torch.Tensor, Optional[torch.Tensor], Optional[tuple[torch.Tensor]]]:
  255. input_shape = hidden_states.shape[:-1]
  256. hidden_shape = (*input_shape, -1, self.head_dim)
  257. query_states = self.q_proj(hidden_states).view(hidden_shape).transpose(1, 2)
  258. key_states = self.k_proj(hidden_states).view(hidden_shape).transpose(1, 2)
  259. value_states = self.v_proj(hidden_states).view(hidden_shape).transpose(1, 2)
  260. query_states = self.q_norm(query_states)
  261. key_states = self.k_norm(key_states)
  262. cos, sin = position_embeddings
  263. query_states, key_states = apply_rotary_pos_emb(query_states, key_states, cos, sin)
  264. if past_key_values is not None:
  265. # sin and cos are specific to RoPE models; cache_position needed for the static cache
  266. cache_kwargs = {"sin": sin, "cos": cos, "cache_position": cache_position}
  267. key_states, value_states = past_key_values.update(key_states, value_states, self.layer_idx, cache_kwargs)
  268. attention_interface: Callable = eager_attention_forward
  269. if self.config._attn_implementation != "eager":
  270. attention_interface = ALL_ATTENTION_FUNCTIONS[self.config._attn_implementation]
  271. attn_output, attn_weights = attention_interface(
  272. self,
  273. query_states,
  274. key_states,
  275. value_states,
  276. attention_mask,
  277. dropout=self.attention_dropout if self.training else 0.0,
  278. scaling=self.scaling,
  279. sliding_window=self.sliding_window,
  280. **kwargs,
  281. )
  282. attn_output = attn_output.reshape(*input_shape, -1).contiguous()
  283. attn_output = self.o_proj(attn_output)
  284. return attn_output, attn_weights
  285. class Gemma3DecoderLayer(GradientCheckpointingLayer):
  286. def __init__(self, config: Gemma3TextConfig, layer_idx: int):
  287. super().__init__()
  288. self.config = config
  289. self.hidden_size = config.hidden_size
  290. self.layer_idx = layer_idx
  291. self.attention_type = config.layer_types[layer_idx]
  292. self.self_attn = Gemma3Attention(config=config, layer_idx=layer_idx)
  293. self.mlp = Gemma3MLP(config)
  294. self.input_layernorm = Gemma3RMSNorm(self.hidden_size, eps=config.rms_norm_eps)
  295. self.post_attention_layernorm = Gemma3RMSNorm(self.hidden_size, eps=config.rms_norm_eps)
  296. self.pre_feedforward_layernorm = Gemma3RMSNorm(self.hidden_size, eps=config.rms_norm_eps)
  297. self.post_feedforward_layernorm = Gemma3RMSNorm(self.hidden_size, eps=config.rms_norm_eps)
  298. @deprecate_kwarg("past_key_value", new_name="past_key_values", version="4.58")
  299. def forward(
  300. self,
  301. hidden_states: torch.Tensor,
  302. position_embeddings_global: torch.Tensor,
  303. position_embeddings_local: torch.Tensor,
  304. attention_mask: Optional[torch.Tensor] = None,
  305. position_ids: Optional[torch.LongTensor] = None,
  306. past_key_values: Optional[Cache] = None,
  307. output_attentions: Optional[bool] = False,
  308. use_cache: Optional[bool] = False,
  309. cache_position: Optional[torch.LongTensor] = None,
  310. **kwargs,
  311. ) -> tuple[torch.FloatTensor, Optional[tuple[torch.FloatTensor, torch.FloatTensor]]]:
  312. residual = hidden_states
  313. hidden_states = self.input_layernorm(hidden_states)
  314. # apply global RoPE to non-sliding layer only
  315. if self.self_attn.is_sliding:
  316. position_embeddings = position_embeddings_local
  317. else:
  318. position_embeddings = position_embeddings_global
  319. hidden_states, self_attn_weights = self.self_attn(
  320. hidden_states=hidden_states,
  321. position_embeddings=position_embeddings,
  322. attention_mask=attention_mask,
  323. position_ids=position_ids,
  324. past_key_values=past_key_values,
  325. output_attentions=output_attentions,
  326. use_cache=use_cache,
  327. cache_position=cache_position,
  328. **kwargs,
  329. )
  330. hidden_states = self.post_attention_layernorm(hidden_states)
  331. hidden_states = residual + hidden_states
  332. residual = hidden_states
  333. hidden_states = self.pre_feedforward_layernorm(hidden_states)
  334. hidden_states = self.mlp(hidden_states)
  335. hidden_states = self.post_feedforward_layernorm(hidden_states)
  336. hidden_states = residual + hidden_states
  337. outputs = (hidden_states,)
  338. if output_attentions:
  339. outputs += (self_attn_weights,)
  340. return outputs
  341. @auto_docstring
  342. class Gemma3PreTrainedModel(PreTrainedModel):
  343. config: Gemma3Config
  344. base_model_prefix = ""
  345. supports_gradient_checkpointing = True
  346. _no_split_modules = [
  347. "Gemma3DecoderLayer",
  348. "SiglipVisionEmbeddings",
  349. "SiglipEncoderLayer",
  350. "SiglipMultiheadAttentionPoolingHead",
  351. ]
  352. _skip_keys_device_placement = ["past_key_values"]
  353. _supports_flash_attn = True
  354. _supports_sdpa = True
  355. _supports_flex_attn = True
  356. _can_compile_fullgraph = True
  357. _supports_attention_backend = True
  358. _can_record_outputs = {
  359. "hidden_states": Gemma3DecoderLayer,
  360. "attentions": Gemma3Attention,
  361. }
  362. def _init_weights(self, module):
  363. super()._init_weights(module)
  364. if isinstance(module, Gemma3MultiModalProjector):
  365. module.mm_input_projection_weight.data.zero_()
  366. # We initialize with 0s to be 1 centered as the RMSNorm here does (1 + weight)
  367. elif "RMSNorm" in module.__class__.__name__:
  368. module.weight.data.zero_()
  369. def _bidirectional_window_overlay(sliding_window: int) -> Callable[[int, int, int, int], bool]:
  370. """
  371. Enables a bidirectional mask within the sliding window.
  372. """
  373. def inner_mask(batch_idx: int, head_idx: int, q_idx: int, kv_idx: int) -> bool:
  374. """A token can attend to any other token if their absolute distance is within
  375. the (exclusive) sliding window size (distance < sliding_window)."""
  376. return abs(q_idx - kv_idx) < sliding_window
  377. return inner_mask
  378. @auto_docstring
  379. class Gemma3TextModel(Gemma3PreTrainedModel):
  380. config: Gemma3TextConfig
  381. def __init__(self, config: Gemma3TextConfig):
  382. super().__init__(config)
  383. self.padding_idx = config.pad_token_id
  384. self.vocab_size = config.vocab_size
  385. # Gemma3 downcasts the below to bfloat16, causing sqrt(3072)=55.4256 to become 55.5. See https://github.com/huggingface/transformers/pull/29402
  386. self.embed_tokens = Gemma3TextScaledWordEmbedding(
  387. config.vocab_size, config.hidden_size, self.padding_idx, embed_scale=self.config.hidden_size**0.5
  388. )
  389. self.layers = nn.ModuleList(
  390. [Gemma3DecoderLayer(config, layer_idx) for layer_idx in range(config.num_hidden_layers)]
  391. )
  392. self.norm = Gemma3RMSNorm(config.hidden_size, eps=config.rms_norm_eps)
  393. self.rotary_emb = Gemma3RotaryEmbedding(config=config)
  394. self.gradient_checkpointing = False
  395. # TODO: raushan fix this after RoPE refactor. For now we hack it by reassigning thetas
  396. # when we want to create a local RoPE layer. Config defaults should hold values for global RoPE
  397. config = copy.deepcopy(config)
  398. config.rope_theta = config.rope_local_base_freq
  399. config.rope_scaling = {"rope_type": "default"}
  400. self.rotary_emb_local = Gemma3RotaryEmbedding(config=config)
  401. # Initialize weights and apply final processing
  402. self.post_init()
  403. @check_model_inputs()
  404. @auto_docstring
  405. def forward(
  406. self,
  407. input_ids: Optional[torch.LongTensor] = None,
  408. attention_mask: Optional[torch.Tensor] = None,
  409. position_ids: Optional[torch.LongTensor] = None,
  410. past_key_values: Optional[Cache] = None,
  411. inputs_embeds: Optional[torch.FloatTensor] = None,
  412. use_cache: Optional[bool] = None,
  413. output_attentions: Optional[bool] = None,
  414. output_hidden_states: Optional[bool] = None,
  415. cache_position: Optional[torch.LongTensor] = None,
  416. **kwargs: Unpack[TransformersKwargs],
  417. ) -> BaseModelOutputWithPast:
  418. output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
  419. output_hidden_states = (
  420. output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
  421. )
  422. use_cache = use_cache if use_cache is not None else self.config.use_cache
  423. if (input_ids is None) ^ (inputs_embeds is not None):
  424. raise ValueError("You must specify exactly one of input_ids or inputs_embeds")
  425. if self.gradient_checkpointing and self.training and use_cache:
  426. logger.warning_once(
  427. "`use_cache=True` is incompatible with gradient checkpointing. Setting `use_cache=False`."
  428. )
  429. use_cache = False
  430. if inputs_embeds is None:
  431. inputs_embeds = self.embed_tokens(input_ids)
  432. if use_cache and past_key_values is None and not self.training:
  433. past_key_values = DynamicCache(config=self.config)
  434. if cache_position is None:
  435. past_seen_tokens = past_key_values.get_seq_length() if past_key_values is not None else 0
  436. cache_position = torch.arange(
  437. past_seen_tokens,
  438. past_seen_tokens + inputs_embeds.shape[1],
  439. device=inputs_embeds.device,
  440. )
  441. if position_ids is None:
  442. position_ids = cache_position.unsqueeze(0)
  443. # It may already have been prepared by e.g. `generate`
  444. if not isinstance(causal_mask_mapping := attention_mask, dict):
  445. # Prepare mask arguments
  446. mask_kwargs = {
  447. "config": self.config,
  448. "input_embeds": inputs_embeds,
  449. "attention_mask": attention_mask,
  450. "cache_position": cache_position,
  451. "past_key_values": past_key_values,
  452. "position_ids": position_ids,
  453. }
  454. sliding_mask_kwargs = mask_kwargs.copy()
  455. if self.config.use_bidirectional_attention:
  456. mask_kwargs["or_mask_function"] = lambda *args: torch.tensor(True, dtype=torch.bool)
  457. sliding_mask_kwargs["or_mask_function"] = _bidirectional_window_overlay(self.config.sliding_window)
  458. # Create the masks
  459. causal_mask_mapping = {
  460. "full_attention": create_causal_mask(**mask_kwargs),
  461. "sliding_attention": create_sliding_window_causal_mask(**sliding_mask_kwargs),
  462. }
  463. # embed positions
  464. hidden_states = inputs_embeds
  465. # create position embeddings to be shared across the decoder layers
  466. position_embeddings_global = self.rotary_emb(hidden_states, position_ids)
  467. position_embeddings_local = self.rotary_emb_local(hidden_states, position_ids)
  468. # decoder layers
  469. all_hidden_states = () if output_hidden_states else None
  470. all_self_attns = () if output_attentions else None
  471. for decoder_layer in self.layers[: self.config.num_hidden_layers]:
  472. if output_hidden_states:
  473. all_hidden_states += (hidden_states,)
  474. layer_outputs = decoder_layer(
  475. hidden_states,
  476. position_embeddings_global=position_embeddings_global,
  477. position_embeddings_local=position_embeddings_local,
  478. attention_mask=causal_mask_mapping[decoder_layer.attention_type],
  479. position_ids=position_ids,
  480. past_key_values=past_key_values,
  481. output_attentions=output_attentions,
  482. use_cache=use_cache,
  483. cache_position=cache_position,
  484. **kwargs,
  485. )
  486. hidden_states = layer_outputs[0]
  487. if output_attentions:
  488. all_self_attns += (layer_outputs[1],)
  489. hidden_states = self.norm(hidden_states)
  490. if output_hidden_states:
  491. all_hidden_states += (hidden_states,)
  492. return BaseModelOutputWithPast(
  493. last_hidden_state=hidden_states,
  494. past_key_values=past_key_values,
  495. hidden_states=all_hidden_states,
  496. attentions=all_self_attns,
  497. )
  498. @auto_docstring
  499. class Gemma3ForCausalLM(Gemma3PreTrainedModel, GenerationMixin):
  500. _tied_weights_keys = ["lm_head.weight"]
  501. _tp_plan = {"lm_head": "colwise_rep"}
  502. _pp_plan = {"lm_head": (["hidden_states"], ["logits"])}
  503. config: Gemma3TextConfig
  504. base_model_prefix = "language_model"
  505. def __init__(self, config: Gemma3TextConfig):
  506. super().__init__(config)
  507. self.model = Gemma3TextModel(config)
  508. self.vocab_size = config.vocab_size
  509. self.lm_head = nn.Linear(config.hidden_size, config.vocab_size, bias=False)
  510. # Initialize weights and apply final processing
  511. self.post_init()
  512. @can_return_tuple
  513. @auto_docstring
  514. def forward(
  515. self,
  516. input_ids: Optional[torch.LongTensor] = None,
  517. attention_mask: Optional[torch.Tensor] = None,
  518. position_ids: Optional[torch.LongTensor] = None,
  519. past_key_values: Optional[Cache] = None,
  520. inputs_embeds: Optional[torch.FloatTensor] = None,
  521. labels: Optional[torch.LongTensor] = None,
  522. use_cache: Optional[bool] = None,
  523. output_attentions: Optional[bool] = None,
  524. output_hidden_states: Optional[bool] = None,
  525. cache_position: Optional[torch.LongTensor] = None,
  526. logits_to_keep: Union[int, torch.Tensor] = 0,
  527. **kwargs,
  528. ) -> CausalLMOutputWithPast:
  529. r"""
  530. Example:
  531. ```python
  532. >>> from transformers import AutoTokenizer, Gemma3ForCausalLM
  533. >>> model = Gemma3ForCausalLM.from_pretrained("google/gemma-2-9b")
  534. >>> tokenizer = AutoTokenizer.from_pretrained("google/gemma-2-9b")
  535. >>> prompt = "What is your favorite condiment?"
  536. >>> inputs = tokenizer(prompt, return_tensors="pt")
  537. >>> # Generate
  538. >>> generate_ids = model.generate(inputs.input_ids, max_length=30)
  539. >>> tokenizer.batch_decode(generate_ids, skip_special_tokens=True, clean_up_tokenization_spaces=False)[0]
  540. "What is your favorite condiment?"
  541. ```"""
  542. output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
  543. output_hidden_states = (
  544. output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
  545. )
  546. # decoder outputs consists of (dec_features, layer_state, dec_hidden, dec_attn)
  547. outputs: BaseModelOutputWithPast = self.model(
  548. input_ids=input_ids,
  549. attention_mask=attention_mask,
  550. position_ids=position_ids,
  551. past_key_values=past_key_values,
  552. inputs_embeds=inputs_embeds,
  553. use_cache=use_cache,
  554. output_attentions=output_attentions,
  555. output_hidden_states=output_hidden_states,
  556. cache_position=cache_position,
  557. **kwargs,
  558. )
  559. hidden_states = outputs.last_hidden_state
  560. # Only compute necessary logits, and do not upcast them to float if we are not computing the loss
  561. slice_indices = slice(-logits_to_keep, None) if isinstance(logits_to_keep, int) else logits_to_keep
  562. logits = self.lm_head(hidden_states[:, slice_indices, :])
  563. if self.config.final_logit_softcapping is not None:
  564. logits = logits / self.config.final_logit_softcapping
  565. logits = torch.tanh(logits)
  566. logits = logits * self.config.final_logit_softcapping
  567. loss = None
  568. if labels is not None:
  569. loss = self.loss_function(logits, labels, self.vocab_size, **kwargs)
  570. return CausalLMOutputWithPast(
  571. loss=loss,
  572. logits=logits,
  573. past_key_values=outputs.past_key_values,
  574. hidden_states=outputs.hidden_states,
  575. attentions=outputs.attentions,
  576. )
  577. class Gemma3MultiModalProjector(nn.Module):
  578. def __init__(self, config: Gemma3Config):
  579. super().__init__()
  580. self.mm_input_projection_weight = nn.Parameter(
  581. torch.zeros(config.vision_config.hidden_size, config.text_config.hidden_size)
  582. )
  583. self.mm_soft_emb_norm = Gemma3RMSNorm(
  584. config.vision_config.hidden_size, eps=config.vision_config.layer_norm_eps
  585. )
  586. self.patches_per_image = int(config.vision_config.image_size // config.vision_config.patch_size)
  587. self.tokens_per_side = int(config.mm_tokens_per_image**0.5)
  588. self.kernel_size = self.patches_per_image // self.tokens_per_side
  589. self.avg_pool = nn.AvgPool2d(kernel_size=self.kernel_size, stride=self.kernel_size)
  590. def forward(self, vision_outputs: torch.Tensor):
  591. batch_size, _, seq_length = vision_outputs.shape
  592. reshaped_vision_outputs = vision_outputs.transpose(1, 2)
  593. reshaped_vision_outputs = reshaped_vision_outputs.reshape(
  594. batch_size, seq_length, self.patches_per_image, self.patches_per_image
  595. )
  596. reshaped_vision_outputs = reshaped_vision_outputs.contiguous()
  597. pooled_vision_outputs = self.avg_pool(reshaped_vision_outputs)
  598. pooled_vision_outputs = pooled_vision_outputs.flatten(2)
  599. pooled_vision_outputs = pooled_vision_outputs.transpose(1, 2)
  600. normed_vision_outputs = self.mm_soft_emb_norm(pooled_vision_outputs)
  601. projected_vision_outputs = torch.matmul(normed_vision_outputs, self.mm_input_projection_weight)
  602. return projected_vision_outputs.type_as(vision_outputs)
  603. def token_type_ids_mask_function(
  604. token_type_ids: Optional[torch.Tensor],
  605. image_group_ids: Optional[torch.Tensor],
  606. tokens_per_image: int,
  607. ) -> Optional[Callable]:
  608. """
  609. This function adds the correct offsets to the `q_idx` and `kv_idx` as the torch API can only accept lengths,
  610. not start and end indices.
  611. """
  612. # Do not return an additional mask in this case
  613. if token_type_ids is None:
  614. return None
  615. def inner_mask(batch_idx: int, head_idx: int, q_idx: int, kv_idx: int) -> bool:
  616. # If it's 1 for both query and key/value, we are in an image block
  617. # NOTE: static cache shape goes beyond input seq length, while token_type_ids.shape[1] == input seq length
  618. # Since vmap doesn't support `if statement` we workaround it with `torch.where`
  619. safe_idx = torch.where(kv_idx < token_type_ids.shape[1], kv_idx, 0)
  620. token_type_ids_at_kv_idx = token_type_ids[batch_idx, safe_idx]
  621. token_type_ids_at_kv_idx = torch.where(kv_idx < token_type_ids.shape[1], token_type_ids_at_kv_idx, 0)
  622. image_group_ids_at_kv_idx = image_group_ids[batch_idx, safe_idx]
  623. image_group_ids_at_kv_idx = torch.where(kv_idx < image_group_ids.shape[1], image_group_ids_at_kv_idx, -1)
  624. is_image_block = (token_type_ids[batch_idx, q_idx] == 1) & (token_type_ids_at_kv_idx == 1)
  625. same_image_block = image_group_ids[batch_idx, q_idx] == image_group_ids_at_kv_idx
  626. # This is bidirectional attention whenever we are dealing with image tokens
  627. return is_image_block & same_image_block
  628. return inner_mask
  629. @auto_docstring(
  630. custom_intro="""
  631. The Base Gemma3 model which consists of a vision backbone and a language model without language modeling head.,
  632. """
  633. )
  634. class Gemma3Model(Gemma3PreTrainedModel):
  635. _checkpoint_conversion_mapping = {"language_model.model": "language_model"}
  636. # we are filtering the logits/labels so we shouldn't divide the loss based on num_items_in_batch
  637. accepts_loss_kwargs = False
  638. def __init__(self, config: Gemma3Config):
  639. super().__init__(config)
  640. self.vision_tower = AutoModel.from_config(config=config.vision_config)
  641. self.multi_modal_projector = Gemma3MultiModalProjector(config)
  642. self.vocab_size = config.text_config.vocab_size
  643. language_model = AutoModel.from_config(config=config.text_config)
  644. self.language_model = language_model
  645. self.pad_token_id = self.config.pad_token_id if self.config.pad_token_id is not None else -1
  646. self.post_init()
  647. def get_input_embeddings(self):
  648. return self.language_model.get_input_embeddings()
  649. def set_input_embeddings(self, value):
  650. self.language_model.set_input_embeddings(value)
  651. def set_decoder(self, decoder):
  652. self.language_model = decoder
  653. def get_decoder(self):
  654. return self.language_model
  655. def get_image_features(self, pixel_values: torch.Tensor) -> torch.Tensor:
  656. """
  657. Projects the last hidden state from the vision model into language model space.
  658. Args:
  659. pixel_values (`torch.FloatTensor]` of shape `(batch_size, channels, height, width)`)
  660. The tensors corresponding to the input images.
  661. Returns:
  662. image_features (`torch.Tensor`): Image feature tensor of shape `(num_images, image_length, embed_dim)`).
  663. """
  664. vision_outputs = self.vision_tower(pixel_values=pixel_values).last_hidden_state
  665. image_features = self.multi_modal_projector(vision_outputs)
  666. return image_features
  667. def get_placeholder_mask(
  668. self, input_ids: torch.LongTensor, inputs_embeds: torch.FloatTensor, image_features: torch.FloatTensor
  669. ):
  670. """
  671. Obtains multimodal placeholder mask from `input_ids` or `inputs_embeds`, and checks that the placeholder token count is
  672. equal to the length of multimodal features. If the lengths are different, an error is raised.
  673. """
  674. if input_ids is None:
  675. special_image_mask = inputs_embeds == self.get_input_embeddings()(
  676. torch.tensor(self.config.image_token_id, dtype=torch.long, device=inputs_embeds.device)
  677. )
  678. special_image_mask = special_image_mask.all(-1)
  679. else:
  680. special_image_mask = input_ids == self.config.image_token_id
  681. n_image_tokens = special_image_mask.sum()
  682. special_image_mask = special_image_mask.unsqueeze(-1).expand_as(inputs_embeds).to(inputs_embeds.device)
  683. n_image_features = image_features.shape[0] * image_features.shape[1]
  684. if inputs_embeds[special_image_mask].numel() != image_features.numel():
  685. raise ValueError(
  686. f"Image features and image tokens do not match: tokens: {n_image_tokens}, features {n_image_features}"
  687. )
  688. return special_image_mask
  689. @can_return_tuple
  690. @auto_docstring
  691. def forward(
  692. self,
  693. input_ids: Optional[torch.LongTensor] = None,
  694. pixel_values: Optional[torch.FloatTensor] = None,
  695. attention_mask: Optional[torch.Tensor] = None,
  696. position_ids: Optional[torch.LongTensor] = None,
  697. past_key_values: Optional[Cache] = None,
  698. token_type_ids: Optional[torch.LongTensor] = None,
  699. cache_position: Optional[torch.LongTensor] = None,
  700. inputs_embeds: Optional[torch.FloatTensor] = None,
  701. labels: Optional[torch.LongTensor] = None,
  702. use_cache: Optional[bool] = None,
  703. output_attentions: Optional[bool] = None,
  704. output_hidden_states: Optional[bool] = None,
  705. return_dict: Optional[bool] = None,
  706. **lm_kwargs,
  707. ) -> Union[tuple, Gemma3ModelOutputWithPast]:
  708. r"""
  709. labels (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*):
  710. Labels for computing the masked language modeling loss. Indices should either be in `[0, ...,
  711. config.text_config.vocab_size]` or -100 (see `input_ids` docstring). Tokens with indices set to `-100` are ignored
  712. (masked), the loss is only computed for the tokens with labels in `[0, ..., config.text_config.vocab_size]`.
  713. Example:
  714. ```python
  715. >>> from PIL import Image
  716. >>> import requests
  717. >>> from transformers import AutoProcessor, Gemma3ForConditionalGeneration
  718. >>> model = Gemma3ForConditionalGeneration.from_pretrained("google/gemma32-3b-mix-224")
  719. >>> processor = AutoProcessor.from_pretrained("google/gemma32-3b-mix-224")
  720. >>> prompt = "Where is the cat standing?"
  721. >>> url = "https://huggingface.co/datasets/huggingface/documentation-images/resolve/main/pipeline-cat-chonk.jpeg"
  722. >>> image = Image.open(requests.get(url, stream=True).raw)
  723. >>> inputs = processor(images=image, text=prompt, return_tensors="pt")
  724. >>> # Generate
  725. >>> generate_ids = model.generate(**inputs,)
  726. >>> processor.batch_decode(generate_ids, skip_special_tokens=True, clean_up_tokenization_spaces=False)[0]
  727. "Where is the cat standing?\nsnow"
  728. ```"""
  729. if (input_ids is None) ^ (inputs_embeds is not None):
  730. raise ValueError("You must specify exactly one of input_ids or inputs_embeds")
  731. output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
  732. output_hidden_states = (
  733. output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
  734. )
  735. return_dict = return_dict if return_dict is not None else self.config.use_return_dict
  736. # Replace image id with PAD if the image token if OOV, to avoid index-errors
  737. if input_ids is not None and self.config.image_token_id >= self.vocab_size:
  738. special_image_mask = input_ids == self.config.image_token_id
  739. llm_input_ids = input_ids.clone()
  740. llm_input_ids[special_image_mask] = 0
  741. else:
  742. llm_input_ids = input_ids
  743. if inputs_embeds is None:
  744. inputs_embeds = self.get_input_embeddings()(llm_input_ids)
  745. if cache_position is None:
  746. past_seen_tokens = past_key_values.get_seq_length() if past_key_values is not None else 0
  747. cache_position = torch.arange(
  748. past_seen_tokens, past_seen_tokens + inputs_embeds.shape[1], device=inputs_embeds.device
  749. )
  750. # Merge text and images
  751. if pixel_values is not None:
  752. image_features = self.get_image_features(pixel_values)
  753. image_features = image_features.to(inputs_embeds.device, inputs_embeds.dtype)
  754. special_image_mask = self.get_placeholder_mask(
  755. input_ids, inputs_embeds=inputs_embeds, image_features=image_features
  756. )
  757. inputs_embeds = inputs_embeds.masked_scatter(special_image_mask, image_features)
  758. # It may already have been prepared by e.g. `generate`
  759. if not isinstance(causal_mask_mapping := attention_mask, dict):
  760. # Prepare mask arguments
  761. mask_kwargs = {
  762. "config": self.config.get_text_config(),
  763. "input_embeds": inputs_embeds,
  764. "attention_mask": attention_mask,
  765. "cache_position": cache_position,
  766. "past_key_values": past_key_values,
  767. "position_ids": position_ids,
  768. }
  769. # NOTE: this `is_prefill` logic is not flawless, it fails when we're using a cache eagerly initialized
  770. # (e.g. compiled prefill) AND `pixel_values` are not provided. Determining prefill in that case requires
  771. # checking data values, which is not compile-compatible.
  772. is_prefill = (
  773. not use_cache
  774. or past_key_values is None
  775. or not past_key_values.is_initialized
  776. or pixel_values is not None
  777. )
  778. if token_type_ids is not None and is_prefill:
  779. # We need to pass an additional mask function to account for token type ids, and it needs to be an `or`
  780. # First find where a new image block starts: 1 if image and previous not image
  781. # The images cannot attend to future images, but can attend to all prev images and to itself
  782. # bidirectionally
  783. is_image = (token_type_ids == 1).to(cache_position.device)
  784. new_image_start = is_image & ~nn.functional.pad(is_image, (1, 0), value=0)[:, :-1]
  785. image_group_ids = torch.cumsum(new_image_start.int(), dim=1) - 1
  786. image_group_ids = torch.where(
  787. is_image, image_group_ids, torch.full_like(token_type_ids, -1, device=is_image.device)
  788. )
  789. mask_kwargs["or_mask_function"] = token_type_ids_mask_function(
  790. token_type_ids.to(cache_position.device), image_group_ids, self.config.mm_tokens_per_image
  791. )
  792. # Create the masks
  793. causal_mask_mapping = {
  794. "full_attention": create_causal_mask(**mask_kwargs),
  795. "sliding_attention": create_sliding_window_causal_mask(**mask_kwargs),
  796. }
  797. outputs = self.language_model(
  798. attention_mask=causal_mask_mapping,
  799. position_ids=position_ids,
  800. past_key_values=past_key_values,
  801. inputs_embeds=inputs_embeds,
  802. use_cache=use_cache,
  803. output_attentions=output_attentions,
  804. output_hidden_states=output_hidden_states,
  805. return_dict=True,
  806. cache_position=cache_position,
  807. **lm_kwargs,
  808. )
  809. return Gemma3ModelOutputWithPast(
  810. last_hidden_state=outputs.last_hidden_state,
  811. past_key_values=outputs.past_key_values if use_cache else None,
  812. hidden_states=outputs.hidden_states,
  813. attentions=outputs.attentions,
  814. image_hidden_states=image_features if pixel_values is not None else None,
  815. )
  816. @auto_docstring(
  817. custom_intro="""
  818. The Base Gemma3 model which consists of a vision backbone and a language model without language modeling head.,
  819. """
  820. )
  821. class Gemma3ForConditionalGeneration(Gemma3PreTrainedModel, GenerationMixin):
  822. _checkpoint_conversion_mapping = {
  823. "^language_model.model": "model.language_model",
  824. "^vision_tower": "model.vision_tower",
  825. "^multi_modal_projector": "model.multi_modal_projector",
  826. "^language_model.lm_head": "lm_head",
  827. }
  828. _tied_weights_keys = ["lm_head.weight"]
  829. # we are filtering the logits/labels so we shouldn't divide the loss based on num_items_in_batch
  830. # Fix: https://github.com/huggingface/transformers/issues/40564
  831. accepts_loss_kwargs = False
  832. def __init__(self, config: Gemma3Config):
  833. super().__init__(config)
  834. self.model = Gemma3Model(config)
  835. self.lm_head = nn.Linear(config.text_config.hidden_size, config.text_config.vocab_size, bias=False)
  836. self.post_init()
  837. def get_input_embeddings(self):
  838. return self.model.get_input_embeddings()
  839. def set_input_embeddings(self, value):
  840. self.model.set_input_embeddings(value)
  841. def set_decoder(self, decoder):
  842. self.model.set_decoder(decoder)
  843. def get_decoder(self):
  844. return self.model.get_decoder()
  845. def get_image_features(self, pixel_values):
  846. return self.model.get_image_features(pixel_values)
  847. # Make modules available through conditional class for BC
  848. @property
  849. def language_model(self):
  850. return self.model.language_model
  851. @property
  852. def vision_tower(self):
  853. return self.model.vision_tower
  854. @property
  855. def multi_modal_projector(self):
  856. return self.model.multi_modal_projector
  857. @auto_docstring
  858. def forward(
  859. self,
  860. input_ids: Optional[torch.LongTensor] = None,
  861. pixel_values: Optional[torch.FloatTensor] = None,
  862. attention_mask: Optional[torch.Tensor] = None,
  863. position_ids: Optional[torch.LongTensor] = None,
  864. past_key_values: Optional[Cache] = None,
  865. token_type_ids: Optional[torch.LongTensor] = None,
  866. cache_position: Optional[torch.LongTensor] = None,
  867. inputs_embeds: Optional[torch.FloatTensor] = None,
  868. labels: Optional[torch.LongTensor] = None,
  869. use_cache: Optional[bool] = None,
  870. output_attentions: Optional[bool] = None,
  871. output_hidden_states: Optional[bool] = None,
  872. return_dict: Optional[bool] = None,
  873. logits_to_keep: Union[int, torch.Tensor] = 0,
  874. **lm_kwargs,
  875. ) -> Union[tuple, Gemma3CausalLMOutputWithPast]:
  876. r"""
  877. labels (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*):
  878. Labels for computing the masked language modeling loss. Indices should either be in `[0, ...,
  879. config.text_config.vocab_size]` or -100 (see `input_ids` docstring). Tokens with indices set to `-100` are ignored
  880. (masked), the loss is only computed for the tokens with labels in `[0, ..., config.text_config.vocab_size]`.
  881. Example:
  882. ```python
  883. >>> from PIL import Image
  884. >>> import requests
  885. >>> from transformers import AutoProcessor, Gemma3ForConditionalGeneration
  886. >>> model = Gemma3ForConditionalGeneration.from_pretrained("google/gemma-3-4b-it")
  887. >>> processor = AutoProcessor.from_pretrained("google/gemma-3-4b-it")
  888. >>> messages = [
  889. ... {
  890. ... "role": "system",
  891. ... "content": [
  892. ... {"type": "text", "text": "You are a helpful assistant."}
  893. ... ]
  894. ... },
  895. ... {
  896. ... "role": "user", "content": [
  897. ... {"type": "image", "url": "https://huggingface.co/datasets/huggingface/documentation-images/resolve/main/pipeline-cat-chonk.jpeg"},
  898. ... {"type": "text", "text": "Where is the cat standing?"},
  899. ... ]
  900. ... },
  901. ... ]
  902. >>> inputs = processor.apply_chat_template(
  903. ... messages,
  904. ... tokenize=True,
  905. ... return_dict=True,
  906. ... return_tensors="pt",
  907. ... add_generation_prompt=True
  908. ... )
  909. >>> # Generate
  910. >>> generate_ids = model.generate(**inputs)
  911. >>> processor.batch_decode(generate_ids, skip_special_tokens=True, clean_up_tokenization_spaces=False)[0]
  912. "user\nYou are a helpful assistant.\n\n\n\n\n\nWhere is the cat standing?\nmodel\nBased on the image, the cat is standing in a snowy area, likely outdoors. It appears to"
  913. ```
  914. """
  915. output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
  916. output_hidden_states = (
  917. output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
  918. )
  919. return_dict = return_dict if return_dict is not None else self.config.use_return_dict
  920. outputs = self.model(
  921. input_ids=input_ids,
  922. pixel_values=pixel_values,
  923. token_type_ids=token_type_ids,
  924. attention_mask=attention_mask,
  925. position_ids=position_ids,
  926. past_key_values=past_key_values,
  927. inputs_embeds=inputs_embeds,
  928. use_cache=use_cache,
  929. labels=labels,
  930. output_attentions=output_attentions,
  931. output_hidden_states=output_hidden_states,
  932. return_dict=return_dict,
  933. cache_position=cache_position,
  934. **lm_kwargs,
  935. )
  936. hidden_states = outputs[0]
  937. # Only compute necessary logits, and do not upcast them to float if we are not computing the loss
  938. slice_indices = slice(-logits_to_keep, None) if isinstance(logits_to_keep, int) else logits_to_keep
  939. logits = self.lm_head(hidden_states[:, slice_indices, :])
  940. loss = None
  941. if labels is not None:
  942. # Upcast to float if we need to compute the loss to avoid potential precision issues
  943. logits = logits.float()
  944. shift_logits = logits[..., :-1, :]
  945. shift_labels = labels[..., 1:]
  946. if attention_mask is not None:
  947. # we use the input attention mask to shift the logits and labels, because it is 2D.
  948. # we also crop attn mask in case it is longer, which happens in PrefixTuning with peft
  949. shift_attention_mask = attention_mask[:, -shift_logits.shape[1] :].to(logits.device)
  950. shift_logits = shift_logits[shift_attention_mask.to(logits.device) != 0].contiguous()
  951. shift_labels = shift_labels[shift_attention_mask.to(shift_labels.device) != 0].contiguous()
  952. else:
  953. shift_logits = shift_logits.contiguous()
  954. shift_labels = shift_labels.contiguous()
  955. # Flatten the tokens
  956. loss_fct = nn.CrossEntropyLoss()
  957. flat_logits = shift_logits.view(-1, self.config.text_config.vocab_size)
  958. flat_labels = shift_labels.view(-1).to(shift_logits.device)
  959. loss = loss_fct(flat_logits, flat_labels)
  960. if not return_dict:
  961. output = (logits,) + outputs[1:]
  962. return (loss,) + output if loss is not None else output
  963. return Gemma3CausalLMOutputWithPast(
  964. loss=loss,
  965. logits=logits,
  966. past_key_values=outputs.past_key_values,
  967. hidden_states=outputs.hidden_states,
  968. attentions=outputs.attentions,
  969. image_hidden_states=outputs.image_hidden_states,
  970. )
  971. def prepare_inputs_for_generation(
  972. self,
  973. input_ids,
  974. past_key_values=None,
  975. inputs_embeds=None,
  976. cache_position=None,
  977. position_ids=None,
  978. pixel_values=None,
  979. attention_mask=None,
  980. token_type_ids=None,
  981. use_cache=True,
  982. logits_to_keep=None,
  983. labels=None,
  984. **kwargs,
  985. ):
  986. # Overwritten -- custom `position_ids` and `pixel_values` handling
  987. model_inputs = super().prepare_inputs_for_generation(
  988. input_ids,
  989. past_key_values=past_key_values,
  990. inputs_embeds=inputs_embeds,
  991. attention_mask=attention_mask,
  992. position_ids=position_ids,
  993. cache_position=cache_position,
  994. use_cache=use_cache,
  995. logits_to_keep=logits_to_keep,
  996. token_type_ids=token_type_ids,
  997. **kwargs,
  998. )
  999. # If we're in cached decoding stage, pixel values should be None because input ids do not contain special image token anymore
  1000. # Otherwise we need pixel values to be passed to model. NOTE: use_cache=False needs pixel_values always
  1001. if cache_position[0] == 0:
  1002. model_inputs["pixel_values"] = pixel_values
  1003. return model_inputs
  1004. @staticmethod
  1005. def create_masks_for_generate(
  1006. config: PretrainedConfig,
  1007. input_embeds: torch.Tensor,
  1008. attention_mask: Optional[torch.Tensor],
  1009. cache_position: torch.Tensor,
  1010. past_key_values: Optional[Cache],
  1011. position_ids: Optional[torch.Tensor],
  1012. token_type_ids: Optional[torch.Tensor] = None,
  1013. **kwargs,
  1014. ) -> dict:
  1015. # Prepare mask arguments
  1016. mask_kwargs = {
  1017. "config": config.get_text_config(),
  1018. "input_embeds": input_embeds,
  1019. "attention_mask": attention_mask,
  1020. "cache_position": cache_position,
  1021. "past_key_values": past_key_values,
  1022. "position_ids": position_ids,
  1023. }
  1024. # Add the token type ids mask for generate as well
  1025. if token_type_ids is not None and input_embeds.shape[1] != 1:
  1026. # We need to pass an additional mask function to account for token type ids, and it needs to be an `or`
  1027. # First find where a new image block starts: 1 if image and previous not image
  1028. # The images cannot attend to future images, but can attend to all prev images and to itself bidirectionally
  1029. is_image = (token_type_ids == 1).to(cache_position.device)
  1030. new_image_start = is_image & ~nn.functional.pad(is_image, (1, 0), value=0)[:, :-1]
  1031. image_group_ids = torch.cumsum(new_image_start.int(), dim=1) - 1
  1032. image_group_ids = torch.where(is_image, image_group_ids, torch.full_like(token_type_ids, -1))
  1033. mask_kwargs["or_mask_function"] = token_type_ids_mask_function(
  1034. token_type_ids.to(cache_position.device), image_group_ids, config.mm_tokens_per_image
  1035. )
  1036. return create_masks_for_generate(**mask_kwargs)
  1037. class Gemma3ForSequenceClassification(Gemma3PreTrainedModel):
  1038. _checkpoint_conversion_mapping = {
  1039. "^language_model.model": "model.language_model",
  1040. "^vision_tower": "model.vision_tower",
  1041. "^multi_modal_projector": "model.multi_modal_projector",
  1042. }
  1043. def __init__(self, config):
  1044. super().__init__(config)
  1045. self.num_labels = config.num_labels
  1046. self.model = Gemma3Model(config)
  1047. self.score = nn.Linear(config.text_config.hidden_size, self.num_labels, bias=False)
  1048. # Initialize weights and apply final processing
  1049. self.post_init()
  1050. def get_input_embeddings(self):
  1051. return self.model.get_input_embeddings()
  1052. def set_input_embeddings(self, value):
  1053. self.model.set_input_embeddings(value)
  1054. @can_return_tuple
  1055. @auto_docstring
  1056. def forward(
  1057. self,
  1058. input_ids: Optional[torch.LongTensor] = None,
  1059. pixel_values: Optional[torch.FloatTensor] = None,
  1060. attention_mask: Optional[torch.Tensor] = None,
  1061. position_ids: Optional[torch.LongTensor] = None,
  1062. past_key_values: Optional[Cache] = None,
  1063. inputs_embeds: Optional[torch.FloatTensor] = None,
  1064. token_type_ids: Optional[torch.LongTensor] = None,
  1065. labels: Optional[torch.LongTensor] = None,
  1066. use_cache: Optional[bool] = None,
  1067. **kwargs: Unpack[TransformersKwargs],
  1068. ) -> SequenceClassifierOutputWithPast:
  1069. r"""
  1070. labels (`torch.LongTensor` of shape `(batch_size,)`, *optional*):
  1071. Labels for computing the sequence classification/regression loss. Indices should be in `[0, ...,
  1072. config.num_labels - 1]`. If `config.num_labels == 1` a regression loss is computed (Mean-Square loss), If
  1073. `config.num_labels > 1` a classification loss is computed (Cross-Entropy).
  1074. """
  1075. transformer_outputs = self.model(
  1076. input_ids,
  1077. attention_mask=attention_mask,
  1078. pixel_values=pixel_values,
  1079. position_ids=position_ids,
  1080. past_key_values=past_key_values,
  1081. inputs_embeds=inputs_embeds,
  1082. token_type_ids=token_type_ids,
  1083. use_cache=use_cache,
  1084. **kwargs,
  1085. )
  1086. hidden_states = transformer_outputs.last_hidden_state
  1087. logits = self.score(hidden_states)
  1088. if input_ids is not None:
  1089. batch_size = input_ids.shape[0]
  1090. else:
  1091. batch_size = inputs_embeds.shape[0]
  1092. if self.config.text_config.pad_token_id is None and batch_size != 1:
  1093. raise ValueError("Cannot handle batch sizes > 1 if no padding token is defined.")
  1094. if self.config.text_config.pad_token_id is None:
  1095. last_non_pad_token = -1
  1096. elif input_ids is not None:
  1097. # To handle both left- and right- padding, we take the rightmost token that is not equal to pad_token_id
  1098. non_pad_mask = (input_ids != self.config.text_config.pad_token_id).to(logits.device, torch.int32)
  1099. token_indices = torch.arange(input_ids.shape[-1], device=logits.device, dtype=torch.int32)
  1100. last_non_pad_token = (token_indices * non_pad_mask).argmax(-1)
  1101. else:
  1102. last_non_pad_token = -1
  1103. logger.warning_once(
  1104. f"{self.__class__.__name__} will not detect padding tokens in `inputs_embeds`. Results may be "
  1105. "unexpected if using padding tokens in conjunction with `inputs_embeds.`"
  1106. )
  1107. pooled_logits = logits[torch.arange(batch_size, device=logits.device), last_non_pad_token]
  1108. loss = None
  1109. if labels is not None:
  1110. loss = self.loss_function(logits=logits, labels=labels, pooled_logits=pooled_logits, config=self.config)
  1111. return SequenceClassifierOutputWithPast(
  1112. loss=loss,
  1113. logits=pooled_logits,
  1114. past_key_values=transformer_outputs.past_key_values,
  1115. hidden_states=transformer_outputs.hidden_states,
  1116. attentions=transformer_outputs.attentions,
  1117. )
  1118. class Gemma3TextForSequenceClassification(GenericForSequenceClassification, Gemma3PreTrainedModel):
  1119. """
  1120. Gemma3TextForSequenceClassification is a text-only sequence classification model that works with Gemma3TextConfig.
  1121. It uses the generic sequence classification implementation for efficiency and consistency.
  1122. """
  1123. config: Gemma3TextConfig
  1124. __all__ = [
  1125. "Gemma3PreTrainedModel",
  1126. "Gemma3TextModel",
  1127. "Gemma3ForCausalLM",
  1128. "Gemma3ForConditionalGeneration",
  1129. "Gemma3Model",
  1130. "Gemma3ForSequenceClassification",
  1131. "Gemma3TextForSequenceClassification",
  1132. ]