configuration_bitnet.py 6.5 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147
  1. # coding=utf-8
  2. # Copyright 2025 The BitNet Team and The HuggingFace Inc. team. All rights reserved.
  3. #
  4. # Licensed under the Apache License, Version 2.0 (the "License");
  5. # you may not use this file except in compliance with the License.
  6. # You may obtain a copy of the License at
  7. #
  8. # http://www.apache.org/licenses/LICENSE-2.0
  9. #
  10. # Unless required by applicable law or agreed to in writing, software
  11. # distributed under the License is distributed on an "AS IS" BASIS,
  12. # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  13. # See the License for the specific language governing permissions and
  14. """BitNet model configuration"""
  15. from ...configuration_utils import PretrainedConfig
  16. from ...utils import logging
  17. logger = logging.get_logger(__name__)
  18. class BitNetConfig(PretrainedConfig):
  19. r"""
  20. This is the configuration class to store the configuration of a [`BitNetModel`]. It is used to instantiate an BitNet
  21. model according to the specified arguments, defining the model architecture. Instantiating a configuration with the
  22. defaults will yield a similar configuration to that of
  23. BitNet b1.58 2B4T [microsoft/bitnet-b1.58-2B-4T](https://huggingface.co/microsoft/bitnet-b1.58-2B-4T).
  24. Configuration objects inherit from [`PretrainedConfig`] and can be used to control the model outputs. Read the
  25. documentation from [`PretrainedConfig`] for more information.
  26. Args:
  27. vocab_size (`int`, *optional*, defaults to 128256):
  28. Vocabulary size of the BitNet model. Defines the number of different tokens that can be represented by the
  29. `inputs_ids` passed when calling [`BitNetModel`]
  30. hidden_size (`int`, *optional*, defaults to 2560):
  31. Dimension of the hidden representations.
  32. intermediate_size (`int`, *optional*, defaults to 6912):
  33. Dimension of the MLP representations.
  34. num_hidden_layers (`int`, *optional*, defaults to 30):
  35. Number of hidden layers in the Transformer decoder.
  36. num_attention_heads (`int`, *optional*, defaults to 20):
  37. Number of attention heads for each attention layer in the Transformer decoder.
  38. num_key_value_heads (`int`, *optional*, defaults to 5):
  39. This is the number of key_value heads that should be used to implement Grouped Query Attention. If
  40. `num_key_value_heads=num_attention_heads`, the model will use Multi Head Attention (MHA), if
  41. `num_key_value_heads=1 the model will use Multi Query Attention (MQA) otherwise GQA is used. When
  42. converting a multi-head checkpoint to a GQA checkpoint, each group key and value head should be constructed
  43. by meanpooling all the original heads within that group. For more details, check out [this
  44. paper](https://huggingface.co/papers/2305.13245). If it is not specified, will default to
  45. `num_attention_heads`.
  46. hidden_act (`str` or `function`, *optional*, defaults to `"relu2"`):
  47. The non-linear activation function (function or string) in the decoder.
  48. max_position_embeddings (`int`, *optional*, defaults to 2048):
  49. The maximum sequence length that this model might ever be used with.
  50. initializer_range (`float`, *optional*, defaults to 0.02):
  51. The standard deviation of the truncated_normal_initializer for initializing all weight matrices.
  52. rms_norm_eps (`float`, *optional*, defaults to 1e-05):
  53. The epsilon used by the rms normalization layers.
  54. use_cache (`bool`, *optional*, defaults to `True`):
  55. Whether or not the model should return the last key/values attentions (not used by all models). Only
  56. relevant if `config.is_decoder=True`.
  57. pad_token_id (`int`, *optional*):
  58. Padding token id.
  59. bos_token_id (`int`, *optional*, defaults to 128000):
  60. Beginning of stream token id.
  61. eos_token_id (`int`, *optional*, defaults to 128001):
  62. End of stream token id.
  63. tie_word_embeddings (`bool`, *optional*, defaults to `False`):
  64. Whether to tie weight embeddings
  65. rope_theta (`float`, *optional*, defaults to 500000.0):
  66. The base period of the RoPE embeddings.
  67. attention_bias (`bool`, *optional*, defaults to `False`):
  68. Whether to use a bias in the query, key, value and output projection layers during self-attention.
  69. attention_dropout (`float`, *optional*, defaults to 0.0):
  70. The dropout ratio for the attention probabilities.
  71. ```python
  72. >>> from transformers import BitNetModel, BitNetConfig
  73. >>> # Initializing a BitNet style configuration
  74. >>> configuration = BitNetConfig()
  75. >>> # Initializing a model from the BitNet style configuration
  76. >>> model = BitNetModel(configuration)
  77. >>> # Accessing the model configuration
  78. >>> configuration = model.config
  79. ```"""
  80. model_type = "bitnet"
  81. keys_to_ignore_at_inference = ["past_key_values"]
  82. def __init__(
  83. self,
  84. vocab_size=128256,
  85. hidden_size=2560,
  86. intermediate_size=6912,
  87. num_hidden_layers=30,
  88. num_attention_heads=20,
  89. num_key_value_heads=5,
  90. hidden_act="relu2",
  91. max_position_embeddings=2048,
  92. initializer_range=0.02,
  93. rms_norm_eps=1e-5,
  94. use_cache=True,
  95. pad_token_id=None,
  96. bos_token_id=128000,
  97. eos_token_id=128001,
  98. tie_word_embeddings=False,
  99. rope_theta=500000.0,
  100. attention_bias=False,
  101. attention_dropout=0.0,
  102. **kwargs,
  103. ):
  104. self.vocab_size = vocab_size
  105. self.max_position_embeddings = max_position_embeddings
  106. self.hidden_size = hidden_size
  107. self.intermediate_size = intermediate_size
  108. self.num_hidden_layers = num_hidden_layers
  109. self.num_attention_heads = num_attention_heads
  110. # for backward compatibility
  111. if num_key_value_heads is None:
  112. num_key_value_heads = num_attention_heads
  113. self.num_key_value_heads = num_key_value_heads
  114. self.hidden_act = hidden_act
  115. self.initializer_range = initializer_range
  116. self.rms_norm_eps = rms_norm_eps
  117. self.use_cache = use_cache
  118. self.rope_theta = rope_theta
  119. self.attention_bias = attention_bias
  120. self.attention_dropout = attention_dropout
  121. super().__init__(
  122. pad_token_id=pad_token_id,
  123. bos_token_id=bos_token_id,
  124. eos_token_id=eos_token_id,
  125. tie_word_embeddings=tie_word_embeddings,
  126. **kwargs,
  127. )
  128. __all__ = ["BitNetConfig"]