modular_vaultgemma.py 2.7 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384
  1. # coding=utf-8
  2. # Copyright 2025 the HuggingFace Team. All rights reserved.
  3. #
  4. # Licensed under the Apache License, Version 2.0 (the "License");
  5. # you may not use this file except in compliance with the License.
  6. # You may obtain a copy of the License at
  7. #
  8. # http://www.apache.org/licenses/LICENSE-2.0
  9. #
  10. # Unless required by applicable law or agreed to in writing, software
  11. # distributed under the License is distributed on an "AS IS" BASIS,
  12. # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  13. # See the License for the specific language governing permissions and
  14. # limitations under the License.
  15. from typing import Optional
  16. import torch
  17. from ...cache_utils import Cache
  18. from ..gemma2.configuration_gemma2 import Gemma2Config
  19. from ..gemma2.modeling_gemma2 import Gemma2DecoderLayer, Gemma2ForCausalLM
  20. class VaultGemmaConfig(Gemma2Config):
  21. pass
  22. class VaultGemmaDecoderLayer(Gemma2DecoderLayer):
  23. def __init__(self, **super_kwargs):
  24. super().__init__(**super_kwargs)
  25. del self.post_attention_layernorm
  26. del self.post_feedforward_layernorm
  27. def forward(
  28. self,
  29. hidden_states: torch.Tensor,
  30. position_embeddings: tuple[torch.Tensor, torch.Tensor],
  31. attention_mask: Optional[torch.Tensor] = None,
  32. position_ids: Optional[torch.LongTensor] = None,
  33. past_key_values: Optional[Cache] = None,
  34. output_attentions: Optional[bool] = False,
  35. use_cache: Optional[bool] = False,
  36. cache_position: Optional[torch.LongTensor] = None,
  37. **kwargs,
  38. ) -> tuple[torch.FloatTensor, Optional[tuple[torch.FloatTensor, torch.FloatTensor]]]:
  39. residual = hidden_states
  40. hidden_states = self.input_layernorm(hidden_states)
  41. # Self Attention
  42. hidden_states, self_attn_weights = self.self_attn(
  43. hidden_states=hidden_states,
  44. position_embeddings=position_embeddings,
  45. attention_mask=attention_mask,
  46. position_ids=position_ids,
  47. past_key_values=past_key_values,
  48. output_attentions=output_attentions,
  49. use_cache=use_cache,
  50. cache_position=cache_position,
  51. **kwargs,
  52. )
  53. hidden_states = residual + hidden_states
  54. residual = hidden_states
  55. hidden_states = self.pre_feedforward_layernorm(hidden_states)
  56. hidden_states = self.mlp(hidden_states)
  57. hidden_states = residual + hidden_states
  58. outputs = (hidden_states,)
  59. if output_attentions:
  60. outputs += (self_attn_weights,)
  61. return outputs
  62. class VaultGemmaForCausalLM(Gemma2ForCausalLM):
  63. pass
  64. __all__ = [
  65. "VaultGemmaConfig",
  66. "VaultGemmaForCausalLM",
  67. "VaultGemmaModel", # noqa: F822
  68. "VaultGemmaPreTrainedModel", # noqa: F822
  69. ]