modular_gemma.py 20 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492
  1. # coding=utf-8
  2. # Copyright 2024 Google Inc. HuggingFace Inc. team. All rights reserved.
  3. #
  4. #
  5. # Licensed under the Apache License, Version 2.0 (the "License");
  6. # you may not use this file except in compliance with the License.
  7. # You may obtain a copy of the License at
  8. #
  9. # http://www.apache.org/licenses/LICENSE-2.0
  10. #
  11. # Unless required by applicable law or agreed to in writing, software
  12. # distributed under the License is distributed on an "AS IS" BASIS,
  13. # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  14. # See the License for the specific language governing permissions and
  15. # limitations under the License.
  16. from typing import TYPE_CHECKING, Any, Optional
  17. import sentencepiece as spm
  18. import torch
  19. from torch import nn
  20. from ...cache_utils import Cache, DynamicCache
  21. from ...configuration_utils import PretrainedConfig
  22. from ...masking_utils import create_causal_mask
  23. from ...modeling_outputs import BaseModelOutputWithPast
  24. from ...modeling_utils import PreTrainedModel
  25. from ...processing_utils import Unpack
  26. from ...tokenization_utils import AddedToken, PreTrainedTokenizer
  27. from ...utils import TransformersKwargs, logging
  28. from ..llama.modeling_llama import (
  29. LlamaForCausalLM,
  30. LlamaForSequenceClassification,
  31. LlamaForTokenClassification,
  32. LlamaMLP,
  33. LlamaModel,
  34. LlamaPreTrainedModel,
  35. LlamaRotaryEmbedding,
  36. )
  37. from ..llama.tokenization_llama import LlamaTokenizer
  38. if TYPE_CHECKING:
  39. from ...tokenization_utils_base import TextInput
  40. VOCAB_FILES_NAMES = {"vocab_file": "tokenizer.model"}
  41. SPIECE_UNDERLINE = "▁"
  42. logger = logging.get_logger(__name__)
  43. class GemmaConfig(PretrainedConfig):
  44. r"""
  45. This is the configuration class to store the configuration of a [`GemmaModel`]. It is used to instantiate an Gemma
  46. model according to the specified arguments, defining the model architecture. Instantiating a configuration with the
  47. defaults will yield a similar configuration to that of the Gemma-7B.
  48. e.g. [google/gemma-7b](https://huggingface.co/google/gemma-7b)
  49. Configuration objects inherit from [`PretrainedConfig`] and can be used to control the model outputs. Read the
  50. documentation from [`PretrainedConfig`] for more information.
  51. Args:
  52. vocab_size (`int`, *optional*, defaults to 256000):
  53. Vocabulary size of the Gemma model. Defines the number of different tokens that can be represented by the
  54. `inputs_ids` passed when calling [`GemmaModel`]
  55. hidden_size (`int`, *optional*, defaults to 3072):
  56. Dimension of the hidden representations.
  57. intermediate_size (`int`, *optional*, defaults to 24576):
  58. Dimension of the MLP representations.
  59. num_hidden_layers (`int`, *optional*, defaults to 28):
  60. Number of hidden layers in the Transformer decoder.
  61. num_attention_heads (`int`, *optional*, defaults to 16):
  62. Number of attention heads for each attention layer in the Transformer decoder.
  63. num_key_value_heads (`int`, *optional*, defaults to 16):
  64. This is the number of key_value heads that should be used to implement Grouped Query Attention. If
  65. `num_key_value_heads=num_attention_heads`, the model will use Multi Head Attention (MHA), if
  66. `num_key_value_heads=1` the model will use Multi Query Attention (MQA) otherwise GQA is used. When
  67. converting a multi-head checkpoint to a GQA checkpoint, each group key and value head should be constructed
  68. by meanpooling all the original heads within that group. For more details, check out [this
  69. paper](https://huggingface.co/papers/2305.13245). If it is not specified, will default to
  70. `num_attention_heads`.
  71. head_dim (`int`, *optional*, defaults to 256):
  72. The attention head dimension.
  73. hidden_act (`str` or `function`, *optional*, defaults to `"gelu_pytorch_tanh"`):
  74. The legacy activation function. It is overwritten by the `hidden_activation`.
  75. hidden_activation (`str` or `function`, *optional*):
  76. The non-linear activation function (function or string) in the decoder. Will default to `"gelu_pytorch_tanh"`
  77. if not specified. `"gelu_pytorch_tanh"` uses an approximation of the `"gelu"` activation function.
  78. max_position_embeddings (`int`, *optional*, defaults to 8192):
  79. The maximum sequence length that this model might ever be used with.
  80. initializer_range (`float`, *optional*, defaults to 0.02):
  81. The standard deviation of the truncated_normal_initializer for initializing all weight matrices.
  82. rms_norm_eps (`float`, *optional*, defaults to 1e-06):
  83. The epsilon used by the rms normalization layers.
  84. use_cache (`bool`, *optional*, defaults to `True`):
  85. Whether or not the model should return the last key/values attentions (not used by all models). Only
  86. relevant if `config.is_decoder=True`.
  87. pad_token_id (`int`, *optional*, defaults to 0):
  88. Padding token id.
  89. eos_token_id (`int`, *optional*, defaults to 1):
  90. End of stream token id.
  91. bos_token_id (`int`, *optional*, defaults to 2):
  92. Beginning of stream token id.
  93. tie_word_embeddings (`bool`, *optional*, defaults to `True`):
  94. Whether to tie weight embeddings
  95. rope_theta (`float`, *optional*, defaults to 10000.0):
  96. The base period of the RoPE embeddings.
  97. attention_bias (`bool`, defaults to `False`, *optional*, defaults to `False`):
  98. Whether to use a bias in the query, key, value and output projection layers during self-attention.
  99. attention_dropout (`float`, *optional*, defaults to 0.0):
  100. The dropout ratio for the attention probabilities.
  101. ```python
  102. >>> from transformers import GemmaModel, GemmaConfig
  103. >>> # Initializing a Gemma gemma-7b style configuration
  104. >>> configuration = GemmaConfig()
  105. >>> # Initializing a model from the gemma-7b style configuration
  106. >>> model = GemmaModel(configuration)
  107. >>> # Accessing the model configuration
  108. >>> configuration = model.config
  109. ```"""
  110. model_type = "gemma"
  111. keys_to_ignore_at_inference = ["past_key_values"]
  112. base_model_tp_plan = {
  113. "layers.*.self_attn.q_proj": "colwise",
  114. "layers.*.self_attn.k_proj": "colwise",
  115. "layers.*.self_attn.v_proj": "colwise",
  116. "layers.*.self_attn.o_proj": "rowwise",
  117. "layers.*.mlp.gate_proj": "colwise",
  118. "layers.*.mlp.up_proj": "colwise",
  119. "layers.*.mlp.down_proj": "rowwise",
  120. }
  121. base_model_pp_plan = {
  122. "embed_tokens": (["input_ids"], ["inputs_embeds"]),
  123. "layers": (["hidden_states", "attention_mask"], ["hidden_states"]),
  124. "norm": (["hidden_states"], ["hidden_states"]),
  125. }
  126. def __init__(
  127. self,
  128. vocab_size=256000,
  129. hidden_size=3072,
  130. intermediate_size=24576,
  131. num_hidden_layers=28,
  132. num_attention_heads=16,
  133. num_key_value_heads=16,
  134. head_dim=256,
  135. hidden_act="gelu_pytorch_tanh",
  136. hidden_activation=None,
  137. max_position_embeddings=8192,
  138. initializer_range=0.02,
  139. rms_norm_eps=1e-6,
  140. use_cache=True,
  141. pad_token_id=0,
  142. eos_token_id=1,
  143. bos_token_id=2,
  144. tie_word_embeddings=True,
  145. rope_theta=10000.0,
  146. attention_bias=False,
  147. attention_dropout=0.0,
  148. **kwargs,
  149. ):
  150. self.vocab_size = vocab_size
  151. self.max_position_embeddings = max_position_embeddings
  152. self.hidden_size = hidden_size
  153. self.intermediate_size = intermediate_size
  154. self.num_hidden_layers = num_hidden_layers
  155. self.num_attention_heads = num_attention_heads
  156. self.head_dim = head_dim
  157. self.num_key_value_heads = num_key_value_heads
  158. self.hidden_act = hidden_act
  159. self.hidden_activation = hidden_activation
  160. self.initializer_range = initializer_range
  161. self.rms_norm_eps = rms_norm_eps
  162. self.use_cache = use_cache
  163. self.rope_theta = rope_theta
  164. self.attention_bias = attention_bias
  165. self.attention_dropout = attention_dropout
  166. super().__init__(
  167. pad_token_id=pad_token_id,
  168. bos_token_id=bos_token_id,
  169. eos_token_id=eos_token_id,
  170. tie_word_embeddings=tie_word_embeddings,
  171. **kwargs,
  172. )
  173. class GemmaTokenizer(LlamaTokenizer, PreTrainedTokenizer):
  174. """
  175. Construct a Gemma tokenizer. Based on byte-level Byte-Pair-Encoding. The default padding token is unset as there is
  176. no padding token in the original model.
  177. Args:
  178. vocab_file (`str`):
  179. Path to the vocabulary file.
  180. unk_token (`str` or `tokenizers.AddedToken`, *optional*, defaults to `"<unk>"`):
  181. The unknown token. A token that is not in the vocabulary cannot be converted to an ID and is set to be this
  182. token instead.
  183. bos_token (`str` or `tokenizers.AddedToken`, *optional*, defaults to `"<bos>"`):
  184. The beginning of sequence token that was used during pretraining. Can be used a sequence classifier token.
  185. eos_token (`str` or `tokenizers.AddedToken`, *optional*, defaults to `"<eos>"`):
  186. The end of sequence token.
  187. pad_token (`str` or `tokenizers.AddedToken`, *optional*, defaults to `"<pad>"`):
  188. A special token used to make arrays of tokens the same size for batching purpose. Will then be ignored by
  189. attention mechanisms or loss computation.
  190. sp_model_kwargs (`dict[str, Any]`, `Optional`, *optional*):
  191. Will be passed to the `SentencePieceProcessor.__init__()` method. The [Python wrapper for
  192. SentencePiece](https://github.com/google/sentencepiece/tree/master/python) can be used, among other things,
  193. to set:
  194. - `enable_sampling`: Enable subword regularization.
  195. - `nbest_size`: Sampling parameters for unigram. Invalid for BPE-Dropout.
  196. - `nbest_size = {0,1}`: No sampling is performed.
  197. - `nbest_size > 1`: samples from the nbest_size results.
  198. - `nbest_size < 0`: assuming that nbest_size is infinite and samples from the all hypothesis (lattice)
  199. using forward-filtering-and-backward-sampling algorithm.
  200. - `alpha`: Smoothing parameter for unigram sampling, and dropout probability of merge operations for
  201. BPE-dropout.
  202. add_bos_token (`bool`, *optional*, defaults to `True`):
  203. Whether or not to add an `bos_token` at the start of sequences.
  204. add_eos_token (`bool`, *optional*, defaults to `False`):
  205. Whether or not to add an `eos_token` at the end of sequences.
  206. clean_up_tokenization_spaces (`bool`, *optional*, defaults to `False`):
  207. Whether or not to cleanup spaces after decoding, cleanup consists in removing potential artifacts like
  208. extra spaces.
  209. use_default_system_prompt (`bool`, *optional*, defaults to `False`):
  210. Whether or not the default system prompt for Gemma should be used.
  211. spaces_between_special_tokens (`bool`, *optional*, defaults to `False`):
  212. Whether or not to add spaces between special tokens.
  213. """
  214. def __init__(
  215. self,
  216. vocab_file,
  217. unk_token="<unk>",
  218. bos_token="<bos>",
  219. eos_token="<eos>",
  220. pad_token="<pad>",
  221. sp_model_kwargs: Optional[dict[str, Any]] = None,
  222. add_bos_token=True,
  223. add_eos_token=False,
  224. clean_up_tokenization_spaces=False,
  225. use_default_system_prompt=False,
  226. spaces_between_special_tokens=False,
  227. **kwargs,
  228. ):
  229. self.sp_model_kwargs = {} if sp_model_kwargs is None else sp_model_kwargs
  230. bos_token = AddedToken(bos_token, normalized=False, special=True) if isinstance(bos_token, str) else bos_token
  231. eos_token = AddedToken(eos_token, normalized=False, special=True) if isinstance(eos_token, str) else eos_token
  232. unk_token = AddedToken(unk_token, normalized=False, special=True) if isinstance(unk_token, str) else unk_token
  233. pad_token = AddedToken(pad_token, normalized=False, special=True) if isinstance(pad_token, str) else pad_token
  234. self.vocab_file = vocab_file
  235. self.add_bos_token = add_bos_token
  236. self.add_eos_token = add_eos_token
  237. self.use_default_system_prompt = use_default_system_prompt
  238. self.sp_model = spm.SentencePieceProcessor(**self.sp_model_kwargs)
  239. self.sp_model.Load(vocab_file)
  240. PreTrainedTokenizer.__init__(
  241. self,
  242. bos_token=bos_token,
  243. eos_token=eos_token,
  244. unk_token=unk_token,
  245. pad_token=pad_token,
  246. add_bos_token=add_bos_token,
  247. add_eos_token=add_eos_token,
  248. sp_model_kwargs=sp_model_kwargs,
  249. clean_up_tokenization_spaces=clean_up_tokenization_spaces,
  250. use_default_system_prompt=use_default_system_prompt,
  251. spaces_between_special_tokens=spaces_between_special_tokens,
  252. **kwargs,
  253. )
  254. def get_spm_processor(self):
  255. raise AttributeError("Not needed for Gemma")
  256. def unk_token_length(self):
  257. raise AttributeError("Not needed for Gemma")
  258. def tokenize(self, text: "TextInput", **kwargs) -> list[str]:
  259. """
  260. Args:
  261. text: TextInput
  262. Simply calls PreTrainedTokenizer's method
  263. """
  264. return PreTrainedTokenizer.tokenize(self, text, **kwargs)
  265. def _tokenize(self, text, **kwargs):
  266. """
  267. Args:
  268. text: TextInput
  269. Returns a tokenized string. The Gemma tokenizer never adds a prefix space.
  270. """
  271. return self.sp_model.encode(text, out_type=str)
  272. def _decode(
  273. self,
  274. token_ids: list[int],
  275. skip_special_tokens: bool = False,
  276. spaces_between_special_tokens: bool = False,
  277. **kwargs,
  278. ) -> str:
  279. sub_texts = []
  280. current_sub_text = []
  281. for ids in token_ids:
  282. if skip_special_tokens and ids in self.all_special_ids:
  283. continue
  284. if ids in self._added_tokens_decoder:
  285. if current_sub_text:
  286. sub_texts.append(self.sp_model.decode(current_sub_text))
  287. sub_texts.append(self._added_tokens_decoder[ids].content)
  288. current_sub_text = []
  289. else:
  290. current_sub_text.append(ids)
  291. if current_sub_text:
  292. sub_texts.append(self.sp_model.decode(current_sub_text))
  293. if spaces_between_special_tokens:
  294. sub_texts = " ".join(sub_texts)
  295. else:
  296. sub_texts = "".join(sub_texts)
  297. return sub_texts.replace(SPIECE_UNDERLINE, " ")
  298. def convert_tokens_to_string(self, tokens):
  299. """Converts a sequence of tokens (string) in a single string."""
  300. current_sub_tokens = []
  301. out_string = ""
  302. for token in tokens:
  303. # make sure that special tokens are not decoded using sentencepiece model
  304. if token in self._added_tokens_encoder:
  305. out_string += self.sp_model.decode(current_sub_tokens) + token
  306. current_sub_tokens = []
  307. else:
  308. current_sub_tokens.append(token)
  309. out_string += self.sp_model.decode(current_sub_tokens)
  310. return out_string
  311. class GemmaRMSNorm(nn.Module):
  312. def __init__(self, dim: int, eps: float = 1e-6):
  313. super().__init__()
  314. self.eps = eps
  315. self.weight = nn.Parameter(torch.zeros(dim))
  316. def _norm(self, x):
  317. return x * torch.rsqrt(x.pow(2).mean(-1, keepdim=True) + self.eps)
  318. def forward(self, x):
  319. output = self._norm(x.float())
  320. # Llama does x.to(float16) * w whilst Gemma is (x * w).to(float16)
  321. # See https://github.com/huggingface/transformers/pull/29402
  322. output = output * (1.0 + self.weight.float())
  323. return output.type_as(x)
  324. def extra_repr(self):
  325. return f"{tuple(self.weight.shape)}, eps={self.eps}"
  326. class GemmaMLP(LlamaMLP):
  327. def __init__(self, config):
  328. super().__init__(config)
  329. self.gate_proj = nn.Linear(self.hidden_size, self.intermediate_size, bias=False)
  330. self.up_proj = nn.Linear(self.hidden_size, self.intermediate_size, bias=False)
  331. self.down_proj = nn.Linear(self.intermediate_size, self.hidden_size, bias=False)
  332. class GemmaRotaryEmbedding(LlamaRotaryEmbedding):
  333. pass
  334. class GemmaPreTrainedModel(LlamaPreTrainedModel):
  335. def _init_weights(self, module):
  336. PreTrainedModel._init_weights(self, module)
  337. # We initialize with 0s to be 1 centered as the RMSNorm here does (1 + weight)
  338. if "RMSNorm" in module.__class__.__name__:
  339. module.weight.data.zero_()
  340. class GemmaModel(LlamaModel):
  341. def forward(
  342. self,
  343. input_ids: Optional[torch.LongTensor] = None,
  344. attention_mask: Optional[torch.Tensor] = None,
  345. position_ids: Optional[torch.LongTensor] = None,
  346. past_key_values: Optional[Cache] = None,
  347. inputs_embeds: Optional[torch.FloatTensor] = None,
  348. use_cache: Optional[bool] = None,
  349. cache_position: Optional[torch.LongTensor] = None,
  350. **kwargs: Unpack[TransformersKwargs],
  351. ) -> BaseModelOutputWithPast:
  352. if (input_ids is None) ^ (inputs_embeds is not None):
  353. raise ValueError("You must specify exactly one of input_ids or inputs_embeds")
  354. if inputs_embeds is None:
  355. inputs_embeds = self.embed_tokens(input_ids)
  356. if use_cache and past_key_values is None:
  357. past_key_values = DynamicCache(config=self.config)
  358. if cache_position is None:
  359. past_seen_tokens = past_key_values.get_seq_length() if past_key_values is not None else 0
  360. cache_position = torch.arange(
  361. past_seen_tokens, past_seen_tokens + inputs_embeds.shape[1], device=inputs_embeds.device
  362. )
  363. if position_ids is None:
  364. position_ids = cache_position.unsqueeze(0)
  365. causal_mask = create_causal_mask(
  366. config=self.config,
  367. input_embeds=inputs_embeds,
  368. attention_mask=attention_mask,
  369. cache_position=cache_position,
  370. past_key_values=past_key_values,
  371. position_ids=position_ids,
  372. )
  373. # embed positions
  374. hidden_states = inputs_embeds
  375. # create position embeddings to be shared across the decoder layers
  376. position_embeddings = self.rotary_emb(hidden_states, position_ids)
  377. # normalized
  378. # Gemma downcasts the below to float16, causing sqrt(3072)=55.4256 to become 55.5
  379. # See https://github.com/huggingface/transformers/pull/29402
  380. normalizer = torch.tensor(self.config.hidden_size**0.5, dtype=hidden_states.dtype)
  381. hidden_states = hidden_states * normalizer
  382. for decoder_layer in self.layers[: self.config.num_hidden_layers]:
  383. hidden_states = decoder_layer(
  384. hidden_states,
  385. attention_mask=causal_mask,
  386. position_ids=position_ids,
  387. past_key_values=past_key_values,
  388. use_cache=use_cache,
  389. cache_position=cache_position,
  390. position_embeddings=position_embeddings,
  391. **kwargs,
  392. )
  393. hidden_states = self.norm(hidden_states)
  394. return BaseModelOutputWithPast(
  395. last_hidden_state=hidden_states,
  396. past_key_values=past_key_values if use_cache else None,
  397. )
  398. class GemmaForCausalLM(LlamaForCausalLM):
  399. def forward(**super_kwargs):
  400. r"""
  401. Example:
  402. ```python
  403. >>> from transformers import AutoTokenizer, GemmaForCausalLM
  404. >>> model = GemmaForCausalLM.from_pretrained("google/gemma-7b")
  405. >>> tokenizer = AutoTokenizer.from_pretrained("google/gemma-7b")
  406. >>> prompt = "What is your favorite condiment?"
  407. >>> inputs = tokenizer(prompt, return_tensors="pt")
  408. >>> # Generate
  409. >>> generate_ids = model.generate(inputs.input_ids, max_length=30)
  410. >>> tokenizer.batch_decode(generate_ids, skip_special_tokens=True, clean_up_tokenization_spaces=False)[0]
  411. "What is your favorite condiment?"
  412. ```"""
  413. return super().forward(**super_kwargs)
  414. class GemmaForSequenceClassification(LlamaForSequenceClassification):
  415. pass
  416. class GemmaForTokenClassification(LlamaForTokenClassification):
  417. pass
  418. __all__ = [
  419. "GemmaConfig",
  420. "GemmaTokenizer",
  421. "GemmaModel",
  422. "GemmaForCausalLM",
  423. "GemmaForSequenceClassification",
  424. "GemmaForTokenClassification",
  425. "GemmaPreTrainedModel",
  426. ]