modeling_gemma3n.py 110 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525526527528529530531532533534535536537538539540541542543544545546547548549550551552553554555556557558559560561562563564565566567568569570571572573574575576577578579580581582583584585586587588589590591592593594595596597598599600601602603604605606607608609610611612613614615616617618619620621622623624625626627628629630631632633634635636637638639640641642643644645646647648649650651652653654655656657658659660661662663664665666667668669670671672673674675676677678679680681682683684685686687688689690691692693694695696697698699700701702703704705706707708709710711712713714715716717718719720721722723724725726727728729730731732733734735736737738739740741742743744745746747748749750751752753754755756757758759760761762763764765766767768769770771772773774775776777778779780781782783784785786787788789790791792793794795796797798799800801802803804805806807808809810811812813814815816817818819820821822823824825826827828829830831832833834835836837838839840841842843844845846847848849850851852853854855856857858859860861862863864865866867868869870871872873874875876877878879880881882883884885886887888889890891892893894895896897898899900901902903904905906907908909910911912913914915916917918919920921922923924925926927928929930931932933934935936937938939940941942943944945946947948949950951952953954955956957958959960961962963964965966967968969970971972973974975976977978979980981982983984985986987988989990991992993994995996997998999100010011002100310041005100610071008100910101011101210131014101510161017101810191020102110221023102410251026102710281029103010311032103310341035103610371038103910401041104210431044104510461047104810491050105110521053105410551056105710581059106010611062106310641065106610671068106910701071107210731074107510761077107810791080108110821083108410851086108710881089109010911092109310941095109610971098109911001101110211031104110511061107110811091110111111121113111411151116111711181119112011211122112311241125112611271128112911301131113211331134113511361137113811391140114111421143114411451146114711481149115011511152115311541155115611571158115911601161116211631164116511661167116811691170117111721173117411751176117711781179118011811182118311841185118611871188118911901191119211931194119511961197119811991200120112021203120412051206120712081209121012111212121312141215121612171218121912201221122212231224122512261227122812291230123112321233123412351236123712381239124012411242124312441245124612471248124912501251125212531254125512561257125812591260126112621263126412651266126712681269127012711272127312741275127612771278127912801281128212831284128512861287128812891290129112921293129412951296129712981299130013011302130313041305130613071308130913101311131213131314131513161317131813191320132113221323132413251326132713281329133013311332133313341335133613371338133913401341134213431344134513461347134813491350135113521353135413551356135713581359136013611362136313641365136613671368136913701371137213731374137513761377137813791380138113821383138413851386138713881389139013911392139313941395139613971398139914001401140214031404140514061407140814091410141114121413141414151416141714181419142014211422142314241425142614271428142914301431143214331434143514361437143814391440144114421443144414451446144714481449145014511452145314541455145614571458145914601461146214631464146514661467146814691470147114721473147414751476147714781479148014811482148314841485148614871488148914901491149214931494149514961497149814991500150115021503150415051506150715081509151015111512151315141515151615171518151915201521152215231524152515261527152815291530153115321533153415351536153715381539154015411542154315441545154615471548154915501551155215531554155515561557155815591560156115621563156415651566156715681569157015711572157315741575157615771578157915801581158215831584158515861587158815891590159115921593159415951596159715981599160016011602160316041605160616071608160916101611161216131614161516161617161816191620162116221623162416251626162716281629163016311632163316341635163616371638163916401641164216431644164516461647164816491650165116521653165416551656165716581659166016611662166316641665166616671668166916701671167216731674167516761677167816791680168116821683168416851686168716881689169016911692169316941695169616971698169917001701170217031704170517061707170817091710171117121713171417151716171717181719172017211722172317241725172617271728172917301731173217331734173517361737173817391740174117421743174417451746174717481749175017511752175317541755175617571758175917601761176217631764176517661767176817691770177117721773177417751776177717781779178017811782178317841785178617871788178917901791179217931794179517961797179817991800180118021803180418051806180718081809181018111812181318141815181618171818181918201821182218231824182518261827182818291830183118321833183418351836183718381839184018411842184318441845184618471848184918501851185218531854185518561857185818591860186118621863186418651866186718681869187018711872187318741875187618771878187918801881188218831884188518861887188818891890189118921893189418951896189718981899190019011902190319041905190619071908190919101911191219131914191519161917191819191920192119221923192419251926192719281929193019311932193319341935193619371938193919401941194219431944194519461947194819491950195119521953195419551956195719581959196019611962196319641965196619671968196919701971197219731974197519761977197819791980198119821983198419851986198719881989199019911992199319941995199619971998199920002001200220032004200520062007200820092010201120122013201420152016201720182019202020212022202320242025202620272028202920302031203220332034203520362037203820392040204120422043204420452046204720482049205020512052205320542055205620572058205920602061206220632064206520662067206820692070207120722073207420752076207720782079208020812082208320842085208620872088208920902091209220932094209520962097209820992100210121022103210421052106210721082109211021112112211321142115211621172118211921202121212221232124212521262127212821292130213121322133213421352136213721382139214021412142214321442145214621472148214921502151215221532154215521562157215821592160216121622163216421652166216721682169217021712172217321742175217621772178217921802181218221832184218521862187218821892190219121922193219421952196219721982199220022012202220322042205220622072208220922102211221222132214221522162217221822192220222122222223222422252226222722282229223022312232223322342235223622372238223922402241224222432244224522462247224822492250225122522253225422552256225722582259226022612262226322642265226622672268226922702271227222732274227522762277227822792280228122822283228422852286228722882289229022912292229322942295229622972298229923002301230223032304230523062307230823092310231123122313231423152316231723182319232023212322232323242325232623272328232923302331233223332334233523362337233823392340234123422343234423452346234723482349235023512352235323542355235623572358235923602361236223632364236523662367236823692370237123722373237423752376237723782379238023812382238323842385238623872388238923902391239223932394
  1. # 🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨
  2. # This file was automatically generated from src/transformers/models/gemma3n/modular_gemma3n.py.
  3. # Do NOT edit this file manually as any edits will be overwritten by the generation of
  4. # the file from the modular. If any change should be done, please apply the change to the
  5. # modular_gemma3n.py file directly. One of our CI enforces this.
  6. # 🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨
  7. # coding=utf-8
  8. # Copyright 2025 Google Inc. HuggingFace Inc. team. All rights reserved.
  9. #
  10. #
  11. # Licensed under the Apache License, Version 2.0 (the "License");
  12. # you may not use this file except in compliance with the License.
  13. # You may obtain a copy of the License at
  14. #
  15. # http://www.apache.org/licenses/LICENSE-2.0
  16. #
  17. # Unless required by applicable law or agreed to in writing, software
  18. # distributed under the License is distributed on an "AS IS" BASIS,
  19. # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  20. # See the License for the specific language governing permissions and
  21. # limitations under the License.
  22. import copy
  23. import math
  24. from collections.abc import Callable, Sequence
  25. from dataclasses import dataclass
  26. from typing import Optional, Union
  27. import torch
  28. import torch.nn as nn
  29. import torch.nn.functional as F
  30. from ...activations import ACT2FN
  31. from ...cache_utils import Cache, DynamicCache
  32. from ...generation import GenerationMixin
  33. from ...masking_utils import create_causal_mask, create_sliding_window_causal_mask
  34. from ...modeling_flash_attention_utils import FlashAttentionKwargs
  35. from ...modeling_layers import GradientCheckpointingLayer
  36. from ...modeling_outputs import BaseModelOutputWithPast, CausalLMOutputWithPast
  37. from ...modeling_rope_utils import ROPE_INIT_FUNCTIONS, dynamic_rope_update
  38. from ...modeling_utils import ALL_ATTENTION_FUNCTIONS, PreTrainedModel
  39. from ...processing_utils import Unpack
  40. from ...utils import ModelOutput, TransformersKwargs, auto_docstring, can_return_tuple, logging
  41. from ...utils.deprecation import deprecate_kwarg
  42. from ..auto import AutoModel
  43. from .configuration_gemma3n import Gemma3nAudioConfig, Gemma3nConfig, Gemma3nTextConfig, Gemma3nVisionConfig
  44. logger = logging.get_logger(__name__)
  45. @dataclass
  46. @auto_docstring(
  47. custom_intro="""
  48. Base class for Gemma3n outputs, with hidden states and attentions.
  49. """
  50. )
  51. class Gemma3nModelOutputWithPast(BaseModelOutputWithPast):
  52. r"""
  53. past_key_values (`Cache`, *optional*, returned when `use_cache=True` is passed or when `config.use_cache=True`):
  54. It is a [`~cache_utils.Cache`] instance. For more details, see our [kv cache guide](https://huggingface.co/docs/transformers/en/kv_cache).
  55. Contains pre-computed hidden-states (key and values in the self-attention blocks) that can be used (see
  56. `past_key_values` input) to speed up sequential decoding.
  57. image_hidden_states (`torch.FloatTensor`, *optional*):
  58. A `torch.FloatTensor` of size `(batch_size, num_images, sequence_length, hidden_size)`.
  59. image_hidden_states of the model produced by the vision encoder and after projecting the last hidden state.
  60. audio_hidden_states (`torch.FloatTensor`, *optional*):
  61. A `torch.FloatTensor` of size `(batch_size, num_images, sequence_length, hidden_size)`.
  62. audio_hidden_states of the model produced by the audio encoder and after projecting the last hidden state.
  63. """
  64. image_hidden_states: Optional[torch.FloatTensor] = None
  65. audio_hidden_states: Optional[torch.FloatTensor] = None
  66. @dataclass
  67. @auto_docstring(
  68. custom_intro="""
  69. Base class for Gemma3n causal language model (or autoregressive) outputs.
  70. """
  71. )
  72. class Gemma3nCausalLMOutputWithPast(ModelOutput):
  73. r"""
  74. loss (`torch.FloatTensor` of shape `(1,)`, *optional*, returned when `labels` is provided):
  75. Language modeling loss (for next-token prediction).
  76. logits (`torch.FloatTensor` of shape `(batch_size, sequence_length, config.text_config.vocab_size)`):
  77. Prediction scores of the language modeling head (scores for each vocabulary token before SoftMax).
  78. past_key_values (`Cache`, *optional*, returned when `use_cache=True` is passed or when `config.use_cache=True`):
  79. It is a [`~cache_utils.Cache`] instance. For more details, see our [kv cache guide](https://huggingface.co/docs/transformers/en/kv_cache).
  80. Contains pre-computed hidden-states (key and values in the self-attention blocks) that can be used (see
  81. `past_key_values` input) to speed up sequential decoding.
  82. image_hidden_states (`torch.FloatTensor`, *optional*):
  83. A `torch.FloatTensor` of size `(batch_size, num_images, sequence_length, hidden_size)`.
  84. image_hidden_states of the model produced by the vision encoder after projecting last hidden state.
  85. audio_hidden_states (`torch.FloatTensor`, *optional*):
  86. A `torch.FloatTensor` of size `(batch_size, num_images, sequence_length, hidden_size)`.
  87. audio_hidden_states of the model produced by the audio encoder and after projecting the last hidden state.
  88. """
  89. loss: Optional[torch.FloatTensor] = None
  90. logits: Optional[torch.FloatTensor] = None
  91. past_key_values: Optional[Cache] = None
  92. hidden_states: Optional[tuple[torch.FloatTensor]] = None
  93. attentions: Optional[tuple[torch.FloatTensor]] = None
  94. image_hidden_states: Optional[torch.FloatTensor] = None
  95. audio_hidden_states: Optional[torch.FloatTensor] = None
  96. class Gemma3nRMSNorm(nn.Module):
  97. def __init__(self, dim: int, eps: float = 1e-6, with_scale: bool = True):
  98. super().__init__()
  99. self.eps = eps
  100. self.with_scale = with_scale
  101. if self.with_scale:
  102. self.weight = nn.Parameter(torch.ones(dim))
  103. else:
  104. self.register_buffer("weight", torch.tensor(1.0), persistent=False)
  105. def _norm(self, x):
  106. return x / torch.sqrt(x.pow(2).mean(-1, keepdim=True) + self.eps)
  107. def forward(self, x: torch.Tensor) -> torch.Tensor:
  108. # Llama does x.to(float16) * w whilst Gemma2 is (x * w).to(float16)
  109. # See https://github.com/huggingface/transformers/pull/29402
  110. output = self._norm(x.float()) * self.weight.float()
  111. return output.type_as(x)
  112. def extra_repr(self):
  113. return f"{tuple(self.weight.shape)}, eps={self.eps}"
  114. # ==== Audio Encoder ====
  115. class Gemma3nAudioRelativePositionEmbedding(nn.Module):
  116. def __init__(self, config: Gemma3nAudioConfig):
  117. super().__init__()
  118. self.config = config
  119. self.num_heads = self.config.conf_num_attention_heads
  120. self.channels = self.config.hidden_size
  121. self.head_dim = self.channels // self.num_heads
  122. self.max_backward = max(0, self.config.conf_attention_context_left - 1)
  123. self.max_forward = self.config.conf_attention_context_right
  124. self.pos_proj = nn.Linear(self.channels, self.num_heads * self.head_dim, bias=False)
  125. min_timescale = 1.0
  126. max_timescale = 1.0e4
  127. num_timescales = self.channels // 2
  128. log_timescale_increment = math.log(float(max_timescale) / float(min_timescale)) / max(num_timescales - 1, 1)
  129. inv_timescales = min_timescale * torch.exp(torch.arange(num_timescales) * -log_timescale_increment)
  130. self.register_buffer(
  131. "inv_timescales",
  132. inv_timescales.float().unsqueeze(0).unsqueeze(0),
  133. persistent=False,
  134. )
  135. def _get_timing_signal_1d_pos(self, position: torch.Tensor, dtype: torch.dtype) -> torch.Tensor:
  136. position = position.float().unsqueeze(-1)
  137. scaled_time = position * self.inv_timescales.to(device=position.device, dtype=torch.float32)
  138. timing_signal = torch.cat([torch.sin(scaled_time), torch.cos(scaled_time)], dim=-1)
  139. return timing_signal.type(dtype)
  140. def _relative_shift(
  141. self,
  142. term_bd_before_shift: torch.Tensor,
  143. batch_size: int,
  144. num_heads: int,
  145. num_query_blocks: int,
  146. query_block_size: int,
  147. key_context_size: int,
  148. max_span_plus_1: int,
  149. ) -> torch.Tensor:
  150. """Performs the relative shift.
  151. Args:
  152. term_bd_before_shift: Tensor of shape [B, N, U, W, F_span]. batch_size
  153. (B), num_heads (N), num_query_blocks (U), query_block_size (W),
  154. key_context_size (C = W+L+R), max_span_plus_1 (F_span = L+R+1).
  155. Returns:
  156. Tensor of shape [B, N, U, W, C].
  157. """
  158. # term_bd_before_shift shape: [B, N, U, W, F_span]
  159. # Target shape after shift: [B, N, U, W, C]
  160. # Padding amount for the last dimension (F_span) to become (C + 1)
  161. # C = key_context_size
  162. # F_span = max_span_plus_1
  163. pad_amount_last_dim = (key_context_size + 1) - max_span_plus_1
  164. # PyTorch F.pad expects (pad_left, pad_right, pad_top, pad_bottom ...)
  165. # We only pad the last dimension on the right.
  166. padding_tuple = (0, pad_amount_last_dim)
  167. term_bd_padded = nn.functional.pad(term_bd_before_shift, padding_tuple)
  168. # Shape after pad: [B, N, U, W, C+1]
  169. # Reshape for slicing (emulating JAX's behavior)
  170. # [B, N, U, W * (C+1)]
  171. term_bd_reshaped = term_bd_padded.reshape(
  172. (
  173. batch_size,
  174. num_heads,
  175. num_query_blocks,
  176. query_block_size * (key_context_size + 1),
  177. )
  178. )
  179. # Slice to effective [B, N, U, W * C]
  180. term_bd_sliced = term_bd_reshaped[:, :, :, : query_block_size * key_context_size]
  181. # Reshape back to [B, N, U, W, C]
  182. term_bd_shifted = term_bd_sliced.reshape(
  183. (
  184. batch_size,
  185. num_heads,
  186. num_query_blocks,
  187. query_block_size,
  188. key_context_size,
  189. )
  190. )
  191. return term_bd_shifted
  192. def forward(self, queries: torch.Tensor, keys: torch.Tensor) -> torch.Tensor:
  193. # queries: [B, U, W, N, H] (batch, num_query_blocks, query_block_size, num_heads, head_dim)
  194. # keys: [B, U, C, N, H] (batch, num_query_blocks, key_context_size, num_heads, head_dim)
  195. # C = W + L + R (key_context_size)
  196. # F_span = L + R + 1 (max_span + 1)
  197. batch_size, num_query_blocks, query_block_size, num_heads, head_dim = queries.shape
  198. _, _, key_context_size, _, _ = keys.shape
  199. # Relative positions for sinusoidal embeddings: [L, L-1, ..., -R]
  200. # Length is L+R+1 = self.max_span + 1
  201. pos_indices = torch.arange(self.max_backward, -self.max_forward - 1, -1, device=queries.device).unsqueeze(
  202. 0
  203. ) # Shape [1, F_span]
  204. max_span_plus_1 = pos_indices.shape[1] # F_span
  205. sin_emb_timing_signal = self._get_timing_signal_1d_pos(
  206. pos_indices, dtype=queries.dtype
  207. ) # Shape [1, F_span, self.channels]
  208. # Project sinusoidal embeddings: [1, F_span, self.channels] -> [1, F_span, N*H]
  209. projected_sin_emb = self.pos_proj(sin_emb_timing_signal)
  210. # Reshape to [1, F_span, N, H] then squeeze to [F_span, N, H]
  211. sin_emb = projected_sin_emb.reshape(1, max_span_plus_1, self.num_heads, self.head_dim).squeeze(
  212. 0
  213. ) # Shape [F, N, H]
  214. # term_ac: Query-Key content interaction
  215. # queries: [B, U, W, N, H] -> permute to [B, N, U, W, H] for matmul
  216. # keys: [B, U, C, N, H] -> permute to [B, N, U, H, C] for matmul
  217. queries_p = queries.permute(0, 3, 1, 2, 4) # [B, N, U, W, H]
  218. keys_p_t = keys.permute(0, 3, 1, 4, 2) # [B, N, U, H, C]
  219. term_ac = torch.matmul(queries_p, keys_p_t) # [B, N, U, W, C]
  220. # term_bd: Query-Position interaction
  221. # Original einsum: term_bd_unshifed = torch.einsum('buwnh,fnh->bnuwf', queries, sin_emb)
  222. # queries shape: [B, U, W, N, H]
  223. # sin_emb shape: [F, N, H]
  224. # Target output shape: [B, N, U, W, F]
  225. # Permute queries to [B, N, U, W, H] for easier broadcasting with sin_emb
  226. q_permuted = queries.permute(0, 3, 1, 2, 4)
  227. # Permute sin_emb to [N, H, F] to prepare for matmul
  228. # sin_emb original is [F, N, H]
  229. s_permuted = sin_emb.permute(1, 2, 0) # Shape: [N, H, F]
  230. # Reshape queries for matmul: [B, N, U*W, H]
  231. q_reshaped = q_permuted.reshape(batch_size, num_heads, num_query_blocks * query_block_size, head_dim)
  232. # Perform matmul: [B, N, U*W, H] @ [N, H, F]
  233. # s_permuted ([N, H, F]) will be broadcast to [B, N, H, F]
  234. # Result: [B, N, U*W, F]
  235. term_bd_unshifed_matmul = torch.matmul(q_reshaped, s_permuted)
  236. # Reshape to target [B, N, U, W, F]
  237. term_bd_unshifed = term_bd_unshifed_matmul.reshape(
  238. batch_size,
  239. num_heads,
  240. num_query_blocks,
  241. query_block_size,
  242. max_span_plus_1,
  243. )
  244. # Apply relative shift to term_bd_unshifed
  245. term_bd_shifted = self._relative_shift(
  246. term_bd_unshifed,
  247. batch_size,
  248. num_heads,
  249. num_query_blocks,
  250. query_block_size,
  251. key_context_size,
  252. max_span_plus_1,
  253. ) # Shape [B, N, U, W, C]
  254. return term_ac + term_bd_shifted
  255. class Gemma3nAudioAttention(nn.Module):
  256. def __init__(self, config: Gemma3nAudioConfig):
  257. super().__init__()
  258. self.config = config
  259. self.num_heads = self.config.conf_num_attention_heads
  260. self.hidden_size = self.config.hidden_size
  261. self.head_dim = self.hidden_size // self.num_heads
  262. self.chunk_size = self.config.conf_attention_chunk_size
  263. self.max_future_horizon = self.config.conf_attention_context_right
  264. self.max_past_horizon = max(0, self.config.conf_attention_context_left - 1)
  265. self.attention_logits_soft_cap = self.config.conf_attention_logit_cap
  266. self.context_size = self.chunk_size + self.max_past_horizon + self.max_future_horizon
  267. self.relative_position_embedding = Gemma3nAudioRelativePositionEmbedding(config)
  268. self.per_dim_scale = nn.Parameter(torch.zeros((self.head_dim,)))
  269. self.q_proj = nn.Linear(self.hidden_size, self.num_heads * self.head_dim, bias=False)
  270. self.k_proj = nn.Linear(self.hidden_size, self.num_heads * self.head_dim, bias=False)
  271. self.v_proj = nn.Linear(self.hidden_size, self.num_heads * self.head_dim, bias=False)
  272. q_scale = self.head_dim**-0.5
  273. r_softplus_0 = 1.0 / torch.nn.functional.softplus(torch.tensor(0.0))
  274. self.register_buffer("q_scale", (q_scale * r_softplus_0).clone().detach(), persistent=False)
  275. lower_causal_mask = torch.tril(
  276. torch.ones((self.context_size, self.chunk_size), dtype=torch.bool),
  277. diagonal=0,
  278. ).T
  279. upper_causal_mask = torch.tril(
  280. torch.ones((self.chunk_size, self.context_size), dtype=torch.bool),
  281. diagonal=self.max_past_horizon + self.max_future_horizon,
  282. )
  283. local_causal_valid_mask = torch.ones((self.chunk_size, self.context_size), dtype=torch.bool)
  284. local_causal_valid_mask = local_causal_valid_mask * lower_causal_mask * upper_causal_mask
  285. self.register_buffer("local_causal_valid_mask", local_causal_valid_mask, persistent=False)
  286. self.register_buffer(
  287. "softcap",
  288. torch.tensor(self.attention_logits_soft_cap).float(),
  289. persistent=False,
  290. )
  291. def _pad_dim1(self, x: torch.Tensor, pad_left: int, pad_right: int) -> torch.Tensor:
  292. batch, _, *tail_shape = x.shape
  293. left = x.new_zeros((batch, pad_left, *tail_shape))
  294. right = x.new_zeros((batch, pad_right, *tail_shape))
  295. x = torch.cat([left, x, right], dim=1)
  296. return x
  297. def _convert_to_block(self, hidden_states: torch.Tensor) -> torch.Tensor:
  298. """Turns a sequence to non overlapping blocks.
  299. Args:
  300. hidden_states: a tensor of [batch, time, ...].
  301. Returns:
  302. A tensor of [batch, num_blocks, block_size, ...], with necessary
  303. paddings,
  304. where output[:, i, ...] are x[:, i*block_size:(i+1)*block_size, ...].
  305. """
  306. shape = hidden_states.shape
  307. b, t = shape[:2]
  308. num_blocks = (t + self.chunk_size - 1) // self.chunk_size
  309. if (padding_len := num_blocks * self.chunk_size - t) > 0:
  310. hidden_states = self._pad_dim1(hidden_states, 0, padding_len)
  311. permute_dims = (b, num_blocks, self.chunk_size) + shape[2:]
  312. hidden_states = hidden_states.reshape(permute_dims).contiguous()
  313. return hidden_states
  314. def _extract_block_context(self, hidden_states: torch.Tensor) -> torch.Tensor:
  315. """Extracts temporal context for every block.
  316. Args:
  317. hidden_states: a tensor of [batch, time, ...].
  318. Returns:
  319. A tensor of [batch, num_blocks, context_size, ...], with necessary
  320. paddings,
  321. where context_size = block_size + left_context + right_context,
  322. and output[:, i, ...] are x[:, start-left_context:end+right_context,
  323. ...],
  324. start = i * block_size, end = (i + 1) * block_size.
  325. """
  326. pad_left = self.max_past_horizon
  327. # The JAX equivalent padding for signal.frame with pad_mode='valid' is
  328. # (left_context, right_context + block_size - 1) on the time dimension.
  329. # PyTorch's _pad_dim1 applies padding symmetrically if only one value is given,
  330. # or (pad_dim_start, pad_dim_end) if two are given.
  331. # Our _pad_dim1(x, pad_left, pad_right) pads dim -2 (time for [B,T,N,H])
  332. # or dim 1 (time for [B,T]).
  333. # The current pad_right calculation matches the JAX effective padding.
  334. pad_right = self.max_future_horizon + self.chunk_size - 1
  335. hidden_states = self._pad_dim1(hidden_states, pad_left, pad_right)
  336. frame_len = self.context_size
  337. frame_step = self.chunk_size
  338. # Directly use unfold without the subframe_factor logic
  339. # x.unfold(dimension, size, step)
  340. # dimension=1 (time dimension, assuming x is [B, T_padded, ...])
  341. # size=frame_len (context_size)
  342. # step=frame_step (chunk_size)
  343. x_unfolded = hidden_states.unfold(dimension=1, size=frame_len, step=frame_step)
  344. # If x was [B, T_padded], x_unfolded is [B, num_blocks, frame_len]
  345. # If x was [B, T_padded, N, H], x_unfolded is [B, num_blocks, N, H, frame_len]
  346. # We want to match JAX's typical output for such operations which might be
  347. # [B, num_blocks, frame_len, N, H] if N, H are present.
  348. # The relative_position_embedding expects keys as [B, U, C, N, H].
  349. # If x_unfolded is [B, U, N, H, C(frame_len)], we need to move C.
  350. if hidden_states.ndim > 2 and x_unfolded.ndim > 3: # Check if inner dimensions (like N, H) exist
  351. # Current shape after unfold for [B, T_pad, N, H] is [B, U, N, H, C]
  352. # Target shape for keys in RPE: [B, U, C, N, H]
  353. x_unfolded = torch.movedim(x_unfolded, source=-1, destination=2)
  354. return x_unfolded.contiguous()
  355. def forward(self, hidden_states: torch.Tensor, mask: torch.BoolTensor) -> torch.Tensor:
  356. # sl.Dense uses jax.numpy.einsum("...a,abcd->...bcd") and jax.numpy.select()
  357. qkv_shape = (*hidden_states.shape[:-1], self.num_heads, self.head_dim)
  358. query_states = self.q_proj(hidden_states).reshape(qkv_shape).contiguous()
  359. key_states = self.k_proj(hidden_states).reshape(qkv_shape).contiguous()
  360. value_states = self.v_proj(hidden_states).reshape(qkv_shape).contiguous()
  361. per_dim_scale_sp = torch.nn.functional.softplus(self.per_dim_scale)
  362. broadcast_shape = (1, 1, 1, self.head_dim)
  363. per_dim_scale_sp_broadcast = per_dim_scale_sp.view(broadcast_shape)
  364. query_states = query_states * self.q_scale * per_dim_scale_sp_broadcast
  365. batch_size, q_time = query_states.shape[:2]
  366. query_blocks = self._convert_to_block(query_states)
  367. key_blocks = self._extract_block_context(key_states)
  368. value_blocks = self._extract_block_context(value_states)
  369. num_query_blocks = query_blocks.shape[1]
  370. # 1. Create a mask indicating originally valid positions.
  371. original_valid_mask = ~mask # True for valid, False for padded
  372. # 2. Extract blocks from this validity mask.
  373. extracted_valid_mask_blocks = self._extract_block_context(original_valid_mask)
  374. # If subframe_factor was used in _extract_block_context for a [B, T] input mask,
  375. # the shape might be [B, U, C/SF, SF]. Reshape to [B, U, C].
  376. # batch_size and num_query_blocks are known from query_blocks.
  377. # self.context_size is C.
  378. if (
  379. extracted_valid_mask_blocks.ndim == 4
  380. and extracted_valid_mask_blocks.shape[2] * extracted_valid_mask_blocks.shape[3] == self.context_size
  381. ):
  382. extracted_valid_mask_blocks = extracted_valid_mask_blocks.reshape(
  383. batch_size, num_query_blocks, self.context_size
  384. )
  385. # After potential reshape, ensure it's [B, U, C] if it was from a [B,T] mask.
  386. # This assertion might be too strict if _extract_block_context handles higher-rank inputs differently,
  387. # but for the mask case, this should hold.
  388. if extracted_valid_mask_blocks.shape != (
  389. batch_size,
  390. num_query_blocks,
  391. self.context_size,
  392. ):
  393. raise ValueError(
  394. "Shape of extracted_valid_mask_blocks"
  395. f" {extracted_valid_mask_blocks.shape} is not ({batch_size},"
  396. f" {num_query_blocks}, {self.context_size}) after potential reshape."
  397. )
  398. # 3. Expand dimensions for broadcasting with logits and causal mask.
  399. # Target shape for broadcasting with logits [B,N,U,W,C]
  400. # extracted_valid_mask_blocks to [B, 1, U, 1, C]
  401. condition_from_input_validity = extracted_valid_mask_blocks.unsqueeze(1).unsqueeze(-2)
  402. # self.local_causal_valid_mask is [W, C], True where allowed by local window.
  403. # Expand to [1, 1, 1, W, C]
  404. condition_from_causality = self.local_causal_valid_mask.unsqueeze(0).unsqueeze(0).unsqueeze(0)
  405. # 4. Combine the two conditions.
  406. # final_condition will be True where a key is *both* originally valid *and* causally accessible.
  407. # Broadcasts to [B, 1, U, W, C]
  408. final_condition_for_where = torch.logical_and(
  409. condition_from_input_validity,
  410. condition_from_causality.to(condition_from_input_validity.device), # Ensure same device
  411. )
  412. # Embed queries and keys
  413. logits = self.relative_position_embedding(query_blocks, key_blocks)
  414. # Apply attention logit softcap
  415. # Ensure softcap is on the same device as logits
  416. softcap_val = self.softcap.to(logits.device)
  417. logits = logits / softcap_val
  418. logits = torch.tanh(logits)
  419. logits = logits * softcap_val
  420. # Apply the combined mask.
  421. # final_condition_for_where will broadcast with logits [B,N,U,W,C]
  422. logits = torch.where(final_condition_for_where, logits, torch.finfo(logits.dtype).min)
  423. probabilities = torch.nn.functional.softmax(logits, dim=-1, dtype=torch.float32).to(dtype=value_blocks.dtype)
  424. # context_vectors is adapted from jax.numpy.einsum("BNuwc,BucNH->BuwNH", ...)
  425. b_dim, n_dim, u_dim, w_dim, c_dim = probabilities.shape
  426. h_dim = value_blocks.shape[-1]
  427. prob_bun = probabilities.permute(0, 2, 1, 3, 4).reshape(-1, w_dim, c_dim)
  428. v_bun = value_blocks.permute(0, 1, 3, 2, 4).reshape(-1, c_dim, h_dim)
  429. result_bmm = torch.bmm(prob_bun, v_bun)
  430. context_vectors = result_bmm.reshape(b_dim, u_dim, n_dim, w_dim, h_dim).permute(0, 1, 3, 2, 4)
  431. context_vectors = context_vectors.reshape(
  432. (
  433. batch_size,
  434. num_query_blocks * self.chunk_size,
  435. self.num_heads,
  436. self.head_dim,
  437. )
  438. )
  439. context_vectors = context_vectors[:, :q_time]
  440. return context_vectors
  441. class Gemma3nAudioCumulativeGroupNorm(nn.Module):
  442. """Applies Group Normalization cumulatively over the time dimension.
  443. This layer normalizes the input by calculating the mean and variance
  444. cumulatively over the time dimension (dim 1). The statistics are computed
  445. over all feature dimensions (specified by `feature_dims` and `num_channels`)
  446. for elements marked as valid by the optional `mask`.
  447. If a `mask` is provided (True for valid, False for invalid/padded),
  448. invalid time steps do not contribute to the statistics calculation, and
  449. their corresponding output values are zeroed out.
  450. Scale and bias, if enabled, are applied per-channel (last dimension).
  451. This behavior is similar to JAX's `GroupNormalization` with `num_groups=1`
  452. and `cumulative=True`.
  453. """
  454. def __init__(
  455. self,
  456. num_channels: int, # Number of channels (size of the last dimension)
  457. feature_dims: Sequence[int], # Sizes of non-channel feature dimensions, e.g., (H, W) for input [B,T,H,W,C]
  458. eps: float = 1e-3,
  459. ):
  460. super().__init__()
  461. self.num_channels = num_channels
  462. self.feature_dims = tuple(feature_dims)
  463. self.eps = eps
  464. # Scale parameter depends only on the channel dimension
  465. self.weight = nn.Parameter(torch.ones(num_channels))
  466. # Axes for normalization: all dimensions except Batch (0) and Time (1).
  467. # For input [B, T, *feature_dims, C], these are dims from 2 onwards.
  468. self.reduction_axes = tuple(range(2, 2 + len(self.feature_dims) + 1))
  469. def forward(self, hidden_states: torch.Tensor) -> torch.Tensor:
  470. """Applies cumulative group norm, optionally using a mask.
  471. Args:
  472. hidden_states: Input tensor, shape [B, T, *feature_dims, C].
  473. Returns:
  474. Normalized tensor with the same shape as x.
  475. """
  476. expected_input_suffix = self.feature_dims + (self.num_channels,)
  477. if hidden_states.shape[2:] != expected_input_suffix:
  478. raise ValueError(
  479. f"Input tensor shape suffix {hidden_states.shape[2:]} does not match expected"
  480. f" suffix (feature_dims + num_channels) {expected_input_suffix}"
  481. )
  482. input_dtype = hidden_states.dtype
  483. # Calculations are performed in float32 for numerical stability.
  484. calc_dtype = torch.float32
  485. x_calc = hidden_states.to(calc_dtype)
  486. # Prepare a broadcastable mask (`mask_calc`).
  487. # If no mask is provided, treat all elements as valid
  488. # (mask_calc is all ones).
  489. # Otherwise, expand the [B, T] mask to [B, T, 1, ..., 1] for broadcasting.
  490. mask_calc = torch.ones_like(x_calc, dtype=calc_dtype)
  491. # Cumulative Statistics Calculation
  492. # 1. Sum of values over reduction axes at each time step.
  493. sum_values_at_t = torch.sum(x_calc, dim=self.reduction_axes, keepdim=True)
  494. # 2. Cumulative sum of values over time.
  495. cum_sum_values = torch.cumsum(sum_values_at_t, dim=1)
  496. # 3. Count of valid elements in the normalization group at each time step.
  497. # (A "group" here consists of all features at a given Batch, Time).
  498. elements_in_group_at_t = torch.sum(mask_calc, dim=self.reduction_axes, keepdim=True)
  499. # 4. Cumulative count of valid elements over time.
  500. cum_count_elements = torch.cumsum(elements_in_group_at_t, dim=1)
  501. # Avoid division by zero if all preceding elements were masked.
  502. safe_cum_count_elements = torch.clamp(cum_count_elements, min=1.0)
  503. # 5. Cumulative mean.
  504. cum_mean = cum_sum_values / safe_cum_count_elements
  505. # 6. Sum of squared differences from the cumulative mean.
  506. # Only sum for valid elements: (x_calc - cum_mean)^2 * mask_calc.
  507. # Using x_calc here for the difference, as cum_mean already accounts for masking.
  508. squared_diff_from_mean = (x_calc - cum_mean).pow(2)
  509. sum_sq_diff_at_t = torch.sum(squared_diff_from_mean, dim=self.reduction_axes, keepdim=True)
  510. # 7. Cumulative sum of squared differences over time.
  511. cum_sum_sq_diff = torch.cumsum(sum_sq_diff_at_t, dim=1)
  512. # 8. Cumulative variance.
  513. cum_variance = cum_sum_sq_diff / safe_cum_count_elements
  514. # Normalize the input using the calculated cumulative statistics:
  515. # (x - E[x]) / sqrt(Var[x] + eps)
  516. normalized_x = (x_calc - cum_mean) * torch.rsqrt(cum_variance + self.eps)
  517. # Apply affine transformation (scale and bias) if enabled.
  518. # Scale and bias are applied per-channel (last dimension).
  519. scale = self.weight.to(calc_dtype)
  520. # Reshape for broadcasting: [C] -> [1, ..., 1, C]
  521. scale_view_shape = [1] * (hidden_states.dim() - 1) + [self.num_channels]
  522. normalized_x = normalized_x * scale.view(scale_view_shape)
  523. # Zero out outputs for time steps that were originally masked (where mask_calc is 0).
  524. # This ensures padded/invalid positions in the input result in zero output.
  525. final_output = normalized_x * mask_calc
  526. return final_output.to(input_dtype)
  527. class Gemma3nAudioSSCPConvBlock(nn.Module):
  528. """A single convolution block for the SubSampleConvProjection.
  529. This block consists of a 2D convolution, followed by CumulativeGroupNorm,
  530. and a ReLU activation. It handles manual padding for the convolution.
  531. """
  532. def __init__(
  533. self,
  534. config: Gemma3nAudioConfig,
  535. idx: int,
  536. input_freq_dim: int, # Changed from input_spatial_dim
  537. manual_padding: tuple[int, int, int, int] = (0, 0, 0, 0),
  538. ):
  539. super().__init__()
  540. self.config = config
  541. self.manual_padding = manual_padding
  542. # in_channels is 1 for the first block, or C_out from previous block's conv
  543. in_channels = 1 if idx == 0 else self.config.sscp_conv_channel_size[idx - 1]
  544. out_channels = self.config.sscp_conv_channel_size[idx]
  545. kernel_h, kernel_w = self.config.sscp_conv_kernel_size[idx]
  546. stride_h, stride_w = self.config.sscp_conv_stride_size[idx]
  547. self.conv = nn.Conv2d(
  548. in_channels=in_channels,
  549. out_channels=out_channels,
  550. kernel_size=(
  551. kernel_h,
  552. kernel_w,
  553. ), # Kernel (kH, kW) operates on (Time, Freq_dim)
  554. stride=(stride_h, stride_w),
  555. padding=(0, 0), # Manual padding is used
  556. bias=False,
  557. )
  558. # Calculate output frequency dimension (f_out_conv) after this convolution.
  559. # input_freq_dim is the unpadded width (feature dimension).
  560. # self.manual_padding is (pad_F_left, pad_F_right, pad_T_top, pad_T_bottom)
  561. f_in_padded = input_freq_dim + self.manual_padding[0] + self.manual_padding[1]
  562. f_out_conv = (f_in_padded - kernel_w) // stride_w + 1
  563. self.norm = Gemma3nAudioCumulativeGroupNorm(
  564. num_channels=out_channels, # Channels of the conv output
  565. feature_dims=(f_out_conv,), # The frequency dimension size after conv
  566. eps=self.config.sscp_conv_group_norm_eps,
  567. )
  568. self.activation = nn.ReLU()
  569. def forward(self, audio_encodings: torch.Tensor) -> torch.Tensor:
  570. # Input audio_encodings is [B, C_in, T_in, F_in] (e.g., C_in=1)
  571. # manual_padding is (pad_F_left, pad_F_right, pad_T_top, pad_T_bottom)
  572. # F.pad applies to last two dims: F_in then T_in
  573. audio_encodings_padded = F.pad(audio_encodings, self.manual_padding, mode="constant", value=0.0).to(
  574. self.conv.weight.dtype
  575. )
  576. # Expected padded shape for F_in, k_w=3, pad_F=(1,1) -> F_padded = F_in+2
  577. # Expected padded shape for T_in, k_h=3, pad_T=(0,2) -> T_padded = T_in+2
  578. audio_encodings_conv = self.conv(audio_encodings_padded)
  579. # Expected conv output shape: [B, C_out, T_out, F_out]
  580. # Input to norm is [B, T_out, F_out, C_out]
  581. x_for_norm = audio_encodings_conv.permute(0, 2, 3, 1).contiguous()
  582. x_normed = self.norm(x_for_norm)
  583. # Output of norm is [B, T_out, F_out, C_out], permute back to [B, C_out, T_out, F_out]
  584. audio_encodings_normed = x_normed.permute(0, 3, 1, 2).contiguous()
  585. return self.activation(audio_encodings_normed)
  586. class Gemma3nAudioSubSampleConvProjection(nn.Module):
  587. def __init__(self, config: Gemma3nAudioConfig):
  588. super().__init__()
  589. self.config = config
  590. current_f_for_block_input = config.input_feat_size # Start with original feature dim
  591. calculated_block_padding = []
  592. calculated_f_out_dims = [] # Tracking frequency dimension output sizes
  593. for i in range(2): # Assuming 2 conv layers as per sscp_conv_... arrays
  594. kernel_h, kernel_w = config.sscp_conv_kernel_size[i]
  595. stride_h, stride_w = config.sscp_conv_stride_size[i]
  596. # Padding for Time (Height for Conv2d) - REVERSE_CAUSAL like
  597. # JAX 'reverse_causal' padding is (0, kernel_size - 1)
  598. pad_t_top = 0
  599. pad_t_bottom = kernel_h - 1
  600. # Frequency Padding (Width for Conv2d)
  601. # Based on JAX effective padding (1,1) for F_in=10, K_w=3, S_w=2
  602. # and the successful test configuration.
  603. # If kernel/stride/input_freq for frequency changes, this might need re-evaluation
  604. # to match generic JAX 'SAME' behavior if it differs.
  605. pad_f_left = 1
  606. pad_f_right = 1
  607. manual_padding_tuple = (
  608. pad_f_left,
  609. pad_f_right,
  610. pad_t_top,
  611. pad_t_bottom,
  612. )
  613. calculated_block_padding.append(manual_padding_tuple)
  614. # Calculate output frequency dimension after this convolution
  615. # This uses the actual padding applied and kernel/stride.
  616. f_in_padded = current_f_for_block_input + pad_f_left + pad_f_right
  617. f_out_after_conv = (f_in_padded - kernel_w) // stride_w + 1 # Assuming dilation_w = 1
  618. calculated_f_out_dims.append(f_out_after_conv)
  619. current_f_for_block_input = f_out_after_conv
  620. self.conv_0 = Gemma3nAudioSSCPConvBlock(
  621. idx=0,
  622. input_freq_dim=config.input_feat_size, # Pass original feature dim
  623. config=config,
  624. manual_padding=calculated_block_padding[0],
  625. )
  626. self.conv_1 = Gemma3nAudioSSCPConvBlock(
  627. idx=1,
  628. input_freq_dim=calculated_f_out_dims[0], # Output freq dim from conv_0
  629. config=config,
  630. manual_padding=calculated_block_padding[1],
  631. )
  632. final_c_out = config.sscp_conv_channel_size[-1]
  633. final_f_out = calculated_f_out_dims[-1] # Final frequency dimension
  634. self.input_proj_in_features = final_c_out * final_f_out
  635. self.input_proj_linear = nn.Linear(self.input_proj_in_features, self.config.hidden_size, bias=False)
  636. def forward(self, audio_encodings: torch.Tensor) -> torch.Tensor:
  637. # audio_encodings is [B, T, F_in]
  638. # Reshape to [B, 1, T, F_in] (Batch, Channels=1, Height=Time, Width=F_in)
  639. audio_encodings_reshaped = audio_encodings.unsqueeze(1)
  640. x = self.conv_0(audio_encodings_reshaped)
  641. x = self.conv_1(x)
  642. # x from conv_1 is [B, C_out_1, T_out_1, F_out_1]
  643. b, c_out, t_out, f_out = x.shape
  644. # Permute to [B, T_out_1, F_out_1, C_out_1] then flatten F_out_1 and C_out_1
  645. x_permuted = x.permute(0, 2, 3, 1).contiguous()
  646. output_flattened = x_permuted.view(b, t_out, f_out * c_out)
  647. output = self.input_proj_linear(output_flattened)
  648. return output
  649. class Gemma3nAudioConformerAttention(nn.Module):
  650. def __init__(self, config: Gemma3nAudioConfig):
  651. super().__init__()
  652. self.config = config
  653. self.post_in_features = self.config.hidden_size
  654. self.register_buffer("gradient_clipping", torch.tensor(self.config.gradient_clipping), persistent=False)
  655. self.pre_attn_norm = Gemma3nRMSNorm(self.config.hidden_size)
  656. self.attn = Gemma3nAudioAttention(config)
  657. self.post = nn.Linear(self.post_in_features, self.config.hidden_size, bias=False)
  658. self.post_norm = Gemma3nRMSNorm(self.config.hidden_size)
  659. def forward(self, audio_encodings: torch.Tensor, audio_mel_mask: torch.BoolTensor) -> torch.Tensor:
  660. audio_encodings_input_to_attn = audio_encodings
  661. audio_encodings = torch.clamp(audio_encodings, -self.gradient_clipping, self.gradient_clipping)
  662. audio_encodings_norm = self.pre_attn_norm(audio_encodings)
  663. # Output of self.attn is [B, T, NumHeads, HeadDim]
  664. audio_encodings_attn_out = self.attn(audio_encodings_norm, audio_mel_mask)
  665. # Reshape from [B, T, NumHeads, HeadDim] to [B, T, NumHeads * HeadDim]
  666. # NumHeads * HeadDim = hidden_size
  667. b, t, num_heads, head_dim = audio_encodings_attn_out.shape
  668. audio_encodings_reshaped = audio_encodings_attn_out.reshape(b, t, num_heads * head_dim)
  669. audio_encodings = self.post(audio_encodings_reshaped)
  670. audio_encodings = torch.clamp(audio_encodings, -self.gradient_clipping, self.gradient_clipping)
  671. return audio_encodings_input_to_attn + self.post_norm(audio_encodings)
  672. class Gemma3nAudioConformerFeedForward(nn.Module):
  673. def __init__(self, config: Gemma3nAudioConfig):
  674. super().__init__()
  675. self.config = config
  676. self.register_buffer("gradient_clipping", torch.tensor(self.config.gradient_clipping), persistent=False)
  677. self.pre_layer_norm = Gemma3nRMSNorm(self.config.hidden_size)
  678. self.ffw_layer_1 = nn.Linear(self.config.hidden_size, self.config.hidden_size * 4, bias=False)
  679. self.ffw_layer_2 = nn.Linear(self.config.hidden_size * 4, self.config.hidden_size, bias=False)
  680. self.post_layer_norm = Gemma3nRMSNorm(self.config.hidden_size)
  681. self.post_layer_scale = torch.tensor(self.config.conf_residual_weight)
  682. def forward(self, audio_encodings: torch.Tensor) -> torch.Tensor:
  683. residual = audio_encodings
  684. audio_encodings = torch.clamp(audio_encodings, -self.gradient_clipping, self.gradient_clipping)
  685. audio_encodings = self.pre_layer_norm(audio_encodings)
  686. audio_encodings: torch.Tensor = self.ffw_layer_1(audio_encodings)
  687. audio_encodings = nn.functional.silu(audio_encodings)
  688. audio_encodings: torch.Tensor = self.ffw_layer_2(audio_encodings)
  689. audio_encodings = torch.clamp(audio_encodings, -self.gradient_clipping, self.gradient_clipping)
  690. audio_encodings = self.post_layer_norm(audio_encodings)
  691. return residual + (audio_encodings * self.post_layer_scale)
  692. class Gemma3nAudioConformerLightConv1d(nn.Module):
  693. def __init__(self, config: Gemma3nAudioConfig):
  694. super().__init__()
  695. self.config = config
  696. self.pre_layer_norm = Gemma3nRMSNorm(self.config.hidden_size, eps=self.config.rms_norm_eps)
  697. self.linear_start = nn.Linear(self.config.hidden_size, self.config.hidden_size * 2, bias=False)
  698. self.depthwise_conv1d = nn.Conv1d(
  699. in_channels=self.config.hidden_size,
  700. out_channels=self.config.hidden_size,
  701. kernel_size=self.config.conf_conv_kernel_size,
  702. stride=1,
  703. padding=0, # Manual causal padding
  704. groups=self.config.hidden_size, # Depthwise
  705. bias=False,
  706. )
  707. self.register_buffer("gradient_clipping", torch.tensor(self.config.gradient_clipping), persistent=False)
  708. self.conv_norm = Gemma3nRMSNorm(self.config.hidden_size, eps=self.config.rms_norm_eps)
  709. self.linear_end = nn.Linear(self.config.hidden_size, self.config.hidden_size, bias=False)
  710. self.causal_padding = self.config.conf_conv_kernel_size - 1
  711. def forward(self, audio_encodings: torch.Tensor) -> torch.Tensor:
  712. audio_encodings_residual = audio_encodings # Save for residual connection
  713. audio_encodings = self.pre_layer_norm(audio_encodings)
  714. audio_encodings = self.linear_start(audio_encodings)
  715. audio_encodings = torch.nn.functional.glu(audio_encodings, dim=-1)
  716. # Permute for Conv1d: [B, T, D] -> [B, D, T]
  717. audio_encodings_permuted = audio_encodings.permute(0, 2, 1)
  718. # Apply manual causal padding
  719. audio_encodings_permuted_padded = F.pad(audio_encodings_permuted, (self.causal_padding, 0))
  720. audio_encodings = self.depthwise_conv1d(audio_encodings_permuted_padded)
  721. # Permute back: [B, D, T_out] -> [B, T_out, D]
  722. audio_encodings = audio_encodings.permute(0, 2, 1)
  723. audio_encodings = torch.clamp(audio_encodings, -self.gradient_clipping, self.gradient_clipping)
  724. audio_encodings = self.conv_norm(audio_encodings)
  725. audio_encodings = nn.functional.silu(audio_encodings)
  726. audio_encodings = self.linear_end(audio_encodings)
  727. output = audio_encodings + audio_encodings_residual
  728. return output
  729. class Gemma3nAudioConformerBlock(nn.Module):
  730. def __init__(self, config: Gemma3nAudioConfig):
  731. super().__init__()
  732. self.config = config
  733. self.ffw_layer_start = Gemma3nAudioConformerFeedForward(self.config)
  734. self.attention = Gemma3nAudioConformerAttention(self.config)
  735. self.lconv1d = Gemma3nAudioConformerLightConv1d(self.config)
  736. self.ffw_layer_end = Gemma3nAudioConformerFeedForward(self.config)
  737. self.register_buffer("gradient_clipping", torch.tensor(self.config.gradient_clipping), persistent=False)
  738. self.norm = Gemma3nRMSNorm(self.config.hidden_size)
  739. def forward(self, audio_encodings: torch.Tensor, audio_mel_mask: torch.BoolTensor) -> torch.Tensor:
  740. audio_encodings = self.ffw_layer_start(audio_encodings)
  741. audio_encodings = self.attention(audio_encodings, audio_mel_mask)
  742. validity_mask_for_lconv = ~audio_mel_mask # True for valid
  743. audio_encodings_for_lconv_input = audio_encodings * validity_mask_for_lconv.unsqueeze(-1).to(
  744. audio_encodings.dtype
  745. )
  746. audio_encodings = self.lconv1d(audio_encodings_for_lconv_input)
  747. audio_encodings = self.ffw_layer_end(audio_encodings)
  748. audio_encodings = torch.clamp(audio_encodings, -self.gradient_clipping, self.gradient_clipping)
  749. output = self.norm(audio_encodings)
  750. return output
  751. class Gemma3nAudioEncoder(PreTrainedModel):
  752. """
  753. An audio encoder based on the [Universal Speech Model](https://huggingface.co/papers/2303.01037) architecture.
  754. """
  755. config: Gemma3nAudioConfig
  756. main_input_name = "audio_mel"
  757. def __init__(self, config: Gemma3nAudioConfig):
  758. super().__init__(config)
  759. self.config = config
  760. self.subsample_conv_projection = Gemma3nAudioSubSampleConvProjection(config)
  761. self.conformer = nn.ModuleList(
  762. [Gemma3nAudioConformerBlock(config) for _ in range(config.conf_num_hidden_layers)]
  763. )
  764. def forward(
  765. self, audio_mel: torch.Tensor, audio_mel_mask: torch.BoolTensor
  766. ) -> tuple[torch.Tensor, torch.BoolTensor]:
  767. """Encodes a batch of MELs.
  768. Args:
  769. audio_mel: a torch.Tensor of shape [batch, num_frames, num_channels,
  770. mel_bins].
  771. Returns:
  772. audio_encodings: a torch.Tensor of shape
  773. `[batch_size, self.config.audio_soft_tokens_per_image,
  774. self.config.audio_config.hidden_size]`
  775. audio_mel_mask: a torch.BoolTensor of shape [batch, num_frames].
  776. """
  777. audio_encodings = self.subsample_conv_projection(audio_mel) # audio_encodings: [B, T_sub, D]
  778. # Subsample the input audio_mel_mask to match the time dimension of audio_encodings (T_sub)
  779. t_sub = audio_encodings.shape[1]
  780. time_stride_product = 1
  781. for stride_pair_idx in range(len(self.config.sscp_conv_stride_size)):
  782. time_stride_product *= self.config.sscp_conv_stride_size[stride_pair_idx][0]
  783. # Create indices for gathering from the original mask.
  784. # These indices map to original time steps corresponding to the start of each
  785. # receptive field in the subsampled output.
  786. indices = torch.arange(t_sub, device=audio_mel_mask.device) * time_stride_product
  787. indices = torch.clamp(indices, max=audio_mel_mask.shape[1] - 1) # Ensure indices are valid
  788. # Expand indices for batch compatibility if B > 1 and indices is 1D.
  789. if audio_mel_mask.ndim > 1 and indices.ndim == 1:
  790. indices = indices.unsqueeze(0).expand(audio_mel_mask.shape[0], -1) # [B, T_sub]
  791. elif (
  792. audio_mel_mask.ndim == indices.ndim
  793. and audio_mel_mask.shape[0] == 1
  794. and indices.shape[0] != 1
  795. and t_sub == indices.shape[0]
  796. ):
  797. # Handle case where B=1 but indices became [T_sub] instead of [1, T_sub]
  798. indices = indices.unsqueeze(0)
  799. current_mask = torch.gather(audio_mel_mask, 1, indices) # [B, T_sub]
  800. for block in self.conformer:
  801. audio_encodings = block(audio_encodings, current_mask) # Pass the processed mask
  802. if self.config.conf_reduction_factor > 1:
  803. audio_encodings = audio_encodings[:, :: self.config.conf_reduction_factor]
  804. # Reduce the mask as well
  805. current_mask = current_mask[:, :: self.config.conf_reduction_factor]
  806. audio_encodings = audio_encodings.masked_fill(current_mask.unsqueeze(-1), 0.0)
  807. return audio_encodings, current_mask
  808. class Gemma3nTextScaledWordEmbedding(nn.Embedding):
  809. """
  810. This module overrides nn.Embeddings' forward by multiplying with embeddings scale.
  811. """
  812. def __init__(self, num_embeddings: int, embedding_dim: int, padding_idx: int, embed_scale: float = 1.0):
  813. super().__init__(num_embeddings, embedding_dim, padding_idx)
  814. self.register_buffer("embed_scale", torch.tensor(embed_scale), persistent=False)
  815. def forward(self, input_ids: torch.Tensor):
  816. return super().forward(input_ids) * self.embed_scale.to(self.weight.dtype)
  817. class Gemma3nTextLaurelBlock(nn.Module):
  818. """Learned Augmented Residual Layer"""
  819. def __init__(self, config: Gemma3nTextConfig):
  820. super().__init__()
  821. self.config = config
  822. self.linear_left = nn.Linear(self.config.hidden_size, self.config.laurel_rank, bias=False)
  823. self.linear_right = nn.Linear(self.config.laurel_rank, self.config.hidden_size, bias=False)
  824. self.post_laurel_norm = Gemma3nRMSNorm(self.config.hidden_size, eps=self.config.rms_norm_eps)
  825. def forward(self, hidden_states: torch.Tensor) -> torch.Tensor:
  826. laurel_hidden_states: torch.Tensor = self.linear_left(hidden_states)
  827. laurel_hidden_states: torch.Tensor = self.linear_right(laurel_hidden_states)
  828. normed_laurel_hidden_states = self.post_laurel_norm(laurel_hidden_states)
  829. return hidden_states + normed_laurel_hidden_states
  830. class Gemma3nTextMLP(nn.Module):
  831. def __init__(self, config: Gemma3nTextConfig, layer_idx: int = 0):
  832. super().__init__()
  833. self.config = config
  834. self.hidden_size = config.hidden_size
  835. self.intermediate_size = config.intermediate_size[layer_idx]
  836. self.gate_proj = nn.Linear(self.hidden_size, self.intermediate_size, bias=False)
  837. self.up_proj = nn.Linear(self.hidden_size, self.intermediate_size, bias=False)
  838. self.down_proj = nn.Linear(self.intermediate_size, self.hidden_size, bias=False)
  839. self.act_fn = ACT2FN[config.hidden_activation]
  840. self.activation_sparsity = config.activation_sparsity_pattern[layer_idx]
  841. def forward(self, hidden_states: torch.Tensor) -> torch.Tensor:
  842. gate_proj = self.gate_proj(hidden_states)
  843. if self.activation_sparsity > 0.0:
  844. gate_proj = self._gaussian_topk(gate_proj)
  845. activations = self.act_fn(gate_proj)
  846. up_proj = self.up_proj(hidden_states)
  847. down_proj = self.down_proj(activations * up_proj)
  848. return down_proj
  849. def _gaussian_topk(self, inputs: torch.Tensor) -> torch.Tensor:
  850. target_sparsity_tensor = torch.tensor(self.activation_sparsity, dtype=torch.float32, device=inputs.device)
  851. # normal_dist and std_multiplier are adapted from jax.scipy.stats.norm.ppf().
  852. #
  853. # References:
  854. # * https://docs.jax.dev/en/latest/_autosummary/jax.scipy.stats.norm.ppf.html
  855. # * https://pytorch.org/docs/stable/distributions.html#torch.distributions.normal.Normal
  856. # * https://pytorch.org/docs/stable/distributions.html#torch.distributions.transformed_distribution.TransformedDistribution.icdf
  857. normal_dist = torch.distributions.normal.Normal(0, 1)
  858. std_multiplier: torch.Tensor = normal_dist.icdf(target_sparsity_tensor)
  859. std_multiplier = std_multiplier.type(inputs.dtype)
  860. inputs_mean = torch.mean(inputs, dim=-1, keepdim=True)
  861. inputs_std = torch.std(inputs, dim=-1, keepdim=True, unbiased=False)
  862. cutoff_x = inputs_mean + inputs_std * std_multiplier
  863. return nn.functional.relu(inputs - cutoff_x)
  864. class Gemma3nTextAltUp(nn.Module):
  865. """Alternating Updates (AltUp)
  866. The AltUp module wraps transformer layers. The `predict` step modifies the
  867. input to the transformer layer, and the `correct` step propagates the output
  868. of the transformer layer to the sparsely updated dimensions.
  869. See more in the research paper:
  870. https://proceedings.neurips.cc/paper_files/paper/2023/file/f2059277ac6ce66e7e5543001afa8bb5-Paper-Conference.pdf
  871. """
  872. def __init__(self, config: Gemma3nTextConfig):
  873. super().__init__()
  874. self.config = config
  875. self.correct_output_scale = nn.Parameter(torch.zeros(self.config.hidden_size))
  876. self.correction_coefs = nn.Linear(self.config.altup_num_inputs, self.config.altup_num_inputs, bias=False)
  877. self.prediction_coefs = nn.Linear(self.config.altup_num_inputs, self.config.altup_num_inputs**2, bias=False)
  878. self.modality_router = nn.Linear(self.config.hidden_size, self.config.altup_num_inputs, bias=False)
  879. self.router_norm = Gemma3nRMSNorm(self.config.hidden_size, eps=self.config.rms_norm_eps)
  880. self.register_buffer("router_input_scale", torch.tensor(self.config.hidden_size**-1.0), persistent=False)
  881. def compute_router_modalities(self, x: torch.Tensor) -> torch.Tensor:
  882. router_inputs = self.router_norm(x) * self.router_input_scale
  883. routed = self.modality_router(router_inputs)
  884. return torch.tanh(routed.float()).type_as(x)
  885. def predict(self, hidden_states: torch.Tensor) -> torch.Tensor:
  886. """Predicts the output of a layer using a trainable map.
  887. Args:
  888. hidden_states: A 4D tensor of shape `[num_altup_inputs, batch_size, num_tokens, hidden_size]` derived by
  889. stacking the input embeddings and preprocessing the last `num_altup_inputs - 1` matrices.
  890. Returns:
  891. A 4D tensor of shape `[num_altup_inputs, batch_size, num_tokens, hidden_size]` containing the predictions.
  892. """
  893. modalities = self.compute_router_modalities(hidden_states[self.config.altup_active_idx])
  894. if self.training and self.config.altup_coef_clip is not None:
  895. self.prediction_coefs.weight.data.clamp_(-self.config.altup_coef_clip, self.config.altup_coef_clip)
  896. # Project and then transpose all 2D matrices contained so that mulmat gives the correct result
  897. all_coefs: torch.Tensor = (
  898. self.prediction_coefs(modalities)
  899. .reshape(*modalities.shape[:-1], self.config.altup_num_inputs, self.config.altup_num_inputs)
  900. .permute(0, 1, 3, 2)
  901. )
  902. # permute hidden_states to [batch_size, num_tokens, hidden_size, altup_num_inputs]
  903. predictions = torch.matmul(hidden_states.permute(1, 2, 3, 0), all_coefs)
  904. predictions = predictions.permute(3, 0, 1, 2) # undo the permute
  905. predictions += hidden_states # add the original input
  906. return predictions.contiguous().type_as(hidden_states)
  907. def correct(self, predictions: torch.Tensor, activated: torch.Tensor) -> torch.Tensor:
  908. """Corrects the predictions relative to the
  909. Args:
  910. predictions: A 4D tensor of shape `[num_altup_inputs, batch_size, num_tokens, hidden_size]` derived by
  911. stacking the input embeddings and preprocessing the last `num_altup_inputs - 1` matrices.
  912. activated: A 3D tensor of shape `[batch_size, num_tokens, hidden_size]` containing the activated inputs.
  913. Returns:
  914. A 4D tensor of shape `[num_altup_inputs, batch_size, num_tokens, hidden_size]` correcting the original
  915. predictions relative to the activated input embeddings.
  916. """
  917. modalities = self.compute_router_modalities(activated)
  918. innovation = activated - predictions[self.config.altup_active_idx] # (batch, num_tokens, hidden_size)
  919. innovation = innovation.repeat(self.config.altup_num_inputs, 1, 1, 1) # Repeat on dim0 to match predictions
  920. if self.config.altup_coef_clip is not None:
  921. self.correction_coefs.weight.data.clamp_(-self.config.altup_coef_clip, self.config.altup_coef_clip)
  922. # all_coefs adapted from jax.numpy.einsum("...p,pi->...i", ...)
  923. # Permute to (altup_num_inputs, batch_size, num_tokens) as the last dim is a scalar applied to each altup input
  924. # and expand on dim1 for broadcastability
  925. all_coefs: torch.Tensor = self.correction_coefs(modalities) + 1.0
  926. all_coefs = all_coefs.permute(2, 0, 1).unsqueeze(-1)
  927. corrected = torch.mul(innovation, all_coefs)
  928. corrected += predictions # add the original input
  929. return corrected.contiguous().type_as(activated)
  930. def forward(self, corrected: torch.Tensor) -> torch.Tensor:
  931. """
  932. This is only defined as the `forward` so that accelerate hooks can move correctly `correct_output_scale`
  933. (which is a nn.Parameter, not a Module) between devices when offloading. It is otherwise only used in
  934. `scale_corrected_output`
  935. """
  936. return (corrected.type_as(self.correct_output_scale) * self.correct_output_scale).type_as(corrected)
  937. def scale_corrected_output(self, corrected: torch.Tensor) -> torch.Tensor:
  938. """Scales the provided 3D tensor of shape [batch_size, num_tokens, hidden_size]."""
  939. return self.forward(corrected)
  940. class Gemma3nTextRotaryEmbedding(nn.Module):
  941. inv_freq: torch.Tensor # fix linting for `register_buffer`
  942. def __init__(self, config: Gemma3nTextConfig, device=None):
  943. super().__init__()
  944. # BC: "rope_type" was originally "type"
  945. if hasattr(config, "rope_scaling") and isinstance(config.rope_scaling, dict):
  946. self.rope_type = config.rope_scaling.get("rope_type", config.rope_scaling.get("type"))
  947. else:
  948. self.rope_type = "default"
  949. self.max_seq_len_cached = config.max_position_embeddings
  950. self.original_max_seq_len = config.max_position_embeddings
  951. self.config = config
  952. self.rope_init_fn = ROPE_INIT_FUNCTIONS[self.rope_type]
  953. inv_freq, self.attention_scaling = self.rope_init_fn(self.config, device)
  954. self.register_buffer("inv_freq", inv_freq, persistent=False)
  955. self.original_inv_freq = self.inv_freq
  956. @torch.no_grad()
  957. @dynamic_rope_update # power user: used with advanced RoPE types (e.g. dynamic rope)
  958. def forward(self, x, position_ids):
  959. inv_freq_expanded = self.inv_freq[None, :, None].float().expand(position_ids.shape[0], -1, 1).to(x.device)
  960. position_ids_expanded = position_ids[:, None, :].float()
  961. device_type = x.device.type if isinstance(x.device.type, str) and x.device.type != "mps" else "cpu"
  962. with torch.autocast(device_type=device_type, enabled=False): # Force float32
  963. freqs = (inv_freq_expanded.float() @ position_ids_expanded.float()).transpose(1, 2)
  964. emb = torch.cat((freqs, freqs), dim=-1)
  965. cos = emb.cos() * self.attention_scaling
  966. sin = emb.sin() * self.attention_scaling
  967. return cos.to(dtype=x.dtype), sin.to(dtype=x.dtype)
  968. def rotate_half(x):
  969. """Rotates half the hidden dims of the input."""
  970. x1 = x[..., : x.shape[-1] // 2]
  971. x2 = x[..., x.shape[-1] // 2 :]
  972. return torch.cat((-x2, x1), dim=-1)
  973. def repeat_kv(hidden_states: torch.Tensor, n_rep: int) -> torch.Tensor:
  974. """
  975. This is the equivalent of torch.repeat_interleave(x, dim=1, repeats=n_rep). The hidden states go from (batch,
  976. num_key_value_heads, seqlen, head_dim) to (batch, num_attention_heads, seqlen, head_dim)
  977. """
  978. batch, num_key_value_heads, slen, head_dim = hidden_states.shape
  979. if n_rep == 1:
  980. return hidden_states
  981. hidden_states = hidden_states[:, :, None, :, :].expand(batch, num_key_value_heads, n_rep, slen, head_dim)
  982. return hidden_states.reshape(batch, num_key_value_heads * n_rep, slen, head_dim)
  983. def eager_attention_forward(
  984. module: nn.Module,
  985. query: torch.Tensor,
  986. key: torch.Tensor,
  987. value: torch.Tensor,
  988. attention_mask: Optional[torch.Tensor],
  989. dropout: float = 0.0,
  990. scaling: Optional[float] = None,
  991. softcap: Optional[float] = None,
  992. **kwargs,
  993. ) -> tuple[torch.Tensor, torch.Tensor]:
  994. if scaling is None:
  995. scaling = module.head_dim**-0.5
  996. key_states = repeat_kv(key, module.num_key_value_groups)
  997. value_states = repeat_kv(value, module.num_key_value_groups)
  998. attn_weights = torch.matmul(query, key_states.transpose(2, 3)) * scaling
  999. if softcap is not None:
  1000. attn_weights = attn_weights / softcap
  1001. attn_weights = torch.tanh(attn_weights)
  1002. attn_weights = attn_weights * softcap
  1003. if attention_mask is not None: # no matter the length, we just slice it
  1004. causal_mask = attention_mask[:, :, :, : key_states.shape[-2]]
  1005. attn_weights = attn_weights + causal_mask
  1006. # upcast attention to fp32
  1007. attn_weights = nn.functional.softmax(attn_weights, dim=-1, dtype=torch.float32).to(query.dtype)
  1008. attn_weights = nn.functional.dropout(attn_weights, p=dropout, training=module.training)
  1009. attn_output = torch.matmul(attn_weights, value_states)
  1010. attn_output = attn_output.transpose(1, 2).contiguous()
  1011. return attn_output, attn_weights
  1012. def apply_rotary_pos_emb(
  1013. x: torch.Tensor,
  1014. cos: torch.Tensor,
  1015. sin: torch.Tensor,
  1016. position_ids: Optional[torch.Tensor] = None,
  1017. unsqueeze_dim: int = 1,
  1018. ):
  1019. """Applies Rotary Position Embedding to the query and key tensors.
  1020. Args:
  1021. x (`torch.Tensor`): The tensor to embed.
  1022. cos (`torch.Tensor`): The cosine part of the rotary embedding.
  1023. sin (`torch.Tensor`): The sine part of the rotary embedding.
  1024. position_ids (`torch.Tensor`, *optional*):
  1025. Deprecated and unused.
  1026. unsqueeze_dim (`int`, *optional*, defaults to 1):
  1027. The 'unsqueeze_dim' argument specifies the dimension along which to unsqueeze cos[position_ids] and
  1028. sin[position_ids] so that they can be properly broadcasted to the dimensions of q and k. For example, note
  1029. that cos[position_ids] and sin[position_ids] have the shape [batch_size, seq_len, head_dim]. Then, if q and
  1030. k have the shape [batch_size, heads, seq_len, head_dim], then setting unsqueeze_dim=1 makes
  1031. cos[position_ids] and sin[position_ids] broadcastable to the shapes of q and k. Similarly, if q and k have
  1032. the shape [batch_size, seq_len, heads, head_dim], then set unsqueeze_dim=2.
  1033. Returns:
  1034. `tuple(torch.Tensor)` comprising of the query and key tensors rotated using the Rotary Position Embedding.
  1035. """
  1036. cos = cos.unsqueeze(unsqueeze_dim)
  1037. sin = sin.unsqueeze(unsqueeze_dim)
  1038. return (x * cos) + (rotate_half(x) * sin)
  1039. class Gemma3nTextAttention(nn.Module):
  1040. """Multi-headed attention from 'Attention Is All You Need' paper"""
  1041. def __init__(self, config: Gemma3nTextConfig, layer_idx: int):
  1042. super().__init__()
  1043. self.is_sliding = config.layer_types[layer_idx] == "sliding_attention"
  1044. self.config = config
  1045. self.layer_idx = layer_idx
  1046. self.head_dim = getattr(config, "head_dim", config.hidden_size // config.num_attention_heads)
  1047. self.num_key_value_groups = config.num_attention_heads // config.num_key_value_heads
  1048. self.attention_dropout = self.config.attention_dropout
  1049. self.is_causal = True
  1050. self.q_proj = nn.Linear(
  1051. config.hidden_size, config.num_attention_heads * self.head_dim, bias=config.attention_bias
  1052. )
  1053. self.k_proj = nn.Linear(
  1054. config.hidden_size, config.num_key_value_heads * self.head_dim, bias=config.attention_bias
  1055. )
  1056. self.v_proj = nn.Linear(
  1057. config.hidden_size, config.num_key_value_heads * self.head_dim, bias=config.attention_bias
  1058. )
  1059. self.o_proj = nn.Linear(
  1060. config.num_attention_heads * self.head_dim, config.hidden_size, bias=config.attention_bias
  1061. )
  1062. self.sliding_window = config.sliding_window if self.is_sliding else None
  1063. self.q_norm = Gemma3nRMSNorm(dim=config.head_dim, eps=config.rms_norm_eps)
  1064. self.k_norm = Gemma3nRMSNorm(dim=config.head_dim, eps=config.rms_norm_eps)
  1065. self.v_norm = Gemma3nRMSNorm(dim=config.head_dim, eps=config.rms_norm_eps, with_scale=False)
  1066. first_kv_shared_layer_idx = self.config.num_hidden_layers - self.config.num_kv_shared_layers
  1067. self.is_kv_shared_layer = layer_idx >= first_kv_shared_layer_idx > 0
  1068. prev_layers = config.layer_types[:first_kv_shared_layer_idx]
  1069. if self.is_kv_shared_layer:
  1070. # For shared layers, find the last non-shared layer of the same type before sharing starts
  1071. self.kv_shared_layer_index = len(prev_layers) - 1 - prev_layers[::-1].index(config.layer_types[layer_idx])
  1072. self.store_full_length_kv = False
  1073. else:
  1074. self.kv_shared_layer_index = None
  1075. # For non-shared layers, store full-length kv if this is the last non-shared layer of its type
  1076. self.store_full_length_kv = layer_idx == len(prev_layers) - 1 - prev_layers[::-1].index(
  1077. config.layer_types[layer_idx]
  1078. )
  1079. @deprecate_kwarg("past_key_value", new_name="past_key_values", version="4.58")
  1080. def forward(
  1081. self,
  1082. hidden_states: torch.Tensor,
  1083. position_embeddings: torch.Tensor,
  1084. attention_mask: Optional[torch.Tensor],
  1085. past_key_values: Optional[Cache] = None,
  1086. cache_position: Optional[torch.LongTensor] = None,
  1087. **kwargs: Unpack[FlashAttentionKwargs],
  1088. ) -> tuple[torch.Tensor, Optional[torch.Tensor], Optional[tuple[torch.Tensor]]]:
  1089. input_shape = hidden_states.shape[:-1]
  1090. hidden_shape = (*input_shape, -1, self.config.head_dim)
  1091. cos, sin = position_embeddings
  1092. query_states = self.q_proj(hidden_states).view(hidden_shape)
  1093. query_states = self.q_norm(query_states)
  1094. query_states = apply_rotary_pos_emb(query_states, cos, sin, unsqueeze_dim=2)
  1095. query_states = query_states.transpose(1, 2)
  1096. # For layers with shared KV (from kv sharing point onwards), we reuse the same keys/values states as the last non-sharing layer
  1097. if self.is_kv_shared_layer and past_key_values is not None:
  1098. key_states, value_states = past_key_values.shared_layers[self.kv_shared_layer_index]
  1099. # Device of past layer may be different from current one
  1100. key_states = key_states.to(query_states.device)
  1101. value_states = value_states.to(query_states.device)
  1102. else:
  1103. key_states = self.k_proj(hidden_states).view(hidden_shape)
  1104. key_states = self.k_norm(key_states)
  1105. key_states = apply_rotary_pos_emb(key_states, cos, sin, unsqueeze_dim=2)
  1106. key_states = key_states.transpose(1, 2)
  1107. value_states = self.v_proj(hidden_states).view(hidden_shape)
  1108. value_states = self.v_norm(value_states)
  1109. value_states = value_states.transpose(1, 2)
  1110. if past_key_values is not None:
  1111. # sin and cos are specific to RoPE models; cache_position needed for the static cache
  1112. cache_kwargs = {
  1113. "sin": sin,
  1114. "cos": cos,
  1115. "cache_position": cache_position,
  1116. "sliding_window": self.sliding_window,
  1117. }
  1118. if not self.is_kv_shared_layer:
  1119. key_states, value_states = past_key_values.update(
  1120. key_states, value_states, self.layer_idx, cache_kwargs
  1121. )
  1122. if self.store_full_length_kv:
  1123. if not hasattr(past_key_values, "shared_layers"):
  1124. past_key_values.shared_layers = {}
  1125. past_key_values.shared_layers[self.layer_idx] = key_states, value_states
  1126. attention_interface: Callable = eager_attention_forward
  1127. if self.config._attn_implementation != "eager":
  1128. attention_interface = ALL_ATTENTION_FUNCTIONS[self.config._attn_implementation]
  1129. attn_output, attn_weights = attention_interface(
  1130. self,
  1131. query_states,
  1132. key_states,
  1133. value_states,
  1134. attention_mask,
  1135. dropout=self.attention_dropout if self.training else 0.0,
  1136. scaling=1.0,
  1137. sliding_window=self.sliding_window,
  1138. **kwargs,
  1139. )
  1140. attn_output = attn_output.reshape(*input_shape, -1).contiguous()
  1141. attn_output = self.o_proj(attn_output)
  1142. return attn_output, attn_weights
  1143. class Gemma3nTextDecoderLayer(GradientCheckpointingLayer):
  1144. def __init__(self, config: Gemma3nTextConfig, layer_idx: int):
  1145. super().__init__()
  1146. self.config = config
  1147. self.hidden_size = config.hidden_size
  1148. self.layer_idx = layer_idx
  1149. self.attention_type = config.layer_types[layer_idx]
  1150. self.self_attn = Gemma3nTextAttention(config, layer_idx)
  1151. self.mlp = Gemma3nTextMLP(config, layer_idx=layer_idx)
  1152. self.input_layernorm = Gemma3nRMSNorm(self.hidden_size, eps=config.rms_norm_eps)
  1153. self.post_attention_layernorm = Gemma3nRMSNorm(self.hidden_size, eps=config.rms_norm_eps)
  1154. self.pre_feedforward_layernorm = Gemma3nRMSNorm(self.hidden_size, eps=config.rms_norm_eps)
  1155. self.post_feedforward_layernorm = Gemma3nRMSNorm(self.hidden_size, eps=config.rms_norm_eps)
  1156. self.hidden_size_per_layer_input = config.hidden_size_per_layer_input
  1157. self.act_fn = ACT2FN[config.hidden_activation]
  1158. self.altup = Gemma3nTextAltUp(config)
  1159. self.laurel = Gemma3nTextLaurelBlock(config)
  1160. self.per_layer_input_gate = nn.Linear(self.hidden_size, self.hidden_size_per_layer_input, bias=False)
  1161. self.per_layer_projection = nn.Linear(self.hidden_size_per_layer_input, self.hidden_size, bias=False)
  1162. self.post_per_layer_input_norm = Gemma3nRMSNorm(self.hidden_size, eps=config.rms_norm_eps)
  1163. @deprecate_kwarg("past_key_value", new_name="past_key_values", version="4.58")
  1164. def forward(
  1165. self,
  1166. hidden_states: torch.Tensor,
  1167. position_embeddings_global: torch.Tensor,
  1168. position_embeddings_local: torch.Tensor,
  1169. per_layer_input: torch.Tensor,
  1170. attention_mask: Optional[torch.Tensor] = None,
  1171. position_ids: Optional[torch.LongTensor] = None,
  1172. past_key_values: Optional[Cache] = None,
  1173. output_attentions: Optional[bool] = False,
  1174. use_cache: Optional[bool] = False,
  1175. cache_position: Optional[torch.LongTensor] = None,
  1176. **kwargs,
  1177. ) -> tuple[torch.Tensor, Optional[tuple[torch.FloatTensor, torch.FloatTensor]]]:
  1178. predictions = self.altup.predict(hidden_states)
  1179. active_prediction = predictions[self.config.altup_active_idx]
  1180. active_prediction_normed = self.input_layernorm(active_prediction)
  1181. laurel_output = self.laurel(active_prediction_normed)
  1182. # apply global RoPE to non-sliding layer only
  1183. if self.self_attn.is_sliding:
  1184. position_embeddings = position_embeddings_local
  1185. else:
  1186. position_embeddings = position_embeddings_global
  1187. attn, self_attn_weights = self.self_attn(
  1188. hidden_states=active_prediction_normed,
  1189. position_embeddings=position_embeddings,
  1190. attention_mask=attention_mask,
  1191. position_ids=position_ids,
  1192. past_key_values=past_key_values,
  1193. output_attentions=output_attentions,
  1194. use_cache=use_cache,
  1195. cache_position=cache_position,
  1196. **kwargs,
  1197. )
  1198. attn = self.post_attention_layernorm(attn)
  1199. attn_gated = active_prediction + attn
  1200. attn_laurel = (attn_gated + laurel_output) / math.sqrt(2)
  1201. attn_norm = self.pre_feedforward_layernorm(attn_laurel)
  1202. attn_ffw = self.mlp(attn_norm)
  1203. attn_ffw_norm = self.post_feedforward_layernorm(attn_ffw)
  1204. attn_ffw_laurel_gated = attn_laurel + attn_ffw_norm
  1205. corrected_predictions = self.altup.correct(predictions, attn_ffw_laurel_gated)
  1206. first_prediction = corrected_predictions[self.config.altup_active_idx].clone()
  1207. if self.config.altup_correct_scale:
  1208. first_prediction = self.altup.scale_corrected_output(first_prediction)
  1209. # per_layer_input_gate adapted from jax.numpy.einsum("btd,dp->btp", ...)
  1210. first_prediction = self.per_layer_input_gate(first_prediction)
  1211. first_prediction = self.act_fn(first_prediction)
  1212. first_prediction = torch.multiply(first_prediction, per_layer_input)
  1213. # per_layer_projection adapted from jax.numpy.einsum("btp,pd->btd", ...)
  1214. first_prediction = self.per_layer_projection(first_prediction)
  1215. first_prediction = self.post_per_layer_input_norm(first_prediction)
  1216. corrected_predictions[1:] += first_prediction
  1217. outputs = (corrected_predictions,)
  1218. if output_attentions:
  1219. outputs += (self_attn_weights,)
  1220. return outputs
  1221. @auto_docstring
  1222. class Gemma3nPreTrainedModel(PreTrainedModel):
  1223. config: Gemma3nConfig
  1224. base_model_prefix = ""
  1225. supports_gradient_checkpointing = True
  1226. _no_split_modules = ["Gemma3nTextDecoderLayer"]
  1227. _skip_keys_device_placement = ["past_key_values"]
  1228. _supports_flash_attn = True
  1229. _supports_sdpa = True
  1230. _supports_flex_attn = True
  1231. _can_compile_fullgraph = True
  1232. _supports_attention_backend = True
  1233. _can_record_outputs = {
  1234. "hidden_states": Gemma3nTextDecoderLayer,
  1235. "attentions": Gemma3nTextAttention,
  1236. }
  1237. def _init_weights(self, module):
  1238. super()._init_weights(module)
  1239. if isinstance(module, Gemma3nAudioCumulativeGroupNorm):
  1240. module.weight.data.fill_(1.0)
  1241. elif isinstance(module, Gemma3nAudioAttention):
  1242. module.per_dim_scale.data.zero_()
  1243. elif isinstance(module, Gemma3nTextAltUp):
  1244. module.correct_output_scale.data.zero_()
  1245. @auto_docstring(custom_intro="The base Gemma 3n language model without a language modeling head.")
  1246. class Gemma3nTextModel(Gemma3nPreTrainedModel):
  1247. config: Gemma3nTextConfig
  1248. def __init__(self, config: Gemma3nTextConfig):
  1249. super().__init__(config)
  1250. self.padding_idx = config.pad_token_id
  1251. self.vocab_size = config.vocab_size
  1252. # Gemma3n downcasts the below to bfloat16, causing sqrt(3072)=55.4256 to become 55.5. See https://github.com/huggingface/transformers/pull/29402
  1253. self.embed_tokens = Gemma3nTextScaledWordEmbedding(
  1254. config.vocab_size, config.hidden_size, self.padding_idx, embed_scale=self.config.hidden_size**0.5
  1255. )
  1256. self.layers = nn.ModuleList(
  1257. [Gemma3nTextDecoderLayer(config, layer_idx) for layer_idx in range(config.num_hidden_layers)]
  1258. )
  1259. self.norm = Gemma3nRMSNorm(config.hidden_size, eps=config.rms_norm_eps)
  1260. self.rotary_emb = Gemma3nTextRotaryEmbedding(config=config)
  1261. self.gradient_checkpointing = False
  1262. # TODO (raushan): Fix this after RoPE refactor. For now we hack it by
  1263. # reassigning thetas when we want to create a local RoPE layer. Config
  1264. # defaults should hold values for global RoPE.
  1265. config = copy.deepcopy(config)
  1266. config.rope_theta = config.rope_local_base_freq
  1267. config.rope_scaling = {"rope_type": "default"}
  1268. self.rotary_emb_local = Gemma3nTextRotaryEmbedding(config=config)
  1269. self.hidden_size = config.hidden_size
  1270. self.hidden_size_per_layer_input = config.hidden_size_per_layer_input
  1271. self.embed_tokens_per_layer = Gemma3nTextScaledWordEmbedding(
  1272. config.vocab_size_per_layer_input,
  1273. config.num_hidden_layers * config.hidden_size_per_layer_input,
  1274. self.padding_idx,
  1275. embed_scale=config.hidden_size_per_layer_input**0.5,
  1276. )
  1277. self.per_layer_model_projection = nn.Linear(
  1278. self.hidden_size,
  1279. config.num_hidden_layers * config.hidden_size_per_layer_input,
  1280. bias=False,
  1281. )
  1282. self.per_layer_projection_norm = Gemma3nRMSNorm(config.hidden_size_per_layer_input, eps=config.rms_norm_eps)
  1283. self.altup_projections = nn.ModuleList(
  1284. [nn.Linear(self.hidden_size, self.hidden_size, bias=False) for _ in range(1, self.config.altup_num_inputs)]
  1285. )
  1286. self.altup_unembed_projections = nn.ModuleList(
  1287. [nn.Linear(self.hidden_size, self.hidden_size, bias=False) for _ in range(1, self.config.altup_num_inputs)]
  1288. )
  1289. self.register_buffer("per_layer_projection_scale", torch.tensor(self.hidden_size**-0.5), persistent=False)
  1290. self.register_buffer("per_layer_input_scale", torch.rsqrt(torch.tensor(2.0)), persistent=False)
  1291. # Initialize weights and apply final processing
  1292. self.post_init()
  1293. @can_return_tuple
  1294. @auto_docstring
  1295. def forward(
  1296. self,
  1297. input_ids: Optional[torch.LongTensor] = None,
  1298. per_layer_inputs: Optional[torch.Tensor] = None,
  1299. attention_mask: Optional[torch.Tensor] = None,
  1300. position_ids: Optional[torch.LongTensor] = None,
  1301. past_key_values: Optional[Cache] = None,
  1302. inputs_embeds: Optional[torch.FloatTensor] = None,
  1303. use_cache: Optional[bool] = None,
  1304. output_attentions: Optional[bool] = None,
  1305. output_hidden_states: Optional[bool] = None,
  1306. cache_position: Optional[torch.LongTensor] = None,
  1307. **kwargs: Unpack[TransformersKwargs],
  1308. ) -> BaseModelOutputWithPast:
  1309. r"""
  1310. per_layer_inputs (torch.Tensor, *optional*, defaults to None):
  1311. Pre-computed per-layer embeddings. If None, they are derived from input_ids if provided.
  1312. """
  1313. output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
  1314. output_hidden_states = (
  1315. output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
  1316. )
  1317. use_cache = use_cache if use_cache is not None else self.config.use_cache
  1318. if (input_ids is None) ^ (inputs_embeds is not None):
  1319. raise ValueError("You must specify exactly one of input_ids or inputs_embeds")
  1320. if self.gradient_checkpointing and self.training and use_cache:
  1321. logger.warning_once(
  1322. "`use_cache=True` is incompatible with gradient checkpointing. Setting `use_cache=False`."
  1323. )
  1324. use_cache = False
  1325. if input_ids is not None:
  1326. inputs_embeds = self.embed_tokens(input_ids)
  1327. per_layer_inputs = self.get_per_layer_inputs(input_ids)
  1328. per_layer_inputs = self.project_per_layer_inputs(inputs_embeds, per_layer_inputs)
  1329. if use_cache and past_key_values is None and not self.training:
  1330. past_key_values = DynamicCache(config=self.config)
  1331. if cache_position is None:
  1332. past_seen_tokens = past_key_values.get_seq_length() if past_key_values is not None else 0
  1333. cache_position = torch.arange(
  1334. past_seen_tokens,
  1335. past_seen_tokens + inputs_embeds.shape[1],
  1336. device=inputs_embeds.device,
  1337. )
  1338. if position_ids is None:
  1339. position_ids = cache_position.unsqueeze(0)
  1340. # It may already have been prepared by e.g. `generate`
  1341. if not isinstance(causal_mask_mapping := attention_mask, dict):
  1342. # Prepare mask arguments
  1343. mask_kwargs = {
  1344. "config": self.config,
  1345. "input_embeds": inputs_embeds,
  1346. "attention_mask": attention_mask,
  1347. "cache_position": cache_position,
  1348. "past_key_values": past_key_values,
  1349. "position_ids": position_ids,
  1350. }
  1351. # Create the masks
  1352. causal_mask_mapping = {
  1353. "full_attention": create_causal_mask(**mask_kwargs),
  1354. "sliding_attention": create_sliding_window_causal_mask(**mask_kwargs),
  1355. }
  1356. # embed positions
  1357. hidden_states_0 = inputs_embeds
  1358. # Initialize RoPE embeddings
  1359. position_embeddings_global = self.rotary_emb(hidden_states_0, position_ids)
  1360. position_embeddings_local = self.rotary_emb_local(hidden_states_0, position_ids)
  1361. # Expand hidden_states to support per-layer inputs
  1362. target_magnitude = torch.mean(hidden_states_0**2, dim=-1, keepdim=True) ** 0.5
  1363. epsilon_tensor = torch.tensor(1e-5)
  1364. temp_hidden_states = [hidden_states_0]
  1365. for i in range(1, self.config.altup_num_inputs):
  1366. # altup_proj adapted from jax.numpy.einsum("btp,pd->btd", ...)
  1367. altup_proj = self.altup_projections[i - 1](hidden_states_0)
  1368. current_hidden_state = altup_proj.to(dtype=hidden_states_0.dtype, device=target_magnitude.device)
  1369. new_magnitude = torch.mean(current_hidden_state**2, dim=-1, keepdim=True)
  1370. new_magnitude = torch.sqrt(torch.maximum(new_magnitude, epsilon_tensor.to(target_magnitude.device)))
  1371. current_hidden_state = current_hidden_state * target_magnitude / new_magnitude
  1372. temp_hidden_states.append(current_hidden_state)
  1373. hidden_states = torch.stack(temp_hidden_states, dim=0) # [num_altup_inputs, batch, seq_len, hidden_size]
  1374. # decoder layers
  1375. all_hidden_states = () if output_hidden_states else None
  1376. all_self_attns = () if output_attentions else None
  1377. for decoder_layer in self.layers[: self.config.num_hidden_layers]:
  1378. if output_hidden_states:
  1379. all_hidden_states += (hidden_states,)
  1380. causal_mask = causal_mask_mapping[decoder_layer.attention_type]
  1381. per_layer_input = per_layer_inputs[:, :, decoder_layer.layer_idx, :]
  1382. layer_outputs = decoder_layer(
  1383. hidden_states,
  1384. position_embeddings_global,
  1385. position_embeddings_local,
  1386. per_layer_input,
  1387. attention_mask=causal_mask,
  1388. position_ids=position_ids,
  1389. past_key_values=past_key_values,
  1390. output_attentions=output_attentions,
  1391. use_cache=use_cache,
  1392. cache_position=cache_position,
  1393. **kwargs,
  1394. )
  1395. hidden_states = layer_outputs[0]
  1396. if output_attentions:
  1397. all_self_attns += (layer_outputs[1],)
  1398. # add hidden states from the last decoder layer (but before reprojecting to stay consistent with layer output)
  1399. if output_hidden_states:
  1400. all_hidden_states += (hidden_states,)
  1401. # Per-layer inputs to single output
  1402. target_magnitude = torch.mean(hidden_states[0] ** 2, dim=-1, keepdim=True) ** 0.5
  1403. temp_hidden_states = [hidden_states[0]]
  1404. for i in range(1, self.config.altup_num_inputs):
  1405. # altup_unembed_projections adapted from jax.numpy.einsum("btp,pd->btd", ...)
  1406. altup_unemb_proj: torch.Tensor = self.altup_unembed_projections[i - 1](hidden_states[i])
  1407. current_hidden_state = altup_unemb_proj.to(dtype=hidden_states_0.dtype, device=target_magnitude.device)
  1408. new_magnitude = torch.mean(current_hidden_state**2, dim=-1, keepdim=True)
  1409. new_magnitude = torch.sqrt(torch.maximum(new_magnitude, epsilon_tensor.to(target_magnitude.device)))
  1410. current_hidden_state = current_hidden_state * target_magnitude / new_magnitude
  1411. temp_hidden_states.append(current_hidden_state)
  1412. hidden_states = torch.stack(temp_hidden_states)
  1413. hidden_states = torch.mean(hidden_states, dim=0)
  1414. hidden_states = self.norm(hidden_states)
  1415. return BaseModelOutputWithPast(
  1416. last_hidden_state=hidden_states,
  1417. past_key_values=past_key_values,
  1418. hidden_states=all_hidden_states,
  1419. attentions=all_self_attns,
  1420. )
  1421. def get_per_layer_inputs(self, input_ids: torch.LongTensor) -> torch.Tensor:
  1422. return self.embed_tokens_per_layer(input_ids).reshape(
  1423. *input_ids.shape,
  1424. self.config.num_hidden_layers,
  1425. self.hidden_size_per_layer_input,
  1426. )
  1427. def project_per_layer_inputs(
  1428. self,
  1429. inputs_embeds: torch.Tensor,
  1430. per_layer_inputs: Optional[torch.Tensor] = None,
  1431. ) -> torch.Tensor:
  1432. per_layer_projection: torch.Tensor = self.per_layer_model_projection(inputs_embeds)
  1433. per_layer_projection *= self.per_layer_projection_scale.to(
  1434. dtype=inputs_embeds.dtype, device=per_layer_projection.device
  1435. )
  1436. per_layer_projection = per_layer_projection.reshape(
  1437. *inputs_embeds.shape[:-1],
  1438. self.config.num_hidden_layers,
  1439. self.hidden_size_per_layer_input,
  1440. )
  1441. per_layer_projection = self.per_layer_projection_norm(per_layer_projection)
  1442. if per_layer_inputs is None:
  1443. return per_layer_projection
  1444. if per_layer_projection.shape != per_layer_inputs.shape:
  1445. # per-layer inputs are sometimes padded with zeros, slice the relevant embeddings.
  1446. per_layer_inputs = per_layer_inputs[..., : self.config.num_hidden_layers, :]
  1447. return (per_layer_projection + per_layer_inputs) * self.per_layer_input_scale.to(
  1448. dtype=inputs_embeds.dtype, device=per_layer_projection.device
  1449. )
  1450. @auto_docstring(custom_intro="The base Gemma 3n language model with a language modeling head.")
  1451. class Gemma3nForCausalLM(Gemma3nPreTrainedModel, GenerationMixin):
  1452. _tied_weights_keys = ["lm_head.weight"]
  1453. _tp_plan = {"lm_head": "colwise_rep"}
  1454. _pp_plan = {"lm_head": (["hidden_states"], ["logits"])}
  1455. config: Gemma3nTextConfig
  1456. base_model_prefix = "model"
  1457. _checkpoint_conversion_mapping = {"model.language_model": "model"}
  1458. def __init__(self, config: Gemma3nTextConfig):
  1459. super().__init__(config)
  1460. self.model = Gemma3nTextModel(config)
  1461. self.vocab_size = config.vocab_size
  1462. self.lm_head = nn.Linear(config.hidden_size, config.vocab_size, bias=False)
  1463. # Initialize weights and apply final processing
  1464. self.post_init()
  1465. @can_return_tuple
  1466. @auto_docstring
  1467. def forward(
  1468. self,
  1469. input_ids: Optional[torch.LongTensor] = None,
  1470. attention_mask: Optional[torch.Tensor] = None,
  1471. position_ids: Optional[torch.LongTensor] = None,
  1472. past_key_values: Optional[Cache] = None,
  1473. inputs_embeds: Optional[torch.FloatTensor] = None,
  1474. labels: Optional[torch.LongTensor] = None,
  1475. use_cache: Optional[bool] = None,
  1476. output_attentions: Optional[bool] = None,
  1477. output_hidden_states: Optional[bool] = None,
  1478. cache_position: Optional[torch.LongTensor] = None,
  1479. logits_to_keep: Union[int, torch.Tensor] = 0,
  1480. **kwargs,
  1481. ) -> CausalLMOutputWithPast:
  1482. r"""
  1483. Example:
  1484. ```python
  1485. >>> from transformers import AutoTokenizer, Gemma3nForCausalLM
  1486. >>> model = Gemma3nForCausalLM.from_pretrained("google/gemma-2-9b")
  1487. >>> tokenizer = AutoTokenizer.from_pretrained("google/gemma-2-9b")
  1488. >>> prompt = "What is your favorite condiment?"
  1489. >>> inputs = tokenizer(prompt, return_tensors="pt")
  1490. >>> # Generate
  1491. >>> generate_ids = model.generate(inputs.input_ids, max_length=30)
  1492. >>> tokenizer.batch_decode(generate_ids, skip_special_tokens=True, clean_up_tokenization_spaces=False)[0]
  1493. "What is your favorite condiment?"
  1494. ```"""
  1495. output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
  1496. output_hidden_states = (
  1497. output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
  1498. )
  1499. # decoder outputs consists of (dec_features, layer_state, dec_hidden, dec_attn)
  1500. outputs: BaseModelOutputWithPast = self.model(
  1501. input_ids=input_ids,
  1502. attention_mask=attention_mask,
  1503. position_ids=position_ids,
  1504. past_key_values=past_key_values,
  1505. inputs_embeds=inputs_embeds,
  1506. use_cache=use_cache,
  1507. output_attentions=output_attentions,
  1508. output_hidden_states=output_hidden_states,
  1509. cache_position=cache_position,
  1510. **kwargs,
  1511. )
  1512. hidden_states = outputs.last_hidden_state
  1513. # Only compute necessary logits, and do not upcast them to float if we are not computing the loss
  1514. slice_indices = slice(-logits_to_keep, None) if isinstance(logits_to_keep, int) else logits_to_keep
  1515. logits = self.lm_head(hidden_states[:, slice_indices, :])
  1516. if self.config.final_logit_softcapping is not None:
  1517. logits = logits / self.config.final_logit_softcapping
  1518. logits = torch.tanh(logits)
  1519. logits = logits * self.config.final_logit_softcapping
  1520. loss = None
  1521. if labels is not None:
  1522. loss = self.loss_function(logits, labels, self.vocab_size, **kwargs)
  1523. return CausalLMOutputWithPast(
  1524. loss=loss,
  1525. logits=logits,
  1526. past_key_values=outputs.past_key_values,
  1527. hidden_states=outputs.hidden_states,
  1528. attentions=outputs.attentions,
  1529. )
  1530. class Gemma3nMultimodalEmbedder(nn.Module):
  1531. """Embeds token ids or soft tokens for multimodal content into language model space."""
  1532. def __init__(
  1533. self,
  1534. multimodal_config: Union[Gemma3nAudioConfig, Gemma3nVisionConfig],
  1535. text_config: Gemma3nTextConfig,
  1536. ):
  1537. super().__init__()
  1538. self.multimodal_hidden_size = multimodal_config.hidden_size
  1539. self.eps = multimodal_config.rms_norm_eps
  1540. self.vocab_offset = multimodal_config.vocab_offset
  1541. self.vocab_size = multimodal_config.vocab_size
  1542. self.text_hidden_size = text_config.hidden_size
  1543. self.embedding = nn.Embedding(self.vocab_size, self.multimodal_hidden_size)
  1544. self.hard_embedding_norm = Gemma3nRMSNorm(self.multimodal_hidden_size, eps=self.eps)
  1545. self.soft_embedding_norm = Gemma3nRMSNorm(self.multimodal_hidden_size, eps=self.eps)
  1546. self.embedding_projection = nn.Linear(self.multimodal_hidden_size, self.text_hidden_size, bias=False)
  1547. self.embedding_post_projection_norm = Gemma3nRMSNorm(self.text_hidden_size, eps=self.eps, with_scale=False)
  1548. def forward(
  1549. self,
  1550. input_ids: Optional[torch.LongTensor] = None,
  1551. inputs_embeds: Optional[torch.Tensor] = None,
  1552. ) -> torch.Tensor:
  1553. """Embeds token ids or soft tokens for multimodal content into language model space.
  1554. Args:
  1555. input_ids: A torch.LongTensor containing the token ids to embed. Values should be in the range
  1556. `[vocab_offset, vocab_offset + vocab_size)`.
  1557. inputs_embeds: A torch.Tensor containing the soft tokens to embed.
  1558. Returns:
  1559. A torch.Tensor of embeddings with shape `[batch_size, seq_len, self.config.text_config.hidden_size]`.
  1560. """
  1561. if (input_ids is None) ^ (inputs_embeds is not None):
  1562. raise ValueError("You must specify exactly one of input_ids or inputs_embeds")
  1563. if inputs_embeds is not None:
  1564. emb_norm = self.soft_embedding_norm(inputs_embeds)
  1565. else:
  1566. hard_emb = self.embedding(input_ids - self.vocab_offset)
  1567. emb_norm = self.hard_embedding_norm(hard_emb)
  1568. emb_norm_proj = self.embedding_projection(emb_norm)
  1569. return self.embedding_post_projection_norm(emb_norm_proj)
  1570. @auto_docstring(
  1571. custom_intro="""
  1572. The base Gemma 3n model comprising a vision backbone, an audio backbone, and a language model without a
  1573. language modeling head.
  1574. """
  1575. )
  1576. class Gemma3nModel(Gemma3nPreTrainedModel):
  1577. _checkpoint_conversion_mapping = {}
  1578. # we are filtering the logits/labels so we shouldn't divide the loss based on num_items_in_batch
  1579. accepts_loss_kwargs = False
  1580. def __init__(self, config: Gemma3nConfig):
  1581. super().__init__(config)
  1582. self.vision_tower = AutoModel.from_config(config=config.vision_config)
  1583. self.vocab_size = config.text_config.vocab_size
  1584. language_model = AutoModel.from_config(config=config.text_config)
  1585. self.language_model = language_model
  1586. self.pad_token_id = self.config.pad_token_id if self.config.pad_token_id is not None else -1
  1587. self.vocab_size_per_layer_input = config.text_config.vocab_size_per_layer_input
  1588. self.audio_tower = AutoModel.from_config(config.audio_config)
  1589. self.embed_vision = Gemma3nMultimodalEmbedder(config.vision_config, config.text_config)
  1590. self.embed_audio = Gemma3nMultimodalEmbedder(config.audio_config, config.text_config)
  1591. self.post_init()
  1592. def get_input_embeddings(self):
  1593. return self.language_model.get_input_embeddings()
  1594. def set_input_embeddings(self, value):
  1595. self.language_model.set_input_embeddings(value)
  1596. def set_decoder(self, decoder):
  1597. self.language_model = decoder
  1598. def get_decoder(self):
  1599. return self.language_model
  1600. def get_image_features(self, pixel_values: torch.Tensor) -> torch.Tensor:
  1601. """
  1602. Projects the last hidden state from the vision model into language model space.
  1603. Args:
  1604. pixel_values (`torch.FloatTensor]` of shape `(batch_size, channels, height, width)`)
  1605. The tensors corresponding to the input images.
  1606. Returns:
  1607. image_features (`torch.Tensor`): Image feature tensor of shape `(num_images, image_length, embed_dim)`).
  1608. """
  1609. vision_outputs = self.vision_tower(
  1610. pixel_values=pixel_values, do_pooling=False, return_dict=True
  1611. ).last_hidden_state
  1612. # Convert from (batch, channels, height, width) to (batch, height * width, channels) where:
  1613. # height == width and height * width == Gemma3nConfig.vision_soft_tokens_per_image.
  1614. vision_outputs = vision_outputs.reshape(
  1615. vision_outputs.shape[0],
  1616. self.config.vision_config.hidden_size,
  1617. self.config.vision_soft_tokens_per_image,
  1618. ).permute(0, 2, 1)
  1619. # Normalize and embed the soft tokens into language model space.
  1620. vision_outputs *= self.config.vision_config.hidden_size**0.5
  1621. return self.embed_vision(inputs_embeds=vision_outputs)
  1622. def get_placeholder_mask(
  1623. self,
  1624. input_ids: Optional[torch.LongTensor] = None,
  1625. inputs_embeds: Optional[torch.FloatTensor] = None,
  1626. image_features: Optional[torch.FloatTensor] = None,
  1627. audio_features: Optional[torch.FloatTensor] = None,
  1628. ):
  1629. """
  1630. Obtains multimodal placeholder mask from `input_ids` or `inputs_embeds`, and checks that the placeholder token count is
  1631. equal to the length of multimodal features. If the lengths are different, an error is raised.
  1632. """
  1633. if input_ids is None:
  1634. special_image_mask = inputs_embeds == self.get_input_embeddings()(
  1635. torch.tensor(self.config.image_token_id, dtype=torch.long, device=inputs_embeds.device)
  1636. )
  1637. special_image_mask = special_image_mask.all(-1)
  1638. special_audio_mask = (
  1639. inputs_embeds
  1640. == self.get_input_embeddings()(
  1641. torch.tensor(self.config.audio_token_id, dtype=torch.long, device=inputs_embeds.device)
  1642. )
  1643. ).all(-1)
  1644. else:
  1645. special_image_mask = input_ids == self.config.image_token_id
  1646. special_audio_mask = input_ids == self.config.audio_token_id
  1647. n_image_tokens = special_image_mask.sum()
  1648. special_image_mask = special_image_mask.unsqueeze(-1).expand_as(inputs_embeds).to(inputs_embeds.device)
  1649. if image_features is not None and inputs_embeds[special_image_mask].numel() != image_features.numel():
  1650. raise ValueError(
  1651. f"Image features and image tokens do not match: tokens: {n_image_tokens}, features {image_features.shape[0] * image_features.shape[1]}"
  1652. )
  1653. n_audio_tokens = special_audio_mask.sum()
  1654. special_audio_mask = special_audio_mask.unsqueeze(-1).expand_as(inputs_embeds).to(inputs_embeds.device)
  1655. if audio_features is not None and inputs_embeds[special_audio_mask].numel() != audio_features.numel():
  1656. raise ValueError(
  1657. f"Audio features and image tokens do not match: tokens: {n_audio_tokens}, features {audio_features.shape[0] * audio_features.shape[1]}"
  1658. )
  1659. return special_image_mask, special_audio_mask
  1660. @can_return_tuple
  1661. def forward(
  1662. self,
  1663. input_ids: Optional[torch.LongTensor] = None, # text inputs
  1664. pixel_values: Optional[torch.FloatTensor] = None, # vision inputs
  1665. input_features: Optional[torch.FloatTensor] = None, # audio inputs
  1666. attention_mask: Optional[torch.Tensor] = None,
  1667. input_features_mask: Optional[torch.Tensor] = None,
  1668. position_ids: Optional[torch.LongTensor] = None,
  1669. past_key_values: Optional[Cache] = None,
  1670. token_type_ids: Optional[torch.LongTensor] = None,
  1671. cache_position: Optional[torch.LongTensor] = None,
  1672. inputs_embeds: Optional[torch.FloatTensor] = None,
  1673. labels: Optional[torch.LongTensor] = None,
  1674. use_cache: Optional[bool] = None,
  1675. output_attentions: Optional[bool] = None,
  1676. output_hidden_states: Optional[bool] = None,
  1677. **lm_kwargs,
  1678. ) -> Gemma3nCausalLMOutputWithPast:
  1679. r"""
  1680. labels (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*):
  1681. Labels for computing the masked language modeling loss. Indices should either be in `[0, ...,
  1682. config.text_config.vocab_size]` or -100 (see `input_ids` docstring). Tokens with indices set to `-100` are ignored
  1683. (masked), the loss is only computed for the tokens with labels in `[0, ..., config.text_config.vocab_size]`.
  1684. Example:
  1685. ```python
  1686. >>> from PIL import Image
  1687. >>> import requests
  1688. >>> from transformers import AutoProcessor, Gemma3nForConditionalGeneration
  1689. >>> model = Gemma3nForConditionalGeneration.from_pretrained("google/gemma3n2-3b-mix-224")
  1690. >>> processor = AutoProcessor.from_pretrained("google/gemma3n2-3b-mix-224")
  1691. >>> prompt = "Where is the cat standing?"
  1692. >>> url = "https://huggingface.co/datasets/huggingface/documentation-images/resolve/main/pipeline-cat-chonk.jpeg"
  1693. >>> image = Image.open(requests.get(url, stream=True).raw)
  1694. >>> inputs = processor(images=image, text=prompt, return_tensors="pt")
  1695. >>> # Generate
  1696. >>> generate_ids = model.generate(**inputs,)
  1697. >>> processor.batch_decode(generate_ids, skip_special_tokens=True, clean_up_tokenization_spaces=False)[0]
  1698. "Where is the cat standing?\nsnow"
  1699. ```
  1700. """
  1701. if (input_ids is None) ^ (inputs_embeds is not None):
  1702. raise ValueError("You must specify exactly one of input_ids or inputs_embeds")
  1703. output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
  1704. output_hidden_states = (
  1705. output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
  1706. )
  1707. if input_ids is not None:
  1708. inputs_embeds = self.get_input_embeddings()(input_ids)
  1709. # Prepare per-layer inputs from inputs_ids
  1710. per_layer_inputs_mask = torch.logical_and(input_ids >= 0, input_ids < self.vocab_size_per_layer_input)
  1711. per_layer_inputs_tokens = torch.where(per_layer_inputs_mask, input_ids, torch.zeros_like(input_ids))
  1712. per_layer_inputs = self.language_model.get_per_layer_inputs(per_layer_inputs_tokens)
  1713. # Handle vision tokens (>= embed_vision.vocab_offset and < embed_audio.vocab_offset)
  1714. vision_mask = torch.logical_and(
  1715. input_ids >= self.embed_vision.vocab_offset, input_ids < self.embed_audio.vocab_offset
  1716. )
  1717. dummy_vision_token_id = self.embed_vision.vocab_offset + self.embed_vision.vocab_size - 1
  1718. vision_input_ids = torch.where(vision_mask, input_ids, dummy_vision_token_id).to(inputs_embeds.device)
  1719. vision_embeds = self.embed_vision(input_ids=vision_input_ids)
  1720. expanded_vision_mask = vision_mask.unsqueeze(-1).expand_as(inputs_embeds)
  1721. inputs_embeds = torch.where(expanded_vision_mask, vision_embeds, inputs_embeds)
  1722. # Handle audio tokens (>= embed_audio.vocab_offset)
  1723. audio_mask = input_ids >= self.embed_audio.vocab_offset
  1724. dummy_audio_token_id = self.embed_audio.vocab_offset + self.embed_audio.vocab_size - 1
  1725. audio_input_ids = torch.where(audio_mask, input_ids, dummy_audio_token_id).to(inputs_embeds.device)
  1726. audio_embeds = self.embed_audio(input_ids=audio_input_ids)
  1727. expanded_audio_mask = audio_mask.unsqueeze(-1).expand_as(inputs_embeds)
  1728. inputs_embeds = torch.where(expanded_audio_mask, audio_embeds, inputs_embeds)
  1729. else:
  1730. per_layer_inputs = None
  1731. # Merge text and images
  1732. if pixel_values is not None:
  1733. image_features = self.get_image_features(pixel_values)
  1734. image_features = image_features.to(inputs_embeds.device, inputs_embeds.dtype)
  1735. special_image_mask, _ = self.get_placeholder_mask(
  1736. input_ids, inputs_embeds=inputs_embeds, image_features=image_features
  1737. )
  1738. inputs_embeds = inputs_embeds.masked_scatter(special_image_mask, image_features)
  1739. # Merge text and audio
  1740. if input_features is not None and input_features_mask is not None:
  1741. audio_features, audio_mask = self.get_audio_features(input_features, ~input_features_mask)
  1742. # The Gemma3nProcessor expects all audio will be 30s in length and inserts 188 audio soft tokens into the
  1743. # text to account for this. However, the audio preprocessing and encoder do not gurarantee they will
  1744. # produce 188 soft tokens; they will produce at most that many tokens, but they may produce fewer tokens
  1745. # depending on the length of the longest audio input in the batch. When we encounter this situation, we pad
  1746. # the audio feature out to 188 soft tokens with the emebedding of the last token in the embed_audio vocab.
  1747. audio_padding_toks = torch.tensor([[self.vocab_size - 1]], dtype=torch.long, device=audio_features.device)
  1748. audio_padding_embs = self.embed_audio(input_ids=audio_padding_toks)
  1749. audio_features = torch.where(audio_mask.unsqueeze(-1), audio_padding_embs, audio_features)
  1750. audio_batch_size, audio_seq_len, audio_embed_dim = audio_features.shape
  1751. extra_padding_tokens = self.config.audio_soft_tokens_per_image - audio_seq_len
  1752. extra_padding_features = audio_padding_embs.expand(audio_batch_size, extra_padding_tokens, audio_embed_dim)
  1753. audio_features = torch.cat((audio_features, extra_padding_features), dim=1)
  1754. audio_features = audio_features.to(inputs_embeds.device, inputs_embeds.dtype)
  1755. _, special_audio_mask = self.get_placeholder_mask(
  1756. input_ids, inputs_embeds=inputs_embeds, audio_features=audio_features
  1757. )
  1758. inputs_embeds = inputs_embeds.masked_scatter(special_audio_mask, audio_features)
  1759. outputs = self.language_model(
  1760. input_ids=None,
  1761. per_layer_inputs=per_layer_inputs,
  1762. attention_mask=attention_mask,
  1763. position_ids=position_ids,
  1764. past_key_values=past_key_values,
  1765. inputs_embeds=inputs_embeds,
  1766. use_cache=use_cache,
  1767. output_attentions=output_attentions,
  1768. output_hidden_states=output_hidden_states,
  1769. return_dict=True,
  1770. cache_position=cache_position,
  1771. **lm_kwargs,
  1772. )
  1773. return Gemma3nModelOutputWithPast(
  1774. last_hidden_state=outputs.last_hidden_state,
  1775. past_key_values=outputs.past_key_values if use_cache else None,
  1776. hidden_states=outputs.hidden_states,
  1777. attentions=outputs.attentions,
  1778. image_hidden_states=image_features if pixel_values is not None else None,
  1779. audio_hidden_states=audio_features if input_features is not None else None,
  1780. )
  1781. def get_audio_features(
  1782. self, input_features: torch.Tensor, input_features_mask: torch.Tensor
  1783. ) -> tuple[torch.Tensor, torch.Tensor]:
  1784. """
  1785. Projects the last hidden state from the audio encoder into language model space.
  1786. Args:
  1787. input_features (`torch.FloatTensor]` of shape `(num_images, seq_length, num_features)`):
  1788. The tensors corresponding to the input audio.
  1789. input_features_mask (`torch.FloatTensor]` of shape `(num_images, seq_length)`):
  1790. The attention mask for the input audio.
  1791. Returns:
  1792. audio_features (`torch.Tensor`): Audio feature tensor of shape `(num_images, audio_length, embed_dim)`).
  1793. """
  1794. audio_outputs, audio_mask = self.audio_tower(input_features, input_features_mask)
  1795. return self.embed_audio(inputs_embeds=audio_outputs), audio_mask
  1796. @auto_docstring(
  1797. custom_intro="""
  1798. The base Gemma 3n model comprising a vision backbone, an audio backbone, a language model, and a language modeling
  1799. head.
  1800. """
  1801. )
  1802. class Gemma3nForConditionalGeneration(Gemma3nPreTrainedModel, GenerationMixin):
  1803. _checkpoint_conversion_mapping = {}
  1804. _tied_weights_keys = ["lm_head.weight"]
  1805. base_model_prefix = "model"
  1806. def __init__(self, config: Gemma3nConfig):
  1807. super().__init__(config)
  1808. self.model = Gemma3nModel(config)
  1809. self.lm_head = nn.Linear(config.text_config.hidden_size, config.text_config.vocab_size, bias=False)
  1810. self.post_init()
  1811. def get_input_embeddings(self):
  1812. return self.model.get_input_embeddings()
  1813. def set_input_embeddings(self, value):
  1814. self.model.set_input_embeddings(value)
  1815. def set_decoder(self, decoder):
  1816. self.model.set_decoder(decoder)
  1817. def get_decoder(self):
  1818. return self.model.get_decoder()
  1819. def get_image_features(self, pixel_values):
  1820. return self.model.get_image_features(pixel_values)
  1821. # Make modules available through conditional class for BC
  1822. @property
  1823. def language_model(self):
  1824. return self.model.language_model
  1825. @property
  1826. def vision_tower(self):
  1827. return self.model.vision_tower
  1828. @property
  1829. def multi_modal_projector(self):
  1830. raise AttributeError("Use embed_vision instead of multi_modal_projector.")
  1831. @can_return_tuple
  1832. @auto_docstring
  1833. def forward(
  1834. self,
  1835. input_ids: Optional[torch.LongTensor] = None, # text inputs
  1836. pixel_values: Optional[torch.FloatTensor] = None, # vision inputs
  1837. input_features: Optional[torch.FloatTensor] = None, # audio inputs
  1838. attention_mask: Optional[torch.Tensor] = None,
  1839. input_features_mask: Optional[torch.Tensor] = None,
  1840. position_ids: Optional[torch.LongTensor] = None,
  1841. past_key_values: Optional[Cache] = None,
  1842. token_type_ids: Optional[torch.LongTensor] = None,
  1843. cache_position: Optional[torch.LongTensor] = None,
  1844. inputs_embeds: Optional[torch.FloatTensor] = None,
  1845. labels: Optional[torch.LongTensor] = None,
  1846. use_cache: Optional[bool] = None,
  1847. output_attentions: Optional[bool] = None,
  1848. output_hidden_states: Optional[bool] = None,
  1849. logits_to_keep: Union[int, torch.Tensor] = 0,
  1850. **lm_kwargs,
  1851. ) -> Gemma3nCausalLMOutputWithPast:
  1852. r"""
  1853. input_features_mask (torch.Tensor, *optional*, defaults to None):
  1854. The attention mask for the input audio.
  1855. labels (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*):
  1856. Labels for computing the masked language modeling loss. Indices should either be in `[0, ...,
  1857. config.text_config.vocab_size]` or -100 (see `input_ids` docstring). Tokens with indices set to `-100` are
  1858. ignored (masked), the loss is only computed for the tokens with labels in
  1859. `[0, ..., config.text_config.vocab_size]`.
  1860. Example:
  1861. ```python
  1862. >>> from PIL import Image
  1863. >>> import requests
  1864. >>> from transformers import AutoProcessor, Gemma3ForConditionalGeneration
  1865. >>> model = Gemma3ForConditionalGeneration.from_pretrained("google/gemma-3-4b-it")
  1866. >>> processor = AutoProcessor.from_pretrained("google/gemma-3-4b-it")
  1867. >>> messages = [
  1868. ... {
  1869. ... "role": "system",
  1870. ... "content": [
  1871. ... {"type": "text", "text": "You are a helpful assistant."}
  1872. ... ]
  1873. ... },
  1874. ... {
  1875. ... "role": "user", "content": [
  1876. ... {"type": "image", "url": "https://huggingface.co/datasets/huggingface/documentation-images/resolve/main/pipeline-cat-chonk.jpeg"},
  1877. ... {"type": "text", "text": "Where is the cat standing?"},
  1878. ... ]
  1879. ... },
  1880. ... ]
  1881. >>> inputs = processor.apply_chat_template(
  1882. ... messages,
  1883. ... tokenizer=True,
  1884. ... return_dict=True,
  1885. ... return_tensors="pt",
  1886. ... add_generation_prompt=True
  1887. ... )
  1888. >>> # Generate
  1889. >>> generate_ids = model.generate(**inputs)
  1890. >>> processor.batch_decode(generate_ids, skip_special_tokens=True, clean_up_tokenization_spaces=False)[0]
  1891. "user\nYou are a helpful assistant.\n\n\n\n\n\nWhere is the cat standing?\nmodel\nBased on the image, the cat is standing in a snowy area, likely outdoors. It appears to"
  1892. ```
  1893. """
  1894. output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
  1895. output_hidden_states = (
  1896. output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
  1897. )
  1898. outputs = self.model(
  1899. input_ids=input_ids,
  1900. pixel_values=pixel_values,
  1901. input_features=input_features,
  1902. attention_mask=attention_mask,
  1903. input_features_mask=input_features_mask,
  1904. position_ids=position_ids,
  1905. past_key_values=past_key_values,
  1906. token_type_ids=token_type_ids,
  1907. cache_position=cache_position,
  1908. inputs_embeds=inputs_embeds,
  1909. labels=labels,
  1910. use_cache=use_cache,
  1911. output_attentions=output_attentions,
  1912. output_hidden_states=output_hidden_states,
  1913. return_dict=True,
  1914. **lm_kwargs,
  1915. )
  1916. hidden_states = outputs.last_hidden_state
  1917. # Only compute necessary logits, and do not upcast them to float if we are not computing the loss
  1918. slice_indices = slice(-logits_to_keep, None) if isinstance(logits_to_keep, int) else logits_to_keep
  1919. logits = self.lm_head(hidden_states[:, slice_indices, :])
  1920. if (final_logit_softcapping := self.config.get_text_config().final_logit_softcapping) is not None:
  1921. logits = logits / final_logit_softcapping
  1922. logits = torch.tanh(logits)
  1923. logits = logits * final_logit_softcapping
  1924. loss = None
  1925. if labels is not None:
  1926. # Upcast to float if we need to compute the loss to avoid potential precision issues
  1927. logits = logits.float()
  1928. shift_logits = logits[..., :-1, :]
  1929. shift_labels = labels[..., 1:]
  1930. if attention_mask is not None:
  1931. # we use the input attention mask to shift the logits and labels, because it is 2D.
  1932. # we also crop attn mask in case it is longer, which happens in PrefixTuning with peft
  1933. shift_attention_mask = attention_mask[:, -shift_logits.shape[1] :].to(logits.device)
  1934. shift_logits = shift_logits[shift_attention_mask.to(logits.device) != 0].contiguous()
  1935. shift_labels = shift_labels[shift_attention_mask.to(shift_labels.device) != 0].contiguous()
  1936. else:
  1937. shift_logits = shift_logits.contiguous()
  1938. shift_labels = shift_labels.contiguous()
  1939. # Flatten the tokens
  1940. loss_fct = nn.CrossEntropyLoss()
  1941. flat_logits = shift_logits.view(-1, self.config.text_config.vocab_size)
  1942. flat_labels = shift_labels.view(-1).to(shift_logits.device)
  1943. loss = loss_fct(flat_logits, flat_labels)
  1944. return Gemma3nCausalLMOutputWithPast(
  1945. loss=loss,
  1946. logits=logits,
  1947. past_key_values=outputs.past_key_values,
  1948. hidden_states=outputs.hidden_states,
  1949. attentions=outputs.attentions,
  1950. image_hidden_states=outputs.image_hidden_states,
  1951. audio_hidden_states=outputs.audio_hidden_states,
  1952. )
  1953. def prepare_inputs_for_generation(
  1954. self,
  1955. input_ids,
  1956. past_key_values=None,
  1957. inputs_embeds=None,
  1958. cache_position=None,
  1959. position_ids=None,
  1960. pixel_values=None,
  1961. input_features=None,
  1962. attention_mask=None,
  1963. input_features_mask=None,
  1964. token_type_ids=None,
  1965. use_cache=True,
  1966. logits_to_keep=None,
  1967. labels=None,
  1968. **kwargs,
  1969. ):
  1970. # Overwritten -- custom `position_ids` and `pixel_values` handling
  1971. model_inputs = super().prepare_inputs_for_generation(
  1972. input_ids,
  1973. past_key_values=past_key_values,
  1974. inputs_embeds=inputs_embeds,
  1975. attention_mask=attention_mask,
  1976. position_ids=position_ids,
  1977. cache_position=cache_position,
  1978. use_cache=use_cache,
  1979. logits_to_keep=logits_to_keep,
  1980. token_type_ids=token_type_ids,
  1981. **kwargs,
  1982. )
  1983. # If we're in cached decoding stage, multimodal inputs should be None because input ids do not contain special
  1984. # tokens anymore. Otherwise multimodal inputs should be passed to model.
  1985. # NOTE: use_cache=False always needs pixel_values, input_features, and input_features_mask
  1986. if cache_position[0] == 0:
  1987. model_inputs["pixel_values"] = pixel_values
  1988. model_inputs["input_features"] = input_features
  1989. model_inputs["input_features_mask"] = input_features_mask
  1990. return model_inputs
  1991. @property
  1992. def audio_tower(self):
  1993. return self.model.audio_tower
  1994. __all__ = [
  1995. "Gemma3nAudioEncoder",
  1996. "Gemma3nForCausalLM",
  1997. "Gemma3nForConditionalGeneration",
  1998. "Gemma3nModel",
  1999. "Gemma3nPreTrainedModel",
  2000. "Gemma3nTextModel",
  2001. ]