modular_gemma3.py 54 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525526527528529530531532533534535536537538539540541542543544545546547548549550551552553554555556557558559560561562563564565566567568569570571572573574575576577578579580581582583584585586587588589590591592593594595596597598599600601602603604605606607608609610611612613614615616617618619620621622623624625626627628629630631632633634635636637638639640641642643644645646647648649650651652653654655656657658659660661662663664665666667668669670671672673674675676677678679680681682683684685686687688689690691692693694695696697698699700701702703704705706707708709710711712713714715716717718719720721722723724725726727728729730731732733734735736737738739740741742743744745746747748749750751752753754755756757758759760761762763764765766767768769770771772773774775776777778779780781782783784785786787788789790791792793794795796797798799800801802803804805806807808809810811812813814815816817818819820821822823824825826827828829830831832833834835836837838839840841842843844845846847848849850851852853854855856857858859860861862863864865866867868869870871872873874875876877878879880881882883884885886887888889890891892893894895896897898899900901902903904905906907908909910911912913914915916917918919920921922923924925926927928929930931932933934935936937938939940941942943944945946947948949950951952953954955956957958959960961962963964965966967968969970971972973974975976977978979980981982983984985986987988989990991992993994995996997998999100010011002100310041005100610071008100910101011101210131014101510161017101810191020102110221023102410251026102710281029103010311032103310341035103610371038103910401041104210431044104510461047104810491050105110521053105410551056105710581059106010611062106310641065106610671068106910701071107210731074107510761077107810791080108110821083108410851086108710881089109010911092109310941095109610971098109911001101110211031104110511061107110811091110111111121113111411151116111711181119112011211122112311241125112611271128112911301131113211331134113511361137113811391140114111421143114411451146114711481149115011511152115311541155115611571158115911601161116211631164116511661167116811691170117111721173117411751176117711781179118011811182118311841185118611871188118911901191119211931194119511961197119811991200120112021203120412051206120712081209121012111212121312141215121612171218121912201221
  1. # coding=utf-8
  2. # Copyright 2025 Google Inc. HuggingFace Inc. team. All rights reserved.
  3. #
  4. #
  5. # Licensed under the Apache License, Version 2.0 (the "License");
  6. # you may not use this file except in compliance with the License.
  7. # You may obtain a copy of the License at
  8. #
  9. # http://www.apache.org/licenses/LICENSE-2.0
  10. #
  11. # Unless required by applicable law or agreed to in writing, software
  12. # distributed under the License is distributed on an "AS IS" BASIS,
  13. # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  14. # See the License for the specific language governing permissions and
  15. # limitations under the License.
  16. import copy
  17. from collections.abc import Callable
  18. from typing import Any, Optional, Union
  19. import torch
  20. import torch.nn as nn
  21. from ...cache_utils import Cache, DynamicCache
  22. from ...configuration_utils import PretrainedConfig, layer_type_validation
  23. from ...masking_utils import create_causal_mask, create_masks_for_generate, create_sliding_window_causal_mask
  24. from ...modeling_flash_attention_utils import FlashAttentionKwargs
  25. from ...modeling_layers import GenericForSequenceClassification, GradientCheckpointingLayer
  26. from ...modeling_outputs import BaseModelOutputWithPast, SequenceClassifierOutputWithPast
  27. from ...modeling_rope_utils import rope_config_validation
  28. from ...modeling_utils import ALL_ATTENTION_FUNCTIONS, PreTrainedModel
  29. from ...processing_utils import Unpack
  30. from ...utils import TransformersKwargs, auto_docstring, can_return_tuple, logging
  31. from ...utils.deprecation import deprecate_kwarg
  32. from ..gemma2.configuration_gemma2 import Gemma2Config
  33. from ..gemma2.modeling_gemma2 import (
  34. Gemma2Attention,
  35. Gemma2ForCausalLM,
  36. Gemma2MLP,
  37. Gemma2Model,
  38. Gemma2PreTrainedModel,
  39. Gemma2RMSNorm,
  40. Gemma2RotaryEmbedding,
  41. apply_rotary_pos_emb,
  42. eager_attention_forward,
  43. )
  44. from ..paligemma.modeling_paligemma import (
  45. PaligemmaCausalLMOutputWithPast,
  46. PaliGemmaForConditionalGeneration,
  47. PaliGemmaModel,
  48. PaligemmaModelOutputWithPast,
  49. )
  50. from ..siglip import SiglipVisionConfig
  51. logger = logging.get_logger(__name__)
  52. class Gemma3TextConfig(Gemma2Config, PretrainedConfig):
  53. r"""
  54. This is the configuration class to store the configuration of a [`Gemma3TextModel`]. It is used to instantiate an Gemma3Text
  55. model according to the specified arguments, defining the model architecture. Instantiating a configuration with the
  56. defaults will yield a similar configuration to that of the Gemma3Text-7B.
  57. e.g. [google/gemma3_text-7b](https://huggingface.co/google/gemma3_text-7b)
  58. Configuration objects inherit from [`PretrainedConfig`] and can be used to control the model outputs. Read the
  59. documentation from [`PretrainedConfig`] for more information.
  60. Args:
  61. vocab_size (`int`, *optional*, defaults to 262208):
  62. Vocabulary size of the Gemma3Text model. Defines the number of different tokens that can be represented by the
  63. `inputs_ids` passed when calling [`Gemma3TextModel`]
  64. hidden_size (`int`, *optional*, defaults to 2304):
  65. Dimension of the hidden representations.
  66. intermediate_size (`int`, *optional*, defaults to 9216):
  67. Dimension of the MLP representations.
  68. num_hidden_layers (`int`, *optional*, defaults to 26):
  69. Number of hidden layers in the Transformer decoder.
  70. num_attention_heads (`int`, *optional*, defaults to 8):
  71. Number of attention heads for each attention layer in the Transformer decoder.
  72. num_key_value_heads (`int`, *optional*, defaults to 4):
  73. This is the number of key_value heads that should be used to implement Grouped Query Attention. If
  74. `num_key_value_heads=num_attention_heads`, the model will use Multi Head Attention (MHA), if
  75. `num_key_value_heads=1` the model will use Multi Query Attention (MQA) otherwise GQA is used. When
  76. converting a multi-head checkpoint to a GQA checkpoint, each group key and value head should be constructed
  77. by meanpooling all the original heads within that group. For more details, check out [this
  78. paper](https://huggingface.co/papers/2305.13245). If it is not specified, will default to
  79. `num_attention_heads`.
  80. head_dim (`int`, *optional*, defaults to 256):
  81. The attention head dimension.
  82. hidden_activation (`str` or `function`, *optional*, defaults to `"gelu_pytorch_tanh"`):
  83. The non-linear activation function (function or string) in the decoder. Will default to `"gelu_pytorch_tanh"`
  84. if not specified. `"gelu_pytorch_tanh"` uses an approximation of the `"gelu"` activation function.
  85. max_position_embeddings (`int`, *optional*, defaults to 131072):
  86. The maximum sequence length that this model might ever be used with.
  87. initializer_range (`float`, *optional*, defaults to 0.02):
  88. The standard deviation of the truncated_normal_initializer for initializing all weight matrices.
  89. rms_norm_eps (`float`, *optional*, defaults to 1e-06):
  90. The epsilon used by the rms normalization layers.
  91. use_cache (`bool`, *optional*, defaults to `True`):
  92. Whether or not the model should return the last key/values attentions (not used by all models). Only
  93. relevant if `config.is_decoder=True`.
  94. pad_token_id (`int`, *optional*, defaults to 0):
  95. Padding token id.
  96. eos_token_id (`int`, *optional*, defaults to 1):
  97. End of stream token id.
  98. bos_token_id (`int`, *optional*, defaults to 2):
  99. Beginning of stream token id.
  100. tie_word_embeddings (`bool`, *optional*, defaults to `True`):
  101. Whether to tie weight embeddings
  102. rope_theta (`float`, *optional*, defaults to 1000000.0):
  103. The base period of the RoPE embeddings.
  104. attention_bias (`bool`, defaults to `False`, *optional*, defaults to `False`):
  105. Whether to use a bias in the query, key, value and output projection layers during self-attention.
  106. attention_dropout (`float`, *optional*, defaults to 0.0):
  107. The dropout ratio for the attention probabilities.
  108. query_pre_attn_scalar (`float`, *optional*, defaults to 256):
  109. Scaling factor used on the attention scores
  110. sliding_window (`int`, *optional*, defaults to 4096):
  111. In Gemma3Text, every other layer uses sliding window attention. This is the size of the sliding window.
  112. layer_types (`list`, *optional*):
  113. Attention pattern for each layer.
  114. final_logit_softcapping (`float`, *optional*):
  115. Scaling factor when applying tanh softcapping on the logits.
  116. attn_logit_softcapping (`float`, *optional*):
  117. Scaling factor when applying tanh softcapping on the attention scores.
  118. rope_scaling (`Dict`, *optional*):
  119. Dictionary containing the scaling configuration for the RoPE embeddings used in global attention. NOTE: if you apply new rope type
  120. and you expect the model to work on longer `max_position_embeddings`, we recommend you to update this value
  121. accordingly.
  122. Expected contents:
  123. `rope_type` (`str`):
  124. The sub-variant of RoPE to use. Can be one of ['default', 'linear', 'dynamic', 'yarn', 'longrope',
  125. 'llama3'], with 'default' being the original RoPE implementation.
  126. `factor` (`float`, *optional*):
  127. Used with all rope types except 'default'. The scaling factor to apply to the RoPE embeddings. In
  128. most scaling types, a `factor` of x will enable the model to handle sequences of length x *
  129. original maximum pre-trained length.
  130. `original_max_position_embeddings` (`int`, *optional*):
  131. Used with 'dynamic', 'longrope' and 'llama3'. The original max position embeddings used during
  132. pretraining.
  133. `attention_factor` (`float`, *optional*):
  134. Used with 'yarn' and 'longrope'. The scaling factor to be applied on the attention
  135. computation. If unspecified, it defaults to value recommended by the implementation, using the
  136. `factor` field to infer the suggested value.
  137. `beta_fast` (`float`, *optional*):
  138. Only used with 'yarn'. Parameter to set the boundary for extrapolation (only) in the linear
  139. ramp function. If unspecified, it defaults to 32.
  140. `beta_slow` (`float`, *optional*):
  141. Only used with 'yarn'. Parameter to set the boundary for interpolation (only) in the linear
  142. ramp function. If unspecified, it defaults to 1.
  143. `short_factor` (`list[float]`, *optional*):
  144. Only used with 'longrope'. The scaling factor to be applied to short contexts (<
  145. `original_max_position_embeddings`). Must be a list of numbers with the same length as the hidden
  146. size divided by the number of attention heads divided by 2
  147. `long_factor` (`list[float]`, *optional*):
  148. Only used with 'longrope'. The scaling factor to be applied to long contexts (<
  149. `original_max_position_embeddings`). Must be a list of numbers with the same length as the hidden
  150. size divided by the number of attention heads divided by 2
  151. `low_freq_factor` (`float`, *optional*):
  152. Only used with 'llama3'. Scaling factor applied to low frequency components of the RoPE
  153. `high_freq_factor` (`float`, *optional*):
  154. Only used with 'llama3'. Scaling factor applied to high frequency components of the RoPE
  155. rope_local_base_freq (float, *optional*, defaults to 10000.0):
  156. The base period of the RoPE embeddings for local attention.
  157. use_bidirectional_attention (`bool`, *optional*, defaults to `False`): If True, the model will attend to all
  158. text tokens instead of using a causal mask. This does not change behavior for vision tokens.
  159. ```python
  160. >>> from transformers import Gemma3TextModel, Gemma3TextConfig
  161. >>> # Initializing a Gemma3Text gemma3_text-7b style configuration
  162. >>> configuration = Gemma3TextConfig()
  163. >>> # Initializing a model from the gemma3_text-7b style configuration
  164. >>> model = Gemma3TextModel(configuration)
  165. >>> # Accessing the model configuration
  166. >>> configuration = model.config
  167. ```
  168. """
  169. model_type = "gemma3_text"
  170. def __init__(
  171. self,
  172. vocab_size=262_208,
  173. hidden_size=2304,
  174. intermediate_size=9216,
  175. num_hidden_layers=26,
  176. num_attention_heads=8,
  177. num_key_value_heads=4,
  178. head_dim=256,
  179. hidden_activation="gelu_pytorch_tanh",
  180. max_position_embeddings=131_072,
  181. initializer_range=0.02,
  182. rms_norm_eps=1e-6,
  183. use_cache=True,
  184. pad_token_id=0,
  185. eos_token_id=1,
  186. bos_token_id=2,
  187. tie_word_embeddings=True,
  188. rope_theta=1_000_000.0,
  189. attention_bias=False,
  190. attention_dropout=0.0,
  191. query_pre_attn_scalar=256,
  192. sliding_window=4096,
  193. layer_types=None,
  194. final_logit_softcapping=None,
  195. attn_logit_softcapping=None,
  196. rope_scaling=None,
  197. rope_local_base_freq=10_000.0,
  198. use_bidirectional_attention=False,
  199. **kwargs,
  200. ):
  201. PretrainedConfig.__init__(
  202. pad_token_id=pad_token_id,
  203. bos_token_id=bos_token_id,
  204. eos_token_id=eos_token_id,
  205. tie_word_embeddings=tie_word_embeddings,
  206. **kwargs,
  207. )
  208. self.vocab_size = vocab_size
  209. self.max_position_embeddings = max_position_embeddings
  210. self.hidden_size = hidden_size
  211. self.intermediate_size = intermediate_size
  212. self.num_hidden_layers = num_hidden_layers
  213. self.num_attention_heads = num_attention_heads
  214. self.head_dim = head_dim
  215. self.num_key_value_heads = num_key_value_heads
  216. self.initializer_range = initializer_range
  217. self.rms_norm_eps = rms_norm_eps
  218. self.use_cache = use_cache
  219. self.rope_theta = rope_theta
  220. self.attention_bias = attention_bias
  221. self.attention_dropout = attention_dropout
  222. self.hidden_activation = hidden_activation
  223. self.query_pre_attn_scalar = query_pre_attn_scalar
  224. self.sliding_window = sliding_window
  225. self.final_logit_softcapping = final_logit_softcapping
  226. self.attn_logit_softcapping = attn_logit_softcapping
  227. self.layer_types = layer_types
  228. self.use_bidirectional_attention = use_bidirectional_attention
  229. if use_bidirectional_attention:
  230. self.sliding_window = (self.sliding_window // 2) + 1 # due to fa we set exclusive bounds
  231. self.rope_local_base_freq = rope_local_base_freq
  232. self.rope_scaling = rope_scaling
  233. rope_config_validation(self)
  234. # BC -> the pattern used to be a simple int, and it's still present in configs on the Hub
  235. self._sliding_window_pattern = kwargs.get("sliding_window_pattern", 6)
  236. if self.layer_types is None:
  237. self.layer_types = [
  238. "sliding_attention" if bool((i + 1) % self._sliding_window_pattern) else "full_attention"
  239. for i in range(self.num_hidden_layers)
  240. ]
  241. layer_type_validation(self.layer_types, self.num_hidden_layers)
  242. class Gemma3Config(PretrainedConfig):
  243. r"""
  244. This is the configuration class to store the configuration of a [`Gemma3ForConditionalGeneration`]. It is used to instantiate an
  245. Gemma3ForConditionalGeneration according to the specified arguments, defining the model architecture. Instantiating a configuration
  246. with the defaults will yield a similar configuration to that of the PaliGemma-2B.
  247. e.g. [google/gemma-3-4b](https://huggingface.co/google/gemma-3-4b)
  248. Configuration objects inherit from [`PretrainedConfig`] and can be used to control the model outputs. Read the
  249. documentation from [`PretrainedConfig`] for more information.
  250. Args:
  251. text_config (`Union[Gemma3TextConfig, dict]`, *optional*):
  252. The config object of the text backbone.
  253. vision_config (`Union[AutoConfig, dict]`, *optional*):
  254. Custom vision config or dict.
  255. mm_tokens_per_image (`int`, *optional*, defaults to 256):
  256. The number of tokens per image embedding.
  257. boi_token_index (`int`, *optional*, defaults to 255999):
  258. The begin-of-image token index to wrap the image prompt.
  259. eoi_token_index (`int`, *optional*, defaults to 256000):
  260. The end-of-image token index to wrap the image prompt.
  261. image_token_index (`int`, *optional*, defaults to 262144):
  262. The image token index to encode the image prompt.
  263. initializer_range (`float`, *optional*, defaults to 0.02):
  264. The standard deviation of the truncated_normal_initializer for initializing all weight matrices.
  265. Example:
  266. ```python
  267. >>> from transformers import Gemma3ForConditionalGeneration, Gemma3Config, SiglipVisionConfig, Gemma3TextConfig
  268. >>> # Initializing a Siglip-like vision config
  269. >>> vision_config = SiglipVisionConfig()
  270. >>> # Initializing a Gemma3 Text config
  271. >>> text_config = Gemma3TextConfig()
  272. >>> # Initializing a Gemma3 gemma-3-4b style configuration
  273. >>> configuration = Gemma3Config(vision_config, text_config)
  274. >>> # Initializing a model from the gemma-3-4b style configuration
  275. >>> model = Gemma3TextConfig(configuration)
  276. >>> # Accessing the model configuration
  277. >>> configuration = model.config
  278. ```"""
  279. model_type = "gemma3"
  280. attribute_map = {
  281. "image_token_id": "image_token_index",
  282. "boi_token_id": "boi_token_index",
  283. "eoi_token_id": "eoi_token_index",
  284. }
  285. sub_configs = {
  286. "text_config": Gemma3TextConfig,
  287. "vision_config": SiglipVisionConfig,
  288. }
  289. def __init__(
  290. self,
  291. text_config: Optional[Union[Gemma3TextConfig, dict[str, Any]]] = None,
  292. vision_config: Optional[Union[SiglipVisionConfig, dict[str, Any]]] = None,
  293. mm_tokens_per_image: int = 256,
  294. boi_token_index: int = 255_999,
  295. eoi_token_index: int = 256_000,
  296. image_token_index: int = 262_144,
  297. initializer_range: float = 0.02,
  298. **kwargs,
  299. ):
  300. if text_config is None:
  301. text_config = Gemma3TextConfig()
  302. logger.info("text_config is None, using default Gemma3TextConfig text config.")
  303. elif isinstance(text_config, dict):
  304. text_config = Gemma3TextConfig(**text_config)
  305. if isinstance(vision_config, dict):
  306. vision_config = SiglipVisionConfig(**vision_config)
  307. elif vision_config is None:
  308. vision_config = SiglipVisionConfig()
  309. logger.info("vision_config is None, using default SiglipVisionConfig vision config.")
  310. self.text_config = text_config
  311. self.vision_config = vision_config
  312. self.mm_tokens_per_image = mm_tokens_per_image
  313. self.boi_token_index = boi_token_index
  314. self.eoi_token_index = eoi_token_index
  315. self.image_token_index = image_token_index
  316. self.initializer_range = initializer_range
  317. super().__init__(**kwargs)
  318. class Gemma3ModelOutputWithPast(PaligemmaModelOutputWithPast):
  319. pass
  320. class Gemma3CausalLMOutputWithPast(PaligemmaCausalLMOutputWithPast):
  321. pass
  322. class Gemma3TextScaledWordEmbedding(nn.Embedding):
  323. """
  324. This module overrides nn.Embeddings' forward by multiplying with embeddings scale.
  325. """
  326. def __init__(self, num_embeddings: int, embedding_dim: int, padding_idx: int, embed_scale: float = 1.0):
  327. super().__init__(num_embeddings, embedding_dim, padding_idx)
  328. self.register_buffer("embed_scale", torch.tensor(embed_scale), persistent=False)
  329. def forward(self, input_ids: torch.Tensor):
  330. return super().forward(input_ids) * self.embed_scale.to(self.weight.dtype)
  331. class Gemma3MLP(Gemma2MLP):
  332. def __init__(self, config: Gemma3TextConfig):
  333. super().__init__(config)
  334. class Gemma3RMSNorm(Gemma2RMSNorm):
  335. def __init__(self, dim: int, eps: float = 1e-6):
  336. super().__init__(dim=dim, eps=eps)
  337. class Gemma3RotaryEmbedding(Gemma2RotaryEmbedding):
  338. def __init__(self, config: Gemma3TextConfig, device=None):
  339. super().__init__(config)
  340. # Weird way to inherit but otherwise the sliding window gets defined first and can't access `is_sliding`
  341. class Gemma3Attention(Gemma2Attention):
  342. def __init__(self, config: Gemma3TextConfig, layer_idx: int):
  343. self.is_sliding = config.layer_types[layer_idx] == "sliding_attention"
  344. super().__init__(config, layer_idx)
  345. self.sliding_window = config.sliding_window if self.is_sliding else None
  346. self.is_causal = not self.config.use_bidirectional_attention
  347. self.q_norm = Gemma3RMSNorm(dim=config.head_dim, eps=config.rms_norm_eps)
  348. self.k_norm = Gemma3RMSNorm(dim=config.head_dim, eps=config.rms_norm_eps)
  349. @deprecate_kwarg("past_key_value", new_name="past_key_values", version="4.58")
  350. def forward(
  351. self,
  352. hidden_states: torch.Tensor,
  353. position_embeddings: torch.Tensor,
  354. attention_mask: Optional[torch.Tensor],
  355. past_key_values: Optional[Cache] = None,
  356. cache_position: Optional[torch.LongTensor] = None,
  357. **kwargs: Unpack[FlashAttentionKwargs],
  358. ) -> tuple[torch.Tensor, Optional[torch.Tensor], Optional[tuple[torch.Tensor]]]:
  359. input_shape = hidden_states.shape[:-1]
  360. hidden_shape = (*input_shape, -1, self.head_dim)
  361. query_states = self.q_proj(hidden_states).view(hidden_shape).transpose(1, 2)
  362. key_states = self.k_proj(hidden_states).view(hidden_shape).transpose(1, 2)
  363. value_states = self.v_proj(hidden_states).view(hidden_shape).transpose(1, 2)
  364. query_states = self.q_norm(query_states)
  365. key_states = self.k_norm(key_states)
  366. cos, sin = position_embeddings
  367. query_states, key_states = apply_rotary_pos_emb(query_states, key_states, cos, sin)
  368. if past_key_values is not None:
  369. # sin and cos are specific to RoPE models; cache_position needed for the static cache
  370. cache_kwargs = {"sin": sin, "cos": cos, "cache_position": cache_position}
  371. key_states, value_states = past_key_values.update(key_states, value_states, self.layer_idx, cache_kwargs)
  372. attention_interface: Callable = eager_attention_forward
  373. if self.config._attn_implementation != "eager":
  374. attention_interface = ALL_ATTENTION_FUNCTIONS[self.config._attn_implementation]
  375. attn_output, attn_weights = attention_interface(
  376. self,
  377. query_states,
  378. key_states,
  379. value_states,
  380. attention_mask,
  381. dropout=self.attention_dropout if self.training else 0.0,
  382. scaling=self.scaling,
  383. sliding_window=self.sliding_window,
  384. **kwargs,
  385. )
  386. attn_output = attn_output.reshape(*input_shape, -1).contiguous()
  387. attn_output = self.o_proj(attn_output)
  388. return attn_output, attn_weights
  389. class Gemma3DecoderLayer(GradientCheckpointingLayer):
  390. def __init__(self, config: Gemma3TextConfig, layer_idx: int):
  391. super().__init__()
  392. self.config = config
  393. self.hidden_size = config.hidden_size
  394. self.layer_idx = layer_idx
  395. self.attention_type = config.layer_types[layer_idx]
  396. self.self_attn = Gemma3Attention(config=config, layer_idx=layer_idx)
  397. self.mlp = Gemma3MLP(config)
  398. self.input_layernorm = Gemma3RMSNorm(self.hidden_size, eps=config.rms_norm_eps)
  399. self.post_attention_layernorm = Gemma3RMSNorm(self.hidden_size, eps=config.rms_norm_eps)
  400. self.pre_feedforward_layernorm = Gemma3RMSNorm(self.hidden_size, eps=config.rms_norm_eps)
  401. self.post_feedforward_layernorm = Gemma3RMSNorm(self.hidden_size, eps=config.rms_norm_eps)
  402. @deprecate_kwarg("past_key_value", new_name="past_key_values", version="4.58")
  403. def forward(
  404. self,
  405. hidden_states: torch.Tensor,
  406. position_embeddings_global: torch.Tensor,
  407. position_embeddings_local: torch.Tensor,
  408. attention_mask: Optional[torch.Tensor] = None,
  409. position_ids: Optional[torch.LongTensor] = None,
  410. past_key_values: Optional[Cache] = None,
  411. output_attentions: Optional[bool] = False,
  412. use_cache: Optional[bool] = False,
  413. cache_position: Optional[torch.LongTensor] = None,
  414. **kwargs,
  415. ) -> tuple[torch.FloatTensor, Optional[tuple[torch.FloatTensor, torch.FloatTensor]]]:
  416. residual = hidden_states
  417. hidden_states = self.input_layernorm(hidden_states)
  418. # apply global RoPE to non-sliding layer only
  419. if self.self_attn.is_sliding:
  420. position_embeddings = position_embeddings_local
  421. else:
  422. position_embeddings = position_embeddings_global
  423. hidden_states, self_attn_weights = self.self_attn(
  424. hidden_states=hidden_states,
  425. position_embeddings=position_embeddings,
  426. attention_mask=attention_mask,
  427. position_ids=position_ids,
  428. past_key_values=past_key_values,
  429. output_attentions=output_attentions,
  430. use_cache=use_cache,
  431. cache_position=cache_position,
  432. **kwargs,
  433. )
  434. hidden_states = self.post_attention_layernorm(hidden_states)
  435. hidden_states = residual + hidden_states
  436. residual = hidden_states
  437. hidden_states = self.pre_feedforward_layernorm(hidden_states)
  438. hidden_states = self.mlp(hidden_states)
  439. hidden_states = self.post_feedforward_layernorm(hidden_states)
  440. hidden_states = residual + hidden_states
  441. outputs = (hidden_states,)
  442. if output_attentions:
  443. outputs += (self_attn_weights,)
  444. return outputs
  445. GEMMA3_START_DOCSTRING = None
  446. class Gemma3PreTrainedModel(Gemma2PreTrainedModel):
  447. base_model_prefix = ""
  448. _no_split_modules = [
  449. "Gemma3DecoderLayer",
  450. "SiglipVisionEmbeddings",
  451. "SiglipEncoderLayer",
  452. "SiglipMultiheadAttentionPoolingHead",
  453. ]
  454. def _init_weights(self, module):
  455. PreTrainedModel._init_weights(self, module)
  456. if isinstance(module, Gemma3MultiModalProjector):
  457. module.mm_input_projection_weight.data.zero_()
  458. # We initialize with 0s to be 1 centered as the RMSNorm here does (1 + weight)
  459. elif "RMSNorm" in module.__class__.__name__:
  460. module.weight.data.zero_()
  461. def _bidirectional_window_overlay(sliding_window: int) -> Callable[[int, int, int, int], bool]:
  462. """
  463. Enables a bidirectional mask within the sliding window.
  464. """
  465. def inner_mask(batch_idx: int, head_idx: int, q_idx: int, kv_idx: int) -> bool:
  466. """A token can attend to any other token if their absolute distance is within
  467. the (exclusive) sliding window size (distance < sliding_window)."""
  468. return abs(q_idx - kv_idx) < sliding_window
  469. return inner_mask
  470. class Gemma3TextModel(Gemma2Model):
  471. config: Gemma3TextConfig
  472. def __init__(self, config: Gemma3TextConfig):
  473. super().__init__(config)
  474. # Gemma3 downcasts the below to bfloat16, causing sqrt(3072)=55.4256 to become 55.5. See https://github.com/huggingface/transformers/pull/29402
  475. self.embed_tokens = Gemma3TextScaledWordEmbedding(
  476. config.vocab_size, config.hidden_size, self.padding_idx, embed_scale=self.config.hidden_size**0.5
  477. )
  478. # TODO: raushan fix this after RoPE refactor. For now we hack it by reassigning thetas
  479. # when we want to create a local RoPE layer. Config defaults should hold values for global RoPE
  480. config = copy.deepcopy(config)
  481. config.rope_theta = config.rope_local_base_freq
  482. config.rope_scaling = {"rope_type": "default"}
  483. self.rotary_emb_local = Gemma3RotaryEmbedding(config=config)
  484. def forward(
  485. self,
  486. input_ids: Optional[torch.LongTensor] = None,
  487. attention_mask: Optional[torch.Tensor] = None,
  488. position_ids: Optional[torch.LongTensor] = None,
  489. past_key_values: Optional[Cache] = None,
  490. inputs_embeds: Optional[torch.FloatTensor] = None,
  491. use_cache: Optional[bool] = None,
  492. output_attentions: Optional[bool] = None,
  493. output_hidden_states: Optional[bool] = None,
  494. cache_position: Optional[torch.LongTensor] = None,
  495. **kwargs: Unpack[TransformersKwargs],
  496. ) -> BaseModelOutputWithPast:
  497. output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
  498. output_hidden_states = (
  499. output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
  500. )
  501. use_cache = use_cache if use_cache is not None else self.config.use_cache
  502. if (input_ids is None) ^ (inputs_embeds is not None):
  503. raise ValueError("You must specify exactly one of input_ids or inputs_embeds")
  504. if self.gradient_checkpointing and self.training and use_cache:
  505. logger.warning_once(
  506. "`use_cache=True` is incompatible with gradient checkpointing. Setting `use_cache=False`."
  507. )
  508. use_cache = False
  509. if inputs_embeds is None:
  510. inputs_embeds = self.embed_tokens(input_ids)
  511. if use_cache and past_key_values is None and not self.training:
  512. past_key_values = DynamicCache(config=self.config)
  513. if cache_position is None:
  514. past_seen_tokens = past_key_values.get_seq_length() if past_key_values is not None else 0
  515. cache_position = torch.arange(
  516. past_seen_tokens,
  517. past_seen_tokens + inputs_embeds.shape[1],
  518. device=inputs_embeds.device,
  519. )
  520. if position_ids is None:
  521. position_ids = cache_position.unsqueeze(0)
  522. # It may already have been prepared by e.g. `generate`
  523. if not isinstance(causal_mask_mapping := attention_mask, dict):
  524. # Prepare mask arguments
  525. mask_kwargs = {
  526. "config": self.config,
  527. "input_embeds": inputs_embeds,
  528. "attention_mask": attention_mask,
  529. "cache_position": cache_position,
  530. "past_key_values": past_key_values,
  531. "position_ids": position_ids,
  532. }
  533. sliding_mask_kwargs = mask_kwargs.copy()
  534. if self.config.use_bidirectional_attention:
  535. mask_kwargs["or_mask_function"] = lambda *args: torch.tensor(True, dtype=torch.bool)
  536. sliding_mask_kwargs["or_mask_function"] = _bidirectional_window_overlay(self.config.sliding_window)
  537. # Create the masks
  538. causal_mask_mapping = {
  539. "full_attention": create_causal_mask(**mask_kwargs),
  540. "sliding_attention": create_sliding_window_causal_mask(**sliding_mask_kwargs),
  541. }
  542. # embed positions
  543. hidden_states = inputs_embeds
  544. # create position embeddings to be shared across the decoder layers
  545. position_embeddings_global = self.rotary_emb(hidden_states, position_ids)
  546. position_embeddings_local = self.rotary_emb_local(hidden_states, position_ids)
  547. # decoder layers
  548. all_hidden_states = () if output_hidden_states else None
  549. all_self_attns = () if output_attentions else None
  550. for decoder_layer in self.layers[: self.config.num_hidden_layers]:
  551. if output_hidden_states:
  552. all_hidden_states += (hidden_states,)
  553. layer_outputs = decoder_layer(
  554. hidden_states,
  555. position_embeddings_global=position_embeddings_global,
  556. position_embeddings_local=position_embeddings_local,
  557. attention_mask=causal_mask_mapping[decoder_layer.attention_type],
  558. position_ids=position_ids,
  559. past_key_values=past_key_values,
  560. output_attentions=output_attentions,
  561. use_cache=use_cache,
  562. cache_position=cache_position,
  563. **kwargs,
  564. )
  565. hidden_states = layer_outputs[0]
  566. if output_attentions:
  567. all_self_attns += (layer_outputs[1],)
  568. hidden_states = self.norm(hidden_states)
  569. if output_hidden_states:
  570. all_hidden_states += (hidden_states,)
  571. return BaseModelOutputWithPast(
  572. last_hidden_state=hidden_states,
  573. past_key_values=past_key_values,
  574. hidden_states=all_hidden_states,
  575. attentions=all_self_attns,
  576. )
  577. class Gemma3ForCausalLM(Gemma2ForCausalLM):
  578. config: Gemma3TextConfig
  579. base_model_prefix = "language_model"
  580. def __init__(self, config: Gemma3TextConfig):
  581. super().__init__(config)
  582. self.model = Gemma3TextModel(config)
  583. class Gemma3MultiModalProjector(nn.Module):
  584. def __init__(self, config: Gemma3Config):
  585. super().__init__()
  586. self.mm_input_projection_weight = nn.Parameter(
  587. torch.zeros(config.vision_config.hidden_size, config.text_config.hidden_size)
  588. )
  589. self.mm_soft_emb_norm = Gemma3RMSNorm(
  590. config.vision_config.hidden_size, eps=config.vision_config.layer_norm_eps
  591. )
  592. self.patches_per_image = int(config.vision_config.image_size // config.vision_config.patch_size)
  593. self.tokens_per_side = int(config.mm_tokens_per_image**0.5)
  594. self.kernel_size = self.patches_per_image // self.tokens_per_side
  595. self.avg_pool = nn.AvgPool2d(kernel_size=self.kernel_size, stride=self.kernel_size)
  596. def forward(self, vision_outputs: torch.Tensor):
  597. batch_size, _, seq_length = vision_outputs.shape
  598. reshaped_vision_outputs = vision_outputs.transpose(1, 2)
  599. reshaped_vision_outputs = reshaped_vision_outputs.reshape(
  600. batch_size, seq_length, self.patches_per_image, self.patches_per_image
  601. )
  602. reshaped_vision_outputs = reshaped_vision_outputs.contiguous()
  603. pooled_vision_outputs = self.avg_pool(reshaped_vision_outputs)
  604. pooled_vision_outputs = pooled_vision_outputs.flatten(2)
  605. pooled_vision_outputs = pooled_vision_outputs.transpose(1, 2)
  606. normed_vision_outputs = self.mm_soft_emb_norm(pooled_vision_outputs)
  607. projected_vision_outputs = torch.matmul(normed_vision_outputs, self.mm_input_projection_weight)
  608. return projected_vision_outputs.type_as(vision_outputs)
  609. def token_type_ids_mask_function(
  610. token_type_ids: Optional[torch.Tensor],
  611. image_group_ids: Optional[torch.Tensor],
  612. tokens_per_image: int,
  613. ) -> Optional[Callable]:
  614. """
  615. This function adds the correct offsets to the `q_idx` and `kv_idx` as the torch API can only accept lengths,
  616. not start and end indices.
  617. """
  618. # Do not return an additional mask in this case
  619. if token_type_ids is None:
  620. return None
  621. def inner_mask(batch_idx: int, head_idx: int, q_idx: int, kv_idx: int) -> bool:
  622. # If it's 1 for both query and key/value, we are in an image block
  623. # NOTE: static cache shape goes beyond input seq length, while token_type_ids.shape[1] == input seq length
  624. # Since vmap doesn't support `if statement` we workaround it with `torch.where`
  625. safe_idx = torch.where(kv_idx < token_type_ids.shape[1], kv_idx, 0)
  626. token_type_ids_at_kv_idx = token_type_ids[batch_idx, safe_idx]
  627. token_type_ids_at_kv_idx = torch.where(kv_idx < token_type_ids.shape[1], token_type_ids_at_kv_idx, 0)
  628. image_group_ids_at_kv_idx = image_group_ids[batch_idx, safe_idx]
  629. image_group_ids_at_kv_idx = torch.where(kv_idx < image_group_ids.shape[1], image_group_ids_at_kv_idx, -1)
  630. is_image_block = (token_type_ids[batch_idx, q_idx] == 1) & (token_type_ids_at_kv_idx == 1)
  631. same_image_block = image_group_ids[batch_idx, q_idx] == image_group_ids_at_kv_idx
  632. # This is bidirectional attention whenever we are dealing with image tokens
  633. return is_image_block & same_image_block
  634. return inner_mask
  635. class Gemma3Model(PaliGemmaModel):
  636. # we are filtering the logits/labels so we shouldn't divide the loss based on num_items_in_batch
  637. accepts_loss_kwargs = False
  638. def __init__(self, config: Gemma3Config):
  639. super().__init__(config)
  640. del self.text_config_dtype
  641. def get_image_features(self, pixel_values: torch.Tensor) -> torch.Tensor:
  642. """
  643. Projects the last hidden state from the vision model into language model space.
  644. Args:
  645. pixel_values (`torch.FloatTensor]` of shape `(batch_size, channels, height, width)`)
  646. The tensors corresponding to the input images.
  647. Returns:
  648. image_features (`torch.Tensor`): Image feature tensor of shape `(num_images, image_length, embed_dim)`).
  649. """
  650. vision_outputs = self.vision_tower(pixel_values=pixel_values).last_hidden_state
  651. image_features = self.multi_modal_projector(vision_outputs)
  652. return image_features
  653. def _update_causal_mask(self, **super_kwargs):
  654. raise AttributeError("We don't want to inherit it")
  655. @can_return_tuple
  656. @auto_docstring
  657. def forward(
  658. self,
  659. input_ids: Optional[torch.LongTensor] = None,
  660. pixel_values: Optional[torch.FloatTensor] = None,
  661. attention_mask: Optional[torch.Tensor] = None,
  662. position_ids: Optional[torch.LongTensor] = None,
  663. past_key_values: Optional[Cache] = None,
  664. token_type_ids: Optional[torch.LongTensor] = None,
  665. cache_position: Optional[torch.LongTensor] = None,
  666. inputs_embeds: Optional[torch.FloatTensor] = None,
  667. labels: Optional[torch.LongTensor] = None,
  668. use_cache: Optional[bool] = None,
  669. output_attentions: Optional[bool] = None,
  670. output_hidden_states: Optional[bool] = None,
  671. return_dict: Optional[bool] = None,
  672. **lm_kwargs,
  673. ) -> Union[tuple, Gemma3ModelOutputWithPast]:
  674. if (input_ids is None) ^ (inputs_embeds is not None):
  675. raise ValueError("You must specify exactly one of input_ids or inputs_embeds")
  676. output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
  677. output_hidden_states = (
  678. output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
  679. )
  680. return_dict = return_dict if return_dict is not None else self.config.use_return_dict
  681. # Replace image id with PAD if the image token if OOV, to avoid index-errors
  682. if input_ids is not None and self.config.image_token_id >= self.vocab_size:
  683. special_image_mask = input_ids == self.config.image_token_id
  684. llm_input_ids = input_ids.clone()
  685. llm_input_ids[special_image_mask] = 0
  686. else:
  687. llm_input_ids = input_ids
  688. if inputs_embeds is None:
  689. inputs_embeds = self.get_input_embeddings()(llm_input_ids)
  690. if cache_position is None:
  691. past_seen_tokens = past_key_values.get_seq_length() if past_key_values is not None else 0
  692. cache_position = torch.arange(
  693. past_seen_tokens, past_seen_tokens + inputs_embeds.shape[1], device=inputs_embeds.device
  694. )
  695. # Merge text and images
  696. if pixel_values is not None:
  697. image_features = self.get_image_features(pixel_values)
  698. image_features = image_features.to(inputs_embeds.device, inputs_embeds.dtype)
  699. special_image_mask = self.get_placeholder_mask(
  700. input_ids, inputs_embeds=inputs_embeds, image_features=image_features
  701. )
  702. inputs_embeds = inputs_embeds.masked_scatter(special_image_mask, image_features)
  703. # It may already have been prepared by e.g. `generate`
  704. if not isinstance(causal_mask_mapping := attention_mask, dict):
  705. # Prepare mask arguments
  706. mask_kwargs = {
  707. "config": self.config.get_text_config(),
  708. "input_embeds": inputs_embeds,
  709. "attention_mask": attention_mask,
  710. "cache_position": cache_position,
  711. "past_key_values": past_key_values,
  712. "position_ids": position_ids,
  713. }
  714. # NOTE: this `is_prefill` logic is not flawless, it fails when we're using a cache eagerly initialized
  715. # (e.g. compiled prefill) AND `pixel_values` are not provided. Determining prefill in that case requires
  716. # checking data values, which is not compile-compatible.
  717. is_prefill = (
  718. not use_cache
  719. or past_key_values is None
  720. or not past_key_values.is_initialized
  721. or pixel_values is not None
  722. )
  723. if token_type_ids is not None and is_prefill:
  724. # We need to pass an additional mask function to account for token type ids, and it needs to be an `or`
  725. # First find where a new image block starts: 1 if image and previous not image
  726. # The images cannot attend to future images, but can attend to all prev images and to itself
  727. # bidirectionally
  728. is_image = (token_type_ids == 1).to(cache_position.device)
  729. new_image_start = is_image & ~nn.functional.pad(is_image, (1, 0), value=0)[:, :-1]
  730. image_group_ids = torch.cumsum(new_image_start.int(), dim=1) - 1
  731. image_group_ids = torch.where(
  732. is_image, image_group_ids, torch.full_like(token_type_ids, -1, device=is_image.device)
  733. )
  734. mask_kwargs["or_mask_function"] = token_type_ids_mask_function(
  735. token_type_ids.to(cache_position.device), image_group_ids, self.config.mm_tokens_per_image
  736. )
  737. # Create the masks
  738. causal_mask_mapping = {
  739. "full_attention": create_causal_mask(**mask_kwargs),
  740. "sliding_attention": create_sliding_window_causal_mask(**mask_kwargs),
  741. }
  742. outputs = self.language_model(
  743. attention_mask=causal_mask_mapping,
  744. position_ids=position_ids,
  745. past_key_values=past_key_values,
  746. inputs_embeds=inputs_embeds,
  747. use_cache=use_cache,
  748. output_attentions=output_attentions,
  749. output_hidden_states=output_hidden_states,
  750. return_dict=True,
  751. cache_position=cache_position,
  752. **lm_kwargs,
  753. )
  754. return Gemma3ModelOutputWithPast(
  755. last_hidden_state=outputs.last_hidden_state,
  756. past_key_values=outputs.past_key_values if use_cache else None,
  757. hidden_states=outputs.hidden_states,
  758. attentions=outputs.attentions,
  759. image_hidden_states=image_features if pixel_values is not None else None,
  760. )
  761. class Gemma3ForConditionalGeneration(PaliGemmaForConditionalGeneration):
  762. # we are filtering the logits/labels so we shouldn't divide the loss based on num_items_in_batch
  763. # Fix: https://github.com/huggingface/transformers/issues/40564
  764. accepts_loss_kwargs = False
  765. @auto_docstring
  766. def forward(
  767. self,
  768. input_ids: Optional[torch.LongTensor] = None,
  769. pixel_values: Optional[torch.FloatTensor] = None,
  770. attention_mask: Optional[torch.Tensor] = None,
  771. position_ids: Optional[torch.LongTensor] = None,
  772. past_key_values: Optional[Cache] = None,
  773. token_type_ids: Optional[torch.LongTensor] = None,
  774. cache_position: Optional[torch.LongTensor] = None,
  775. inputs_embeds: Optional[torch.FloatTensor] = None,
  776. labels: Optional[torch.LongTensor] = None,
  777. use_cache: Optional[bool] = None,
  778. output_attentions: Optional[bool] = None,
  779. output_hidden_states: Optional[bool] = None,
  780. return_dict: Optional[bool] = None,
  781. logits_to_keep: Union[int, torch.Tensor] = 0,
  782. **lm_kwargs,
  783. ) -> Union[tuple, Gemma3CausalLMOutputWithPast]:
  784. r"""
  785. labels (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*):
  786. Labels for computing the masked language modeling loss. Indices should either be in `[0, ...,
  787. config.text_config.vocab_size]` or -100 (see `input_ids` docstring). Tokens with indices set to `-100` are ignored
  788. (masked), the loss is only computed for the tokens with labels in `[0, ..., config.text_config.vocab_size]`.
  789. Example:
  790. ```python
  791. >>> from PIL import Image
  792. >>> import requests
  793. >>> from transformers import AutoProcessor, Gemma3ForConditionalGeneration
  794. >>> model = Gemma3ForConditionalGeneration.from_pretrained("google/gemma-3-4b-it")
  795. >>> processor = AutoProcessor.from_pretrained("google/gemma-3-4b-it")
  796. >>> messages = [
  797. ... {
  798. ... "role": "system",
  799. ... "content": [
  800. ... {"type": "text", "text": "You are a helpful assistant."}
  801. ... ]
  802. ... },
  803. ... {
  804. ... "role": "user", "content": [
  805. ... {"type": "image", "url": "https://huggingface.co/datasets/huggingface/documentation-images/resolve/main/pipeline-cat-chonk.jpeg"},
  806. ... {"type": "text", "text": "Where is the cat standing?"},
  807. ... ]
  808. ... },
  809. ... ]
  810. >>> inputs = processor.apply_chat_template(
  811. ... messages,
  812. ... tokenize=True,
  813. ... return_dict=True,
  814. ... return_tensors="pt",
  815. ... add_generation_prompt=True
  816. ... )
  817. >>> # Generate
  818. >>> generate_ids = model.generate(**inputs)
  819. >>> processor.batch_decode(generate_ids, skip_special_tokens=True, clean_up_tokenization_spaces=False)[0]
  820. "user\nYou are a helpful assistant.\n\n\n\n\n\nWhere is the cat standing?\nmodel\nBased on the image, the cat is standing in a snowy area, likely outdoors. It appears to"
  821. ```
  822. """
  823. output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
  824. output_hidden_states = (
  825. output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
  826. )
  827. return_dict = return_dict if return_dict is not None else self.config.use_return_dict
  828. outputs = self.model(
  829. input_ids=input_ids,
  830. pixel_values=pixel_values,
  831. token_type_ids=token_type_ids,
  832. attention_mask=attention_mask,
  833. position_ids=position_ids,
  834. past_key_values=past_key_values,
  835. inputs_embeds=inputs_embeds,
  836. use_cache=use_cache,
  837. labels=labels,
  838. output_attentions=output_attentions,
  839. output_hidden_states=output_hidden_states,
  840. return_dict=return_dict,
  841. cache_position=cache_position,
  842. **lm_kwargs,
  843. )
  844. hidden_states = outputs[0]
  845. # Only compute necessary logits, and do not upcast them to float if we are not computing the loss
  846. slice_indices = slice(-logits_to_keep, None) if isinstance(logits_to_keep, int) else logits_to_keep
  847. logits = self.lm_head(hidden_states[:, slice_indices, :])
  848. loss = None
  849. if labels is not None:
  850. # Upcast to float if we need to compute the loss to avoid potential precision issues
  851. logits = logits.float()
  852. shift_logits = logits[..., :-1, :]
  853. shift_labels = labels[..., 1:]
  854. if attention_mask is not None:
  855. # we use the input attention mask to shift the logits and labels, because it is 2D.
  856. # we also crop attn mask in case it is longer, which happens in PrefixTuning with peft
  857. shift_attention_mask = attention_mask[:, -shift_logits.shape[1] :].to(logits.device)
  858. shift_logits = shift_logits[shift_attention_mask.to(logits.device) != 0].contiguous()
  859. shift_labels = shift_labels[shift_attention_mask.to(shift_labels.device) != 0].contiguous()
  860. else:
  861. shift_logits = shift_logits.contiguous()
  862. shift_labels = shift_labels.contiguous()
  863. # Flatten the tokens
  864. loss_fct = nn.CrossEntropyLoss()
  865. flat_logits = shift_logits.view(-1, self.config.text_config.vocab_size)
  866. flat_labels = shift_labels.view(-1).to(shift_logits.device)
  867. loss = loss_fct(flat_logits, flat_labels)
  868. if not return_dict:
  869. output = (logits,) + outputs[1:]
  870. return (loss,) + output if loss is not None else output
  871. return Gemma3CausalLMOutputWithPast(
  872. loss=loss,
  873. logits=logits,
  874. past_key_values=outputs.past_key_values,
  875. hidden_states=outputs.hidden_states,
  876. attentions=outputs.attentions,
  877. image_hidden_states=outputs.image_hidden_states,
  878. )
  879. def prepare_inputs_for_generation(
  880. self,
  881. input_ids,
  882. past_key_values=None,
  883. inputs_embeds=None,
  884. cache_position=None,
  885. position_ids=None,
  886. pixel_values=None,
  887. attention_mask=None,
  888. token_type_ids=None,
  889. use_cache=True,
  890. logits_to_keep=None,
  891. labels=None,
  892. **kwargs,
  893. ):
  894. # Overwritten -- custom `position_ids` and `pixel_values` handling
  895. model_inputs = super().prepare_inputs_for_generation(
  896. input_ids,
  897. past_key_values=past_key_values,
  898. inputs_embeds=inputs_embeds,
  899. attention_mask=attention_mask,
  900. position_ids=position_ids,
  901. cache_position=cache_position,
  902. use_cache=use_cache,
  903. logits_to_keep=logits_to_keep,
  904. token_type_ids=token_type_ids,
  905. **kwargs,
  906. )
  907. # If we're in cached decoding stage, pixel values should be None because input ids do not contain special image token anymore
  908. # Otherwise we need pixel values to be passed to model. NOTE: use_cache=False needs pixel_values always
  909. if cache_position[0] == 0:
  910. model_inputs["pixel_values"] = pixel_values
  911. return model_inputs
  912. def _prepare_4d_causal_attention_mask_with_cache_position(self, **super_kwargs):
  913. raise AttributeError("We don't want to inherit it")
  914. @staticmethod
  915. def create_masks_for_generate(
  916. config: PretrainedConfig,
  917. input_embeds: torch.Tensor,
  918. attention_mask: Optional[torch.Tensor],
  919. cache_position: torch.Tensor,
  920. past_key_values: Optional[Cache],
  921. position_ids: Optional[torch.Tensor],
  922. token_type_ids: Optional[torch.Tensor] = None,
  923. **kwargs,
  924. ) -> dict:
  925. # Prepare mask arguments
  926. mask_kwargs = {
  927. "config": config.get_text_config(),
  928. "input_embeds": input_embeds,
  929. "attention_mask": attention_mask,
  930. "cache_position": cache_position,
  931. "past_key_values": past_key_values,
  932. "position_ids": position_ids,
  933. }
  934. # Add the token type ids mask for generate as well
  935. if token_type_ids is not None and input_embeds.shape[1] != 1:
  936. # We need to pass an additional mask function to account for token type ids, and it needs to be an `or`
  937. # First find where a new image block starts: 1 if image and previous not image
  938. # The images cannot attend to future images, but can attend to all prev images and to itself bidirectionally
  939. is_image = (token_type_ids == 1).to(cache_position.device)
  940. new_image_start = is_image & ~nn.functional.pad(is_image, (1, 0), value=0)[:, :-1]
  941. image_group_ids = torch.cumsum(new_image_start.int(), dim=1) - 1
  942. image_group_ids = torch.where(is_image, image_group_ids, torch.full_like(token_type_ids, -1))
  943. mask_kwargs["or_mask_function"] = token_type_ids_mask_function(
  944. token_type_ids.to(cache_position.device), image_group_ids, config.mm_tokens_per_image
  945. )
  946. return create_masks_for_generate(**mask_kwargs)
  947. class Gemma3ForSequenceClassification(Gemma3PreTrainedModel):
  948. _checkpoint_conversion_mapping = {
  949. "^language_model.model": "model.language_model",
  950. "^vision_tower": "model.vision_tower",
  951. "^multi_modal_projector": "model.multi_modal_projector",
  952. }
  953. def __init__(self, config):
  954. super().__init__(config)
  955. self.num_labels = config.num_labels
  956. self.model = Gemma3Model(config)
  957. self.score = nn.Linear(config.text_config.hidden_size, self.num_labels, bias=False)
  958. # Initialize weights and apply final processing
  959. self.post_init()
  960. def get_input_embeddings(self):
  961. return self.model.get_input_embeddings()
  962. def set_input_embeddings(self, value):
  963. self.model.set_input_embeddings(value)
  964. @can_return_tuple
  965. @auto_docstring
  966. def forward(
  967. self,
  968. input_ids: Optional[torch.LongTensor] = None,
  969. pixel_values: Optional[torch.FloatTensor] = None,
  970. attention_mask: Optional[torch.Tensor] = None,
  971. position_ids: Optional[torch.LongTensor] = None,
  972. past_key_values: Optional[Cache] = None,
  973. inputs_embeds: Optional[torch.FloatTensor] = None,
  974. token_type_ids: Optional[torch.LongTensor] = None,
  975. labels: Optional[torch.LongTensor] = None,
  976. use_cache: Optional[bool] = None,
  977. **kwargs: Unpack[TransformersKwargs],
  978. ) -> SequenceClassifierOutputWithPast:
  979. r"""
  980. labels (`torch.LongTensor` of shape `(batch_size,)`, *optional*):
  981. Labels for computing the sequence classification/regression loss. Indices should be in `[0, ...,
  982. config.num_labels - 1]`. If `config.num_labels == 1` a regression loss is computed (Mean-Square loss), If
  983. `config.num_labels > 1` a classification loss is computed (Cross-Entropy).
  984. """
  985. transformer_outputs = self.model(
  986. input_ids,
  987. attention_mask=attention_mask,
  988. pixel_values=pixel_values,
  989. position_ids=position_ids,
  990. past_key_values=past_key_values,
  991. inputs_embeds=inputs_embeds,
  992. token_type_ids=token_type_ids,
  993. use_cache=use_cache,
  994. **kwargs,
  995. )
  996. hidden_states = transformer_outputs.last_hidden_state
  997. logits = self.score(hidden_states)
  998. if input_ids is not None:
  999. batch_size = input_ids.shape[0]
  1000. else:
  1001. batch_size = inputs_embeds.shape[0]
  1002. if self.config.text_config.pad_token_id is None and batch_size != 1:
  1003. raise ValueError("Cannot handle batch sizes > 1 if no padding token is defined.")
  1004. if self.config.text_config.pad_token_id is None:
  1005. last_non_pad_token = -1
  1006. elif input_ids is not None:
  1007. # To handle both left- and right- padding, we take the rightmost token that is not equal to pad_token_id
  1008. non_pad_mask = (input_ids != self.config.text_config.pad_token_id).to(logits.device, torch.int32)
  1009. token_indices = torch.arange(input_ids.shape[-1], device=logits.device, dtype=torch.int32)
  1010. last_non_pad_token = (token_indices * non_pad_mask).argmax(-1)
  1011. else:
  1012. last_non_pad_token = -1
  1013. logger.warning_once(
  1014. f"{self.__class__.__name__} will not detect padding tokens in `inputs_embeds`. Results may be "
  1015. "unexpected if using padding tokens in conjunction with `inputs_embeds.`"
  1016. )
  1017. pooled_logits = logits[torch.arange(batch_size, device=logits.device), last_non_pad_token]
  1018. loss = None
  1019. if labels is not None:
  1020. loss = self.loss_function(logits=logits, labels=labels, pooled_logits=pooled_logits, config=self.config)
  1021. return SequenceClassifierOutputWithPast(
  1022. loss=loss,
  1023. logits=pooled_logits,
  1024. past_key_values=transformer_outputs.past_key_values,
  1025. hidden_states=transformer_outputs.hidden_states,
  1026. attentions=transformer_outputs.attentions,
  1027. )
  1028. class Gemma3TextForSequenceClassification(GenericForSequenceClassification, Gemma3PreTrainedModel):
  1029. """
  1030. Gemma3TextForSequenceClassification is a text-only sequence classification model that works with Gemma3TextConfig.
  1031. It uses the generic sequence classification implementation for efficiency and consistency.
  1032. """
  1033. config: Gemma3TextConfig
  1034. __all__ = [
  1035. "Gemma3Config",
  1036. "Gemma3TextConfig",
  1037. "Gemma3PreTrainedModel",
  1038. "Gemma3TextModel",
  1039. "Gemma3ForCausalLM",
  1040. "Gemma3ForConditionalGeneration",
  1041. "Gemma3Model",
  1042. "Gemma3ForSequenceClassification",
  1043. "Gemma3TextForSequenceClassification",
  1044. ]