configuration_gemma3n.py 36 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525526527528529530531532533534535536537538539540541542543544545546547548549550551552553554555556557558559560561562563564565566567568569570571572573574575576577578579580581582583584585586587588589590591592593594595596597598599600601602603604605606607608609610611612613614615616617618619620621622623624625626627628629630631632633634635636637638639640641642643644645646647648649650651652653654655656657658659660661662663664665666667668669670671672673674675676677678679680681682
  1. # 🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨
  2. # This file was automatically generated from src/transformers/models/gemma3n/modular_gemma3n.py.
  3. # Do NOT edit this file manually as any edits will be overwritten by the generation of
  4. # the file from the modular. If any change should be done, please apply the change to the
  5. # modular_gemma3n.py file directly. One of our CI enforces this.
  6. # 🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨
  7. # coding=utf-8
  8. # Copyright 2025 Google Inc. HuggingFace Inc. team. All rights reserved.
  9. #
  10. #
  11. # Licensed under the Apache License, Version 2.0 (the "License");
  12. # you may not use this file except in compliance with the License.
  13. # You may obtain a copy of the License at
  14. #
  15. # http://www.apache.org/licenses/LICENSE-2.0
  16. #
  17. # Unless required by applicable law or agreed to in writing, software
  18. # distributed under the License is distributed on an "AS IS" BASIS,
  19. # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  20. # See the License for the specific language governing permissions and
  21. # limitations under the License.
  22. from collections.abc import Sequence
  23. from typing import Any, Optional, Union
  24. from ...configuration_utils import PretrainedConfig, layer_type_validation
  25. from ...modeling_rope_utils import rope_config_validation
  26. from ...utils import is_timm_available, logging, requires_backends
  27. if is_timm_available():
  28. from timm.data import ImageNetInfo, infer_imagenet_subset
  29. logger = logging.get_logger(__name__)
  30. class Gemma3nTextConfig(PretrainedConfig):
  31. r"""
  32. This is the configuration class to store the configuration of a [`Gemma3nTextModel`]. It is used to instantiate an
  33. Gemma3nTextModel model according to the specified arguments, defining the model architecture. Instantiating a
  34. configuration with the defaults will yield a similar configuration to that of the Gemma 3n E4B, e.g.
  35. [google/gemma-3n-E4B](https://huggingface.co/google/gemma-3n-E4B).
  36. Configuration objects that inherit from [`Gemma3nTextConfig`] and can be used to control the model outputs. Read
  37. the documentation from [`Gemma3nTextConfig`] for more information.
  38. Args:
  39. vocab_size (`int`, *optional*, defaults to 262400):
  40. Vocabulary size of the Gemma3nText model. Defines the number of different tokens that can be represented by
  41. the `inputs_ids` passed when calling [`Gemma3nTextModel`]
  42. vocab_size_per_layer_input (`int`, *optional*, defaults to 262144):
  43. Vocabulary size of the per-layer text embeddings that augment the standard embeddings.
  44. hidden_size (`int`, *optional*, defaults to 2048):
  45. Dimension of the hidden representations.
  46. hidden_size_per_layer_input (`int`, *optional*, defaults to 256):
  47. Dimension of the hidden representations for per-layer emebeddings.
  48. intermediate_size (`int` or `Sequence[int]`, *optional*, defaults to 16384):
  49. Dimension of the MLP representations. MatFormer configurations may wish to provide a sequence of integers
  50. to account for variable intermediate_size values across layers. In such cases,
  51. `len(intermediate_size) == num_hidden_layers`.
  52. num_hidden_layers (`int`, *optional*, defaults to 35):
  53. Number of hidden layers in the Transformer decoder.
  54. num_attention_heads (`int`, *optional*, defaults to 8):
  55. Number of attention heads for each attention layer in the Transformer decoder.
  56. num_key_value_heads (`int`, *optional*, defaults to 2):
  57. This is the number of key_value heads that should be used to implement Grouped Query Attention. If
  58. `num_key_value_heads=num_attention_heads`, the model will use Multi Head Attention (MHA), if
  59. `num_key_value_heads=1` the model will use Multi Query Attention (MQA) otherwise GQA is used. When
  60. converting a multi-head checkpoint to a GQA checkpoint, each group key and value head should be constructed
  61. by meanpooling all the original heads within that group. For more details checkout this
  62. [paper](https://huggingface.co/papers/2305.13245). If not specified, will default to `num_attention_heads`.
  63. head_dim (`int`, *optional*, defaults to 256):
  64. The attention head dimension.
  65. hidden_activation (`str` or `function`, *optional*, defaults to `"gelu_pytorch_tanh"`):
  66. The non-linear activation function (function or string) in the decoder. Will default to
  67. `"gelu_pytorch_tanh"` if not specified. `"gelu_pytorch_tanh"` uses an approximation of the `"gelu"`
  68. activation function.
  69. max_position_embeddings (`int`, *optional*, defaults to 32768):
  70. The maximum sequence length that this model might ever be used with.
  71. initializer_range (`float`, *optional*, defaults to 0.02):
  72. The standard deviation of the truncated_normal_initializer for initializing all weight matrices.
  73. rms_norm_eps (`float`, *optional*, defaults to 1e-06):
  74. The epsilon used by the rms normalization layers.
  75. use_cache (`bool`, *optional*, defaults to `True`):
  76. Whether or not the model should return the last key/values attentions (not used by all models). Only
  77. relevant if `config.is_decoder=True`.
  78. pad_token_id (`int`, *optional*, defaults to 0):
  79. Padding token id.
  80. eos_token_id (`int`, *optional*, defaults to 1):
  81. End of stream token id.
  82. bos_token_id (`int`, *optional*, defaults to 2):
  83. Beginning of stream token id.
  84. rope_theta (`float`, *optional*, defaults to 1000000.0):
  85. The base period of the RoPE embeddings.
  86. rope_scaling (`Dict`, *optional*):
  87. Dictionary containing the scaling configuration for the RoPE embeddings used in global attention.
  88. NOTE: if you apply new rope type and you expect the model to work on longer `max_position_embeddings`, we
  89. recommend you to update this value accordingly.
  90. Expected contents:
  91. `rope_type` (`str`):
  92. The sub-variant of RoPE to use. Can be one of ['default', 'linear', 'dynamic', 'yarn', 'longrope',
  93. 'llama3'], with 'default' being the original RoPE implementation.
  94. `factor` (`float`, *optional*):
  95. Used with all rope types except 'default'. The scaling factor to apply to the RoPE embeddings. In
  96. most scaling types, a `factor` of x will enable the model to handle sequences of length x *
  97. original maximum pre-trained length.
  98. `original_max_position_embeddings` (`int`, *optional*):
  99. Used with 'dynamic', 'longrope' and 'llama3'. The original max position embeddings used during
  100. pretraining.
  101. `attention_factor` (`float`, *optional*):
  102. Used with 'yarn' and 'longrope'. The scaling factor to be applied on the attention
  103. computation. If unspecified, it defaults to value recommended by the implementation, using the
  104. `factor` field to infer the suggested value.
  105. `beta_fast` (`float`, *optional*):
  106. Only used with 'yarn'. Parameter to set the boundary for extrapolation (only) in the linear
  107. ramp function. If unspecified, it defaults to 32.
  108. `beta_slow` (`float`, *optional*):
  109. Only used with 'yarn'. Parameter to set the boundary for interpolation (only) in the linear
  110. ramp function. If unspecified, it defaults to 1.
  111. `short_factor` (`List[float]`, *optional*):
  112. Only used with 'longrope'. The scaling factor to be applied to short contexts (<
  113. `original_max_position_embeddings`). Must be a list of numbers with the same length as the hidden
  114. size divided by the number of attention heads divided by 2
  115. `long_factor` (`List[float]`, *optional*):
  116. Only used with 'longrope'. The scaling factor to be applied to long contexts (<
  117. `original_max_position_embeddings`). Must be a list of numbers with the same length as the hidden
  118. size divided by the number of attention heads divided by 2
  119. `low_freq_factor` (`float`, *optional*):
  120. Only used with 'llama3'. Scaling factor applied to low frequency components of the RoPE
  121. `high_freq_factor` (`float`, *optional*):
  122. Only used with 'llama3'. Scaling factor applied to high frequency components of the RoPE
  123. rope_local_base_freq (float, *optional*, defaults to 10000.0):
  124. The base period of the RoPE embeddings for local attention.
  125. attention_bias (`bool`, defaults to `False`, *optional*, defaults to `False`):
  126. Whether to use a bias in the query, key, value and output projection layers during self-attention.
  127. attention_dropout (`float`, *optional*, defaults to 0.0):
  128. The dropout ratio for the attention probabilities.
  129. sliding_window (`int`, *optional*, defaults to 512):
  130. This is the size of the sliding window used by local attention layers.
  131. layer_types (`Optional`, *optional*):
  132. A sequence of strings defining the attention type for that layer as either "sliding_attention" or
  133. "full_attention". If not provided, `layer_types` will de inferred from `num_hidden_layers` using a pattern
  134. of four "sliding_attention" layers followed one "full_attention". The last layer in the model should always
  135. be a "full_attention" layer.
  136. final_logit_softcapping (`float`, *optional*, defaults to 30.0):
  137. Scaling factor when applying tanh softcapping on the logits.
  138. altup_active_idx (`int`, *optional*, defaults to 0):
  139. The index of the prediction from which AltUp will compute additional predictions or correct
  140. altup_coef_clip (`float`, *optional*, defaults to 120.0):
  141. The maximum amplitude of an AltUp prediction or correction coefficient weight.
  142. altup_correct_scale (`bool`, *optional*, defaults to `True`):
  143. If True, apply the `AltUp.correct_output_scale` to the corrected prediction at `altup_active_idx`.
  144. altup_num_inputs (`int`, *optional*, defaults to 4):
  145. The number of predictions that AltUp should be make given the input sequence.
  146. num_kv_shared_layers (`int`, *optional*, defaults to 15):
  147. The number of layer that share KV cache values. During the forward pass, the last `num_kv_shared_layers`
  148. layers in the model "share" the KV values in that each local and global layer in this range uses the KV
  149. cache values computed for the last local or global layer, respectively, before entering this range. The
  150. value should be a multiple of the attention pattern size (see `layer_types` parameter).
  151. laurel_rank (int, *optional*, defaults to 64):
  152. The intermediate size for the linear projections in the Learned Augmented Residual Layer.
  153. activation_sparsity_pattern (Sequence[float], *optional*):
  154. The sparsity factor used to extract the top-k activations for a given layer. The provided Sequence must
  155. explicitly provide a sparsity value for each layer in the model. By default, the first 10 layers are
  156. sparse with a sparsity factor of 0.95 and the rest are dense.
  157. ```python
  158. >>> from transformers import Gemma3nTextModel, Gemma3nTextConfig
  159. >>> # Initializing a Gemma3nText gemma3n_text-E4B style configuration
  160. >>> configuration = Gemma3nTextConfig()
  161. >>> # Initializing a model from the gemma3n_text-E4B style configuration
  162. >>> model = Gemma3nTextModel(configuration)
  163. >>> # Accessing the model configuration
  164. >>> configuration = model.config
  165. ```
  166. """
  167. model_type = "gemma3n_text"
  168. keys_to_ignore_at_inference = ["past_key_values"]
  169. base_model_tp_plan = {
  170. "layers.*.self_attn.q_proj": "colwise",
  171. "layers.*.self_attn.k_proj": "colwise",
  172. "layers.*.self_attn.v_proj": "colwise",
  173. "layers.*.self_attn.o_proj": "rowwise",
  174. "layers.*.mlp.gate_proj": "colwise",
  175. "layers.*.mlp.up_proj": "colwise",
  176. "layers.*.mlp.down_proj": "rowwise",
  177. }
  178. base_model_pp_plan = {
  179. "embed_tokens": (["input_ids"], ["inputs_embeds"]),
  180. "layers": (["hidden_states", "attention_mask"], ["hidden_states"]),
  181. "norm": (["hidden_states"], ["hidden_states"]),
  182. }
  183. def __init__(
  184. self,
  185. vocab_size: int = 262_400,
  186. vocab_size_per_layer_input: int = 262_144,
  187. hidden_size: int = 2048,
  188. hidden_size_per_layer_input: int = 256,
  189. intermediate_size: Union[int, Sequence[int]] = 16_384,
  190. num_hidden_layers: int = 35,
  191. num_attention_heads: int = 8,
  192. num_key_value_heads: int = 2,
  193. head_dim: int = 256,
  194. hidden_activation: str = "gelu_pytorch_tanh",
  195. max_position_embeddings: int = 32_768,
  196. initializer_range: float = 0.02,
  197. rms_norm_eps: float = 1e-6,
  198. use_cache: bool = True,
  199. pad_token_id: int = 0,
  200. eos_token_id: int = 1,
  201. bos_token_id: int = 2,
  202. rope_theta: float = 1_000_000.0,
  203. rope_scaling: Optional[dict[str, Any]] = None,
  204. rope_local_base_freq: float = 10_000.0,
  205. attention_bias: bool = False,
  206. attention_dropout: float = 0.0,
  207. sliding_window: int = 512,
  208. layer_types: Optional[Sequence[str]] = None,
  209. final_logit_softcapping: float = 30.0,
  210. altup_active_idx: int = 0,
  211. altup_coef_clip: float = 120.0,
  212. altup_correct_scale: bool = True,
  213. altup_num_inputs: int = 4,
  214. num_kv_shared_layers: int = 15,
  215. laurel_rank: int = 64,
  216. activation_sparsity_pattern: Optional[Union[float, Sequence[float]]] = None,
  217. **kwargs,
  218. ):
  219. super().__init__(
  220. pad_token_id=pad_token_id,
  221. bos_token_id=bos_token_id,
  222. eos_token_id=eos_token_id,
  223. **kwargs,
  224. )
  225. if isinstance(intermediate_size, Sequence) and (intsize_len := len(intermediate_size)) != num_hidden_layers:
  226. raise ValueError(
  227. "intermediate_size must have an explicit intermediate size for every layer or one for all layers. "
  228. f"Expected {num_hidden_layers} values but got {intsize_len}."
  229. )
  230. elif not isinstance(intermediate_size, Sequence):
  231. intermediate_size = [intermediate_size] * num_hidden_layers
  232. self.vocab_size = vocab_size
  233. self.vocab_size_per_layer_input = vocab_size_per_layer_input
  234. self.max_position_embeddings = max_position_embeddings
  235. self.hidden_size = hidden_size
  236. self.intermediate_size = intermediate_size
  237. self.num_hidden_layers = num_hidden_layers
  238. self.num_attention_heads = num_attention_heads
  239. self.head_dim = head_dim
  240. self.num_key_value_heads = num_key_value_heads
  241. self.initializer_range = initializer_range
  242. self.rms_norm_eps = rms_norm_eps
  243. self.use_cache = use_cache
  244. self.rope_theta = rope_theta
  245. self.attention_bias = attention_bias
  246. self.attention_dropout = attention_dropout
  247. self.hidden_activation = hidden_activation
  248. self.sliding_window = sliding_window
  249. self.final_logit_softcapping = final_logit_softcapping
  250. self.layer_types = layer_types
  251. self.rope_local_base_freq = rope_local_base_freq
  252. self.rope_scaling = rope_scaling
  253. rope_config_validation(self)
  254. if layer_types is None:
  255. self.layer_types = [
  256. "full_attention" if (i + 1) % 5 == 0 else "sliding_attention" for i in range(self.num_hidden_layers)
  257. ]
  258. else:
  259. self.layer_types = layer_types
  260. layer_type_validation(self.layer_types, self.num_hidden_layers)
  261. self.hidden_size_per_layer_input = hidden_size_per_layer_input
  262. self.num_kv_shared_layers = num_kv_shared_layers
  263. self.altup_active_idx = altup_active_idx
  264. self.altup_coef_clip = altup_coef_clip
  265. self.altup_correct_scale = altup_correct_scale
  266. self.altup_num_inputs = altup_num_inputs
  267. self.laurel_rank = laurel_rank
  268. if activation_sparsity_pattern is None:
  269. num_sparse_layers = 10 if num_hidden_layers > 10 else 0
  270. activation_sparsity_pattern = [0.95] * num_sparse_layers + [0.0] * (num_hidden_layers - num_sparse_layers)
  271. if (len_asp := len(activation_sparsity_pattern)) != num_hidden_layers:
  272. raise ValueError(
  273. "activation_sparsity_pattern must have an explicit activation sparsity value for every layer."
  274. f"Expected {num_hidden_layers} values but got {len_asp}."
  275. )
  276. self.activation_sparsity_pattern = activation_sparsity_pattern
  277. class Gemma3nAudioConfig(PretrainedConfig):
  278. r"""
  279. This is the configuration class to store the configuration of a [`Gemma3nAudioEncoder`]. It is used to instantiate
  280. an `Gemma3nAudioEncoder` model according to the specified arguments, defining the model architecture. Instantiating
  281. a configuration with the defaults will yield a similar configuration to that of the Gemma 3n E4B, e.g.,
  282. [google/gemma-3n-E4B](https://huggingface.co/google/gemma-3n-E4B).
  283. Configuration objects that inherit from [`Gemma3nAudioConfig`] and can be used to control the model outputs. Read
  284. the documentation from [`Gemma3nAudioConfig`] for more information.
  285. Args:
  286. vocab_size (`int`, *optional*, defaults to 128):
  287. Vocabulary size of the additional hard-token embeddings for audio model. These augment the embeddings
  288. included in the `Gemma3nTextModel` to provide, e.g., the end of audio and audio soft token placeholder
  289. tokens when converting `input_ids` to embeddings in the `Gemma3nForConditionalGeneration` model.
  290. vocab_offset (`int`, *optional*, defaults to 262272):
  291. Offset between the tokenizer vocab index for the token ids embedded by `Gemma3nMultimodalEmbedder` and the
  292. 0-indexed `Gemma3nMultimodalEmbedder.embedding` table.
  293. input_feat_size (`int`, *optional*, defaults to 128):
  294. The number of channels in each mel-spectrogram frame.
  295. hidden_size (`int`, *optional*, defaults to 1536):
  296. Dimension of the hidden representations.
  297. rms_norm_eps (`float`, *optional*, defaults to 1e-06):
  298. The epsilon used by the rms normalization layers.
  299. gradient_clipping (`float`, *optional*, defaults to 10000000000.0):
  300. Clipping value used to stabilize extremely large gradient values.
  301. conf_attention_chunk_size (`int`, *optional*, defaults to 12):
  302. The sub-sequence size for local attention processing inside the Conformer ("conf") section of the
  303. Universal Speech Model.
  304. conf_attention_context_left (`int`, *optional*, defaults to 13):
  305. The left context size of the local attention inside the Conformer ("conf") section of the
  306. Universal Speech Model.
  307. conf_attention_context_right (`int`, *optional*, defaults to 0):
  308. The right context size of the local attention inside the Conformer ("conf") section of the
  309. Universal Speech Model.
  310. conf_attention_logit_cap (`float`, *optional*, defaults to 50.0):
  311. Logit cap applied during local attention inside the Conformer ("conf") section of the
  312. Universal Speech Model.
  313. conf_num_attention_heads (`int`, *optional*, defaults to 8):
  314. The number of attention heads in local attention inside the Conformer ("conf") section of the
  315. Universal Speech Model.
  316. conf_num_hidden_layers (`int`, *optional*, defaults to 12):
  317. The number of layers that use local attention inside the Conformer ("conf") section of the
  318. Universal Speech Model.
  319. conf_conv_kernel_size (`int`, *optional*, defaults to 5):
  320. Convolution kernel size for the conformer block inside the Conformer ("conf") section of the
  321. Universal Speech Model.
  322. conf_reduction_factor (`int`, *optional*, defaults to 4):
  323. Reduction factor used in the conformer block inside the Conformer ("conf") section of the
  324. Universal Speech Model.
  325. conf_residual_weight (`float`, *optional*, defaults to 0.5):
  326. Residual connection weight inside the Conformer ("conf") section of the
  327. Universal Speech Model.
  328. sscp_conv_channel_size (`tuple(int, int)`, *optional*, defaults to `(128, 32)`):
  329. The channel sizes for the first and second convolutional layers in the Sub-sample Convolution Projection
  330. ("sscp") section of the Universal Speech Model.
  331. sscp_conv_group_norm_eps (`float`, *optional*, defaults to 0.001):
  332. Epsilon used in group normalization in the subsample convolution projection in the Sub-sample Convolution
  333. Projection ("sscp") section of the Universal Speech Model.
  334. sscp_conv_kernel_size (`tuple(tuple(int, int), tuple(int, int))`, *optional*, defaults to `((3, 3), (3, 3))`):
  335. Kernel sizes of the two convolutional layers in the subsample convolution projection in the Sub-sample
  336. Convolution Projection ("sscp") section of the Universal Speech Model. The kernel sizes are specified as a
  337. tuple of height and width for each layer, where the height corresponds to the time dimension and the width
  338. corresponds to the frequency dimension.
  339. sscp_conv_stride_size (`tuple(tuple(int, int), tuple(int, int))`, *optional*, defaults to `((2, 2), (2, 2))`):
  340. Stride sizes of the two convolutional layers in the subsample convolution projection in the Sub-sample
  341. Convolution Projection ("sscp") section of the Universal Speech Model. The stride sizes are specified as a
  342. tuple of height and width for each layer, where the height corresponds to the time dimension and the width
  343. corresponds to the frequency dimension.
  344. Example:
  345. ```python
  346. >>> from transformers import Gemma3nAudioConfig, Gemma3nAudioEncoder
  347. >>> # Initializing a Gemma3nAudioEncoder gemma3n_audio-E4B-style configuration
  348. >>> configuration = Gemma3nAudioConfig()
  349. >>> # Initializing a model from the gemma3n_audio-E4B style configuration
  350. >>> model = Gemma3nAudioEncoder(configuration)
  351. >>> # Accessing the model configuration
  352. >>> configuration = model.config
  353. ```
  354. """
  355. model_type = "gemma3n_audio"
  356. def __init__(
  357. self,
  358. vocab_size: int = 128,
  359. vocab_offset: int = 262_144 + 128, # text vocab size + vision vocab size
  360. input_feat_size: int = 128,
  361. hidden_size: int = 1536,
  362. rms_norm_eps: float = 1e-6,
  363. gradient_clipping: float = 10_000_000_000.0,
  364. conf_attention_chunk_size: int = 12,
  365. conf_attention_context_left: int = 13,
  366. conf_attention_context_right: int = 0,
  367. conf_attention_logit_cap: float = 50.0,
  368. conf_num_attention_heads: int = 8,
  369. conf_num_hidden_layers: int = 12,
  370. conf_conv_kernel_size: int = 5,
  371. conf_reduction_factor: int = 4,
  372. conf_residual_weight: float = 0.5,
  373. sscp_conv_channel_size: tuple[int, int] = (128, 32),
  374. sscp_conv_group_norm_eps: float = 1e-3,
  375. sscp_conv_kernel_size: tuple[tuple[int, int], tuple[int, int]] = (
  376. (3, 3),
  377. (3, 3),
  378. ),
  379. sscp_conv_stride_size: tuple[tuple[int, int], tuple[int, int]] = (
  380. (2, 2),
  381. (2, 2),
  382. ),
  383. **kwargs,
  384. ):
  385. super().__init__(**kwargs)
  386. self.input_feat_size = input_feat_size
  387. self.hidden_size = hidden_size
  388. self.rms_norm_eps = rms_norm_eps
  389. self.vocab_size = vocab_size
  390. self.vocab_offset = vocab_offset
  391. self.gradient_clipping = gradient_clipping
  392. self.conf_attention_chunk_size = conf_attention_chunk_size
  393. self.conf_attention_context_left = conf_attention_context_left
  394. self.conf_attention_context_right = conf_attention_context_right
  395. self.conf_attention_logit_cap = conf_attention_logit_cap
  396. self.conf_num_attention_heads = conf_num_attention_heads
  397. self.conf_num_hidden_layers = conf_num_hidden_layers
  398. self.conf_conv_kernel_size = conf_conv_kernel_size
  399. self.conf_reduction_factor = conf_reduction_factor
  400. self.conf_residual_weight = conf_residual_weight
  401. self.sscp_conv_channel_size = sscp_conv_channel_size
  402. self.sscp_conv_group_norm_eps = sscp_conv_group_norm_eps
  403. self.sscp_conv_kernel_size = sscp_conv_kernel_size
  404. self.sscp_conv_stride_size = sscp_conv_stride_size
  405. class Gemma3nVisionConfig(PretrainedConfig):
  406. r"""
  407. This is the configuration class to store the configuration for a timm backbone [`TimmWrapper`]. It is used to
  408. instantiate an timm model model according to the specified arguments, defining the model architecture.
  409. Instantiating a configuration with the defaults will yield a similar configuration to that of the Gemma 3n E4B
  410. vision tower, e.g. [google/gemma-3n-E4B](https://huggingface.co/google/gemma-3n-E4B).
  411. Configuration objects inherit from [`Gemma3nVisionConfig`] and can be used to control the model outputs. Read the
  412. documentation from [`Gemma3nVisionConfig`] for more information.
  413. Config loads imagenet label descriptions and stores them in `id2label` attribute, `label2id` attribute for default
  414. imagenet models is set to `None` due to occlusions in the label descriptions.
  415. Args:
  416. initializer_range (`float`, *optional*, defaults to 0.02):
  417. The standard deviation of the truncated_normal_initializer for initializing all weight matrices.
  418. do_pooling (`bool`, *optional*, defaults to `False`):
  419. Whether to do pooling for the last_hidden_state in `TimmWrapper` or not.
  420. architecture (`str`, *optional*, defaults to `"mobilenetv5_300m_enc"`):
  421. Determines vision architecture for TimmWrapper.
  422. hidden_size (`int`, *optional*, defaults to 2048):
  423. Dimension of the hidden representations.
  424. vocab_size (`int`, *optional*, defaults to 128):
  425. Vocabulary size of the additional hard-token embeddings for vision model.
  426. vocab_offset (`int`, *optional*, defaults to 262144):
  427. Offset between the tokenizer vocab index for the token ids embedded by `Gemma3nMultimodalEmbedder` and the
  428. 0-indexed `Gemma3nMultimodalEmbedder.embedding` table.
  429. rms_norm_eps (`float`, *optional*, defaults to 1e-06):
  430. The epsilon used by the rms normalization layers.
  431. Example:
  432. ```python
  433. >>> from transformers import Gemma3nVisionConfig, TimmWrapper
  434. >>> # Initializing a TimmWrapper gemma3n_vision-E4B-style configuration
  435. >>> configuration = Gemma3nVisionConfig()
  436. >>> # Initializing a gemma3n_vision-E4B-style TimmWrapper from the configuration
  437. >>> model = TimmWrapper(configuration)
  438. >>> # Accessing the model configuration
  439. >>> configuration = model.config
  440. ```
  441. """
  442. model_type = "gemma3n_vision"
  443. def __init__(
  444. self,
  445. initializer_range: float = 0.02,
  446. do_pooling: bool = False,
  447. architecture: str = "mobilenetv5_300m_enc",
  448. hidden_size: int = 2048,
  449. vocab_size: int = 128,
  450. vocab_offset: int = 262_144,
  451. rms_norm_eps: float = 1e-06,
  452. model_args: Optional[dict] = None,
  453. **kwargs,
  454. ):
  455. super().__init__(**kwargs)
  456. self.architecture = architecture
  457. self.initializer_range = initializer_range
  458. self.do_pooling = do_pooling
  459. self.model_args = model_args # named "model_args" for BC with timm
  460. self.hidden_size = hidden_size
  461. self.vocab_size = vocab_size
  462. self.vocab_offset = vocab_offset
  463. self.rms_norm_eps = rms_norm_eps
  464. @classmethod
  465. def from_dict(cls, config_dict: dict[str, Any], **kwargs):
  466. label_names = config_dict.get("label_names")
  467. is_custom_model = "num_labels" in kwargs or "id2label" in kwargs
  468. # if no labels added to config, use imagenet labeller in timm
  469. if label_names is None and not is_custom_model:
  470. requires_backends(cls, ["timm"])
  471. imagenet_subset = infer_imagenet_subset(config_dict)
  472. if imagenet_subset:
  473. dataset_info = ImageNetInfo(imagenet_subset)
  474. synsets = dataset_info.label_names()
  475. label_descriptions = dataset_info.label_descriptions(as_dict=True)
  476. label_names = [label_descriptions[synset] for synset in synsets]
  477. if label_names is not None and not is_custom_model:
  478. kwargs["id2label"] = dict(enumerate(label_names))
  479. # if all label names are unique, create label2id mapping as well
  480. if len(set(label_names)) == len(label_names):
  481. kwargs["label2id"] = {name: i for i, name in enumerate(label_names)}
  482. else:
  483. kwargs["label2id"] = None
  484. # timm config stores the `num_classes` attribute in both the root of config and in the "pretrained_cfg" dict.
  485. # We are removing these attributes in order to have the native `transformers` num_labels attribute in config
  486. # and to avoid duplicate attributes
  487. num_labels_in_kwargs = kwargs.pop("num_labels", None)
  488. num_labels_in_dict = config_dict.pop("num_classes", None)
  489. # passed num_labels has priority over num_classes in config_dict
  490. kwargs["num_labels"] = num_labels_in_kwargs or num_labels_in_dict
  491. # pop num_classes from "pretrained_cfg",
  492. # it is not necessary to have it, only root one is used in timm
  493. if "pretrained_cfg" in config_dict and "num_classes" in config_dict["pretrained_cfg"]:
  494. config_dict["pretrained_cfg"].pop("num_classes", None)
  495. return super().from_dict(config_dict, **kwargs)
  496. def to_dict(self) -> dict[str, Any]:
  497. output = super().to_dict()
  498. output.setdefault("num_classes", self.num_labels)
  499. output.setdefault("label_names", list(self.id2label.values()))
  500. output.pop("id2label", None)
  501. output.pop("label2id", None)
  502. return output
  503. class Gemma3nConfig(PretrainedConfig):
  504. r"""
  505. This is the configuration class to store the configuration of a [`Gemma3nForConditionalGeneration`]. It is used to
  506. instantiate a Gemma3nForConditionalGeneration according to the specified arguments, defining the model
  507. architecture. Instantiating a configuration with the defaults will yield a similar configuration to that of
  508. Gemma3n-E4B.
  509. e.g. [google/gemma-3n-E4B](https://huggingface.co/google/gemma-3n-E4B)
  510. Configuration objects inherit from [`PretrainedConfig`] and can be used to control the model outputs. Read the
  511. documentation from [`PretrainedConfig`] for more information.
  512. Args:
  513. text_config (`Union[Gemma3nTextConfig, dict]`, *optional*):
  514. The config object of the text backbone.
  515. vision_config (`Union[AutoConfig, dict]`, *optional*):
  516. Custom vision config or dict.
  517. audio_config (`Union[AutoConfig, dict]`, *optional*):
  518. Custom audio config or dict.
  519. audio_soft_tokens_per_image (`int`, *optional*, defaults to 188):
  520. The number of soft tokens per audio clip.
  521. vision_soft_tokens_per_image (`int`, *optional*, defaults to 256):
  522. The number of soft tokens per image.
  523. boi_token_id (`int`, *optional*, defaults to 255999):
  524. The begin-of-image token index to wrap the image prompt.
  525. eoi_token_id (`int`, *optional*, defaults to 262144):
  526. The end-of-image token index to wrap the image prompt.
  527. image_token_id (`int`, *optional*, defaults to 262145):
  528. The image token index to encode the image prompt.
  529. boa_token_id (`int`, *optional*, defaults to 256000):
  530. The begin-of-audio token index to wrap the audio prompt.
  531. eoa_token_id (`int`, *optional*, defaults to 262272):
  532. The end-of-audio token index to wrap the audio prompt.
  533. audio_token_id (`int`, *optional*, defaults to 262273):
  534. The audio token index to encode the audio prompt.
  535. initializer_range (`float`, *optional*, defaults to 0.02):
  536. The standard deviation of the truncated_normal_initializer for initializing all weight matrices.
  537. Example:
  538. ```python
  539. >>> from transformers import Gemma3nForConditionalGeneration, Gemma3nConfig, Gemma3nTextConfig
  540. >>> # Initializing a MobileNet vision config, which is loaded from TIMM
  541. >>> vision_config = Gemma3nVisionConfig()
  542. >>> # Initializing a Gemma3n Audio config
  543. >>> audio_config = Gemma3nAudioConfig()
  544. >>> # Initializing a Gemma3n Text config
  545. >>> text_config = Gemma3nTextConfig()
  546. >>> # Initializing a Gemma3n gemma-3-4b style configuration
  547. >>> configuration = Gemma3nConfig(text_config, vision_config, audio_config)
  548. >>> # Initializing a model from the gemma-3-4b style configuration
  549. >>> model = Gemma3nTextConfig(configuration)
  550. >>> # Accessing the model configuration
  551. >>> configuration = model.config
  552. ```"""
  553. model_type = "gemma3n"
  554. sub_configs = {
  555. "text_config": Gemma3nTextConfig,
  556. "vision_config": Gemma3nVisionConfig,
  557. "audio_config": Gemma3nAudioConfig,
  558. }
  559. def __init__(
  560. self,
  561. text_config: Optional[Union[Gemma3nTextConfig, dict[str, Any]]] = None,
  562. vision_config: Optional[Union[Gemma3nVisionConfig, dict[str, Any]]] = None,
  563. audio_config: Optional[Union[Gemma3nAudioConfig, dict[str, Any]]] = None,
  564. audio_soft_tokens_per_image: int = 188,
  565. vision_soft_tokens_per_image: int = 256,
  566. boi_token_id: int = 255_999,
  567. eoi_token_id: int = 262_144,
  568. image_token_id: int = 262_145,
  569. boa_token_id: int = 256_000,
  570. eoa_token_id: int = 262_272,
  571. audio_token_id: int = 262_273,
  572. initializer_range: float = 0.02,
  573. **kwargs,
  574. ):
  575. super().__init__(**kwargs)
  576. if isinstance(text_config, dict):
  577. text_config = Gemma3nTextConfig(**text_config)
  578. elif text_config is None:
  579. text_config = Gemma3nTextConfig()
  580. logger.info("text_config is None. Using default Gemma3nTextConfig.")
  581. if isinstance(vision_config, dict):
  582. vision_config = Gemma3nVisionConfig(**vision_config)
  583. elif vision_config is None:
  584. vision_config = Gemma3nVisionConfig()
  585. logger.info("vision_config is None. Using default Gemma3nVisionConfig.")
  586. if isinstance(audio_config, dict):
  587. audio_config = Gemma3nAudioConfig(**audio_config)
  588. elif audio_config is None:
  589. audio_config = Gemma3nAudioConfig()
  590. logger.info("audio_config is None. Using default Gemma3nAudioConfig.")
  591. self.text_config = text_config
  592. self.vision_config = vision_config
  593. self.audio_config = audio_config
  594. self.audio_soft_tokens_per_image = audio_soft_tokens_per_image
  595. self.vision_soft_tokens_per_image = vision_soft_tokens_per_image
  596. self.boi_token_id = boi_token_id
  597. self.eoi_token_id = eoi_token_id
  598. self.image_token_id = image_token_id
  599. self.boa_token_id = boa_token_id
  600. self.eoa_token_id = eoa_token_id
  601. self.audio_token_id = audio_token_id
  602. self.initializer_range = initializer_range
  603. __all__ = ["Gemma3nAudioConfig", "Gemma3nConfig", "Gemma3nTextConfig", "Gemma3nVisionConfig"]