configuration_dbrx.py 9.7 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232
  1. # coding=utf-8
  2. # Copyright 2024 Databricks Mosaic Research and The HuggingFace Inc. team. All rights reserved.
  3. #
  4. # Licensed under the Apache License, Version 2.0 (the "License");
  5. # you may not use this file except in compliance with the License.
  6. # You may obtain a copy of the License at
  7. #
  8. # http://www.apache.org/licenses/LICENSE-2.0
  9. #
  10. # Unless required by applicable law or agreed to in writing, software
  11. # distributed under the License is distributed on an "AS IS" BASIS,
  12. # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  13. # See the License for the specific language governing permissions and
  14. # limitations under the License.
  15. """DBRX model configuration"""
  16. from typing import Any, Optional
  17. from ...configuration_utils import PretrainedConfig
  18. from ...utils import logging
  19. logger = logging.get_logger(__name__)
  20. class DbrxAttentionConfig(PretrainedConfig):
  21. """Configuration class for Dbrx Attention.
  22. [`DbrxAttention`] class. It is used to instantiate attention layers
  23. according to the specified arguments, defining the layers architecture.
  24. Configuration objects inherit from [`PretrainedConfig`] and can be used to control the model outputs. Read the
  25. documentation from [`PretrainedConfig`] for more information.
  26. Args:
  27. attn_pdrop (`float`, *optional*, defaults to 0.0):
  28. The dropout probability for the attention layers.
  29. clip_qkv (`float`, *optional*):
  30. If set, clip the queries, keys, and values in the attention layer to this value.
  31. kv_n_heads (`int`, *optional*, defaults to 1): For grouped_query_attention only, allow user to specify number of kv heads.
  32. rope_theta (`float`, *optional*, defaults to 10000.0): The base frequency for rope.
  33. """
  34. base_config_key = "attn_config"
  35. def __init__(
  36. self,
  37. attn_pdrop: float = 0.0,
  38. clip_qkv: Optional[float] = None,
  39. kv_n_heads: int = 1,
  40. rope_theta: float = 10000.0,
  41. **kwargs: Any,
  42. ):
  43. super().__init__(**kwargs)
  44. self.attn_pdrop = attn_pdrop
  45. self.clip_qkv = clip_qkv
  46. self.kv_n_heads = kv_n_heads
  47. self.rope_theta = rope_theta
  48. for k in ["model_type", "attn_implementation", "transformers_version", "_commit_hash", "torch_dtype", "dtype"]:
  49. if k in kwargs:
  50. kwargs.pop(k)
  51. if len(kwargs) != 0:
  52. raise ValueError(f"Found unknown {kwargs=}")
  53. class DbrxFFNConfig(PretrainedConfig):
  54. """Configuration class for Dbrx FFN.
  55. [`DbrxFFN`] class. It is used to instantiate feedforward layers according to
  56. the specified arguments, defining the layers architecture.
  57. Configuration objects inherit from [`PretrainedConfig`] and can be used to control the model outputs. Read the
  58. documentation from [`PretrainedConfig`] for more information.
  59. Args:
  60. ffn_act_fn (`dict`, *optional*, defaults to `None`): A dict specifying activation function for the FFN.
  61. The dict should have a key 'name' with the value being the name of the activation function along with
  62. any additional keyword arguments. If `None`, then set to `{"name": "silu"}`.
  63. ffn_hidden_size (`int`, *optional*, defaults to 3584): The hidden size of the feedforward network.
  64. moe_num_experts (`int`, *optional*, defaults to 4): The number of experts in the mixture of experts layer.
  65. moe_top_k (`int`, *optional*, defaults to 1): The number of experts to use in the mixture of experts layer.
  66. moe_jitter_eps (`float`, *optional*, defaults to `None`): If not `None`, the jitter epsilon for the mixture of experts layer.
  67. moe_loss_weight (`float`, *optional*, defaults to 0.01): The loss weight for the mixture of experts layer.
  68. moe_normalize_expert_weights (`float`, *optional*, defaults to 1.0): The normalization factor for the expert weights.
  69. """
  70. base_config_key = "ffn_config"
  71. def __init__(
  72. self,
  73. ffn_act_fn: Optional[dict] = None,
  74. ffn_hidden_size: int = 3584,
  75. moe_num_experts: int = 4,
  76. moe_top_k: int = 1,
  77. moe_jitter_eps: Optional[float] = None,
  78. moe_loss_weight: float = 0.01,
  79. moe_normalize_expert_weights: Optional[float] = 1.0,
  80. **kwargs: Any,
  81. ):
  82. super().__init__()
  83. if ffn_act_fn is None:
  84. ffn_act_fn = {"name": "silu"}
  85. self.ffn_act_fn = ffn_act_fn
  86. self.ffn_hidden_size = ffn_hidden_size
  87. self.moe_num_experts = moe_num_experts
  88. self.moe_top_k = moe_top_k
  89. self.moe_jitter_eps = moe_jitter_eps
  90. self.moe_loss_weight = moe_loss_weight
  91. self.moe_normalize_expert_weights = moe_normalize_expert_weights
  92. for k in ["model_type", "attn_implementation", "transformers_version", "_commit_hash", "torch_dtype", "dtype"]:
  93. if k in kwargs:
  94. kwargs.pop(k)
  95. if len(kwargs) != 0:
  96. raise ValueError(f"Found unknown {kwargs=}")
  97. class DbrxConfig(PretrainedConfig):
  98. r"""
  99. This is the configuration class to store the configuration of a [`DbrxModel`]. It is used to instantiate a Dbrx model according to the
  100. specified arguments, defining the model architecture. Instantiating a configuration with the
  101. defaults will yield a different configuration to that of the [databricks/dbrx-instruct](https://huggingface.co/databricks/dbrx-instruct) architecture.
  102. Configuration objects inherit from [`PretrainedConfig`] and can be used to control the model outputs. Read the
  103. documentation from [`PretrainedConfig`] for more information.
  104. Args:
  105. d_model (`int`, *optional*, defaults to 2048):
  106. Dimensionality of the embeddings and hidden states.
  107. n_heads (`int`, *optional*, defaults to 16):
  108. Number of attention heads for each attention layer in the Transformer encoder.
  109. n_layers (`int`, *optional*, defaults to 24):
  110. Number of hidden layers in the Transformer encoder.
  111. max_seq_len (`int`, *optional*, defaults to 2048):
  112. The maximum sequence length of the model.
  113. vocab_size (`int`, *optional*, defaults to 32000):
  114. Vocabulary size of the Dbrx model. Defines the maximum number of different tokens that can be represented by
  115. the `inputs_ids` passed when calling [`DbrxModel`].
  116. resid_pdrop (`float`, *optional*, defaults to 0.0):
  117. The dropout probability applied to the attention output before combining with residual.
  118. emb_pdrop (`float`, *optional*, defaults to 0.0):
  119. The dropout probability for the embedding layer.
  120. attn_config (`dict`, *optional*):
  121. A dictionary used to configure the model's attention module.
  122. ffn_config (`dict`, *optional*):
  123. A dictionary used to configure the model's FFN module.
  124. use_cache (`bool`, *optional*, defaults to `True`):
  125. Whether or not the model should return the last key/values attentions (not used by all models).
  126. initializer_range (`float`, *optional*, defaults to 0.02):
  127. The standard deviation of the truncated_normal_initializer for initializing all weight matrices.
  128. output_router_logits (`bool`, *optional*, defaults to `False`):
  129. Whether or not the router logits should be returned by the model. Enabling this will also
  130. allow the model to output the auxiliary loss. See [here]() for more details.
  131. Example:
  132. ```python
  133. >>> from transformers import DbrxConfig, DbrxModel
  134. >>> # Initializing a Dbrx configuration
  135. >>> configuration = DbrxConfig(n_layers=2, d_model=256, n_heads=8, vocab_size=128)
  136. >>> # Initializing a model (with random weights) from the configuration
  137. >>> model = DbrxModel(configuration)
  138. >>> # Accessing the model configuration
  139. >>> configuration = model.config
  140. ```
  141. """
  142. model_type = "dbrx"
  143. sub_configs = {"attn_config": DbrxAttentionConfig, "ffn_config": DbrxFFNConfig}
  144. attribute_map = {
  145. "num_attention_heads": "n_heads",
  146. "hidden_size": "d_model",
  147. "num_hidden_layers": "n_layers",
  148. "max_position_embeddings": "max_seq_len",
  149. }
  150. def __init__(
  151. self,
  152. d_model: int = 2048,
  153. n_heads: int = 16,
  154. n_layers: int = 24,
  155. max_seq_len: int = 2048,
  156. vocab_size: int = 32000,
  157. resid_pdrop: float = 0.0,
  158. emb_pdrop: float = 0.0,
  159. attn_config: Optional[DbrxAttentionConfig] = None,
  160. ffn_config: Optional[DbrxFFNConfig] = None,
  161. use_cache: bool = True,
  162. initializer_range: float = 0.02,
  163. output_router_logits: bool = False,
  164. **kwargs: Any,
  165. ):
  166. if attn_config is None:
  167. self.attn_config = DbrxAttentionConfig()
  168. elif isinstance(attn_config, dict):
  169. self.attn_config = DbrxAttentionConfig(**attn_config)
  170. else:
  171. self.attn_config = attn_config
  172. if ffn_config is None:
  173. self.ffn_config = DbrxFFNConfig()
  174. elif isinstance(ffn_config, dict):
  175. self.ffn_config = DbrxFFNConfig(**ffn_config)
  176. else:
  177. self.ffn_config = ffn_config
  178. self.d_model = d_model
  179. self.n_heads = n_heads
  180. self.n_layers = n_layers
  181. self.max_seq_len = max_seq_len
  182. self.vocab_size = vocab_size
  183. self.resid_pdrop = resid_pdrop
  184. self.emb_pdrop = emb_pdrop
  185. self.use_cache = use_cache
  186. self.initializer_range = initializer_range
  187. self.output_router_logits = output_router_logits
  188. self.num_key_value_heads = self.attn_config.kv_n_heads
  189. tie_word_embeddings = kwargs.pop("tie_word_embeddings", False)
  190. if tie_word_embeddings:
  191. raise ValueError("tie_word_embeddings is not supported for DBRX models.")
  192. super().__init__(tie_word_embeddings=tie_word_embeddings, **kwargs)
  193. __all__ = ["DbrxConfig"]