configuration_bamba.py 9.9 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210
  1. # coding=utf-8
  2. # Copyright 2024 IBM and the HuggingFace Inc. team. All rights reserved.
  3. #
  4. # Licensed under the Apache License, Version 2.0 (the "License");
  5. # you may not use this file except in compliance with the License.
  6. # You may obtain a copy of the License at
  7. #
  8. # http://www.apache.org/licenses/LICENSE-2.0
  9. #
  10. # Unless required by applicable law or agreed to in writing, software
  11. # distributed under the License is distributed on an "AS IS" BASIS,
  12. # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  13. # See the License for the specific language governing permissions and
  14. # limitations under the License.
  15. """Bamba model configuration"""
  16. from ...configuration_utils import PretrainedConfig
  17. from ...utils import logging
  18. logger = logging.get_logger(__name__)
  19. class BambaConfig(PretrainedConfig):
  20. r"""
  21. This is the configuration class to store the configuration of a [`BambaModel`]. It is used to instantiate a
  22. BambaModel model according to the specified arguments, defining the model architecture. Instantiating a configuration
  23. with defaults taken from [ibm-fms/Bamba-9.8b-2.2T-hf](https://huggingface.co/ibm-fms/Bamba-9.8b-2.2T-hf).
  24. The BambaModel is a hybrid [mamba2](https://github.com/state-spaces/mamba) architecture with SwiGLU.
  25. The checkpoints are jointly trained by IBM, Princeton, and UIUC.
  26. Configuration objects inherit from [`PretrainedConfig`] and can be used to control the model outputs. Read the
  27. documentation from [`PretrainedConfig`] for more information.
  28. Args:
  29. vocab_size (`int`, *optional*, defaults to 128000):
  30. Vocabulary size of the Bamba model. Defines the number of different tokens that can be represented by the
  31. `inputs_ids` passed when calling [`BambaModel`]
  32. tie_word_embeddings (`bool`, *optional*, defaults to `False`):
  33. Whether the model's input and output word embeddings should be tied. Note that this is only relevant if the
  34. model has an output word embedding layer.
  35. hidden_size (`int`, *optional*, defaults to 4096):
  36. Dimension of the hidden representations.
  37. intermediate_size (`int`, *optional*, defaults to 14336):
  38. Dimension of the MLP representations.
  39. num_hidden_layers (`int`, *optional*, defaults to 32):
  40. Number of hidden layers in the Transformer encoder.
  41. num_attention_heads (`int`, *optional*, defaults to 32):
  42. Number of attention heads for each attention layer in the Transformer encoder.
  43. num_key_value_heads (`int`, *optional*, defaults to 8):
  44. This is the number of key_value heads that should be used to implement Grouped Query Attention. If
  45. `num_key_value_heads=num_attention_heads`, the model will use Multi Head Attention (MHA), if
  46. `num_key_value_heads=1` the model will use Multi Query Attention (MQA) otherwise GQA is used. When
  47. converting a multi-head checkpoint to a GQA checkpoint, each group key and value head should be constructed
  48. by meanpooling all the original heads within that group. For more details, check out [this
  49. paper](https://huggingface.co/papers/2305.13245). If it is not specified, will default to `8`.
  50. hidden_act (`str` or `function`, *optional*, defaults to `"silu"`):
  51. The non-linear activation function (function or string) in the decoder.
  52. initializer_range (`float`, *optional*, defaults to 0.02):
  53. The standard deviation of the truncated_normal_initializer for initializing all weight matrices.
  54. rms_norm_eps (`float`, *optional*, defaults to 1e-05):
  55. The epsilon used by the rms normalization layers.
  56. use_cache (`bool`, *optional*, defaults to `True`):
  57. Whether or not the model should return the last key/values attentions (not used by all models). Only
  58. relevant if `config.is_decoder=True`.
  59. num_logits_to_keep (`int` or `None`, *optional*, defaults to 1):
  60. Number of prompt logits to calculate during generation. If `None`, all logits will be calculated. If an
  61. integer value, only last `num_logits_to_keep` logits will be calculated. Default is 1 because only the
  62. logits of the last prompt token are needed for generation. For long sequences, the logits for the entire
  63. sequence may use a lot of memory so, setting `num_logits_to_keep=1` will reduce memory footprint
  64. significantly.
  65. pad_token_id (`int`, *optional*, defaults to 0):
  66. The id of the padding token.
  67. bos_token_id (`int`, *optional*, defaults to 1):
  68. The id of the "beginning-of-sequence" token.
  69. eos_token_id (`int`, *optional*, defaults to 2):
  70. The id of the "end-of-sequence" token.
  71. max_position_embeddings (`int`, *optional*, defaults to 262144):
  72. Max cached sequence length for the model
  73. attention_dropout (`float`, *optional*, defaults to 0.0):
  74. The dropout ratio for the attention probabilities.
  75. attn_layer_indices (`list`, *optional*):
  76. Specifies the layer indices that will have full attention. Must contain values at most num_hidden_layers.
  77. mamba_n_heads (`int`, *optional*, defaults to 128):
  78. The number of mamba heads used in the v2 implementation.
  79. mamba_d_head (`int`, *optional*, defaults to `"auto"`):
  80. Head embedding dimension size
  81. mamba_n_groups (`int`, *optional*, defaults to 1):
  82. The number of the mamba groups used in the v2 implementation.
  83. mamba_d_state (`int`, *optional*, defaults to 256):
  84. The dimension the mamba state space latents
  85. mamba_d_conv (`int`, *optional*, defaults to 4):
  86. The size of the mamba convolution kernel
  87. mamba_expand (`int`, *optional*, defaults to 2):
  88. Expanding factor (relative to hidden_size) used to determine the mamba intermediate size
  89. mamba_chunk_size (`int`, *optional*, defaults to 256):
  90. The chunks in which to break the sequence when doing prefill/training
  91. mamba_conv_bias (`bool`, *optional*, defaults to `True`):
  92. Flag indicating whether or not to use bias in the convolution layer of the mamba mixer block.
  93. mamba_proj_bias (`bool`, *optional*, defaults to `False`):
  94. Flag indicating whether or not to use bias in the input and output projections (["in_proj", "out_proj"]) of the mamba mixer block
  95. z_loss_coefficient (`float`, *optional*, defaults to 0.0):
  96. Coefficient for auxiliary z-loss used to control logit growth during training
  97. """
  98. model_type = "bamba"
  99. keys_to_ignore_at_inference = ["past_key_values"]
  100. def __init__(
  101. self,
  102. vocab_size=128000,
  103. tie_word_embeddings=False,
  104. hidden_size=4096,
  105. intermediate_size=14336,
  106. num_hidden_layers=32,
  107. num_attention_heads=32,
  108. num_key_value_heads=8,
  109. hidden_act="silu",
  110. initializer_range=0.02,
  111. rms_norm_eps=1e-5,
  112. use_cache=True,
  113. num_logits_to_keep=1,
  114. pad_token_id=0,
  115. bos_token_id=1,
  116. eos_token_id=2,
  117. max_position_embeddings=262144,
  118. attention_dropout=0.0,
  119. attn_layer_indices=None,
  120. mamba_n_heads=128,
  121. mamba_d_head="auto",
  122. mamba_n_groups=1,
  123. mamba_d_state=256,
  124. mamba_d_conv=4,
  125. mamba_expand=2,
  126. mamba_chunk_size=256,
  127. mamba_conv_bias=True,
  128. mamba_proj_bias=False,
  129. z_loss_coefficient=0.0,
  130. **kwargs,
  131. ):
  132. self.vocab_size = vocab_size
  133. self.tie_word_embeddings = tie_word_embeddings
  134. self.hidden_size = hidden_size
  135. self.intermediate_size = intermediate_size
  136. self.num_hidden_layers = num_hidden_layers
  137. self.num_attention_heads = num_attention_heads
  138. self.max_position_embeddings = max_position_embeddings
  139. self.attention_dropout = attention_dropout
  140. self.attention_bias = False
  141. self.mlp_bias = False
  142. # for backward compatibility
  143. if num_key_value_heads is None:
  144. num_key_value_heads = num_attention_heads
  145. self.num_key_value_heads = num_key_value_heads
  146. self.hidden_act = hidden_act
  147. self.initializer_range = initializer_range
  148. self.rms_norm_eps = rms_norm_eps
  149. self.use_cache = use_cache
  150. self.num_logits_to_keep = num_logits_to_keep
  151. self.attn_layer_indices = attn_layer_indices
  152. self.rope_theta = 10000.0
  153. self.rope_scaling = None
  154. self.partial_rotary_factor = 0.5
  155. mamba_intermediate = mamba_expand * hidden_size
  156. if mamba_intermediate % mamba_n_heads != 0:
  157. raise ValueError("mamba_n_heads must divide mamba_expand * hidden_size")
  158. # for the mamba_v2, must satisfy the following
  159. if mamba_d_head == "auto":
  160. mamba_d_head = mamba_intermediate // mamba_n_heads
  161. if mamba_d_head * mamba_n_heads != mamba_intermediate:
  162. raise ValueError("The dimensions for the Mamba head state do not match the model intermediate_size")
  163. self.mamba_n_heads = mamba_n_heads
  164. self.mamba_d_head = mamba_d_head
  165. self.mamba_n_groups = mamba_n_groups
  166. self.mamba_d_state = mamba_d_state
  167. self.mamba_d_conv = mamba_d_conv
  168. self.mamba_expand = mamba_expand
  169. self.mamba_chunk_size = mamba_chunk_size
  170. self.mamba_conv_bias = mamba_conv_bias
  171. self.mamba_proj_bias = mamba_proj_bias
  172. self.z_loss_coefficient = z_loss_coefficient
  173. super().__init__(
  174. pad_token_id=pad_token_id,
  175. bos_token_id=bos_token_id,
  176. eos_token_id=eos_token_id,
  177. tie_word_embeddings=tie_word_embeddings,
  178. **kwargs,
  179. )
  180. @property
  181. def layers_block_type(self):
  182. return [
  183. "attention" if (self.attn_layer_indices and i in self.attn_layer_indices) else "mamba"
  184. for i in range(self.num_hidden_layers)
  185. ]
  186. __all__ = ["BambaConfig"]