configuration_dia.py 20 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376
  1. # coding=utf-8
  2. # Copyright 2025 The Nari Labs and HuggingFace Inc. team. All rights reserved.
  3. #
  4. # Licensed under the Apache License, Version 2.0 (the "License");
  5. # you may not use this file except in compliance with the License.
  6. # You may obtain a copy of the License at
  7. #
  8. # http://www.apache.org/licenses/LICENSE-2.0
  9. #
  10. # Unless required by applicable law or agreed to in writing, software
  11. # distributed under the License is distributed on an "AS IS" BASIS,
  12. # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  13. # See the License for the specific language governing permissions and
  14. # limitations under the License.
  15. """Dia model configuration"""
  16. from typing import Optional
  17. from ...configuration_utils import PretrainedConfig
  18. from ...modeling_rope_utils import rope_config_validation
  19. from ...utils import logging
  20. logger = logging.get_logger(__name__)
  21. class DiaEncoderConfig(PretrainedConfig):
  22. r"""
  23. This is the configuration class to store the configuration of a [`DiaEncoder`]. It is used to instantiate a Dia
  24. encoder according to the specified arguments, defining the encoder architecture.
  25. Configuration objects inherit from [`PretrainedConfig`] and can be used to control the model outputs. Read the
  26. documentation from [`PretrainedConfig`] for more information.
  27. Args:
  28. max_position_embeddings (`int`, *optional*, defaults to 1024):
  29. The maximum sequence length that this model might ever be used with.
  30. num_hidden_layers (`int`, *optional*, defaults to 12):
  31. Number of hidden layers in the Transformer encoder.
  32. hidden_size (`int`, *optional*, defaults to 1024):
  33. Dimensionality of the encoder layers and the pooler layer.
  34. num_attention_heads (`int`, *optional*, defaults to 16):
  35. Number of attention heads for each attention layer in the Transformer encoder.
  36. num_key_value_heads (`int`, *optional*, defaults to 16):
  37. Number of key and value heads for each attention layer in the Transformer encoder.
  38. head_dim (`int`, *optional*, defaults to 128):
  39. Dimensionality of the attention head.
  40. intermediate_size (`int`, *optional*, defaults to 4096):
  41. Dimensionality of the "intermediate" (often named feed-forward) layer in the Transformer encoder.
  42. norm_eps (`float`, *optional*, defaults to 1e-05):
  43. The epsilon used by the normalization layers.
  44. vocab_size (`int`, *optional*, defaults to 256):
  45. Vocabulary size of the Dia model. Defines the number of different tokens that can be represented by the
  46. `inputs_ids` passed when calling [`DiaModel`].
  47. hidden_act (`str` or `function`, *optional*, defaults to `"silu"`):
  48. The non-linear activation function (function or string) in the encoder and pooler. If string, `"gelu"`,
  49. `"relu"`, `"swish"` and `"gelu_new"` are supported.
  50. rope_theta (`float`, *optional*, defaults to 10000.0):
  51. The base period of the RoPE embeddings.
  52. rope_scaling (`dict`, *optional*):
  53. Dictionary containing the scaling configuration for the RoPE embeddings. NOTE: if you apply new rope type
  54. and you expect the model to work on longer `max_position_embeddings`, we recommend you to update this value
  55. accordingly.
  56. Expected contents:
  57. `rope_type` (`str`):
  58. The sub-variant of RoPE to use. Can be one of ['default', 'linear', 'dynamic', 'yarn', 'longrope',
  59. 'llama3'], with 'default' being the original RoPE implementation.
  60. `factor` (`float`, *optional*):
  61. Used with all rope types except 'default'. The scaling factor to apply to the RoPE embeddings. In
  62. most scaling types, a `factor` of x will enable the model to handle sequences of length x *
  63. original maximum pre-trained length.
  64. `original_max_position_embeddings` (`int`, *optional*):
  65. Used with 'dynamic', 'longrope' and 'llama3'. The original max position embeddings used during
  66. pretraining.
  67. `attention_factor` (`float`, *optional*):
  68. Used with 'yarn' and 'longrope'. The scaling factor to be applied on the attention
  69. computation. If unspecified, it defaults to value recommended by the implementation, using the
  70. `factor` field to infer the suggested value.
  71. `beta_fast` (`float`, *optional*):
  72. Only used with 'yarn'. Parameter to set the boundary for extrapolation (only) in the linear
  73. ramp function. If unspecified, it defaults to 32.
  74. `beta_slow` (`float`, *optional*):
  75. Only used with 'yarn'. Parameter to set the boundary for interpolation (only) in the linear
  76. ramp function. If unspecified, it defaults to 1.
  77. `short_factor` (`List[float]`, *optional*):
  78. Only used with 'longrope'. The scaling factor to be applied to short contexts (<
  79. `original_max_position_embeddings`). Must be a list of numbers with the same length as the hidden
  80. size divided by the number of attention heads divided by 2
  81. `long_factor` (`List[float]`, *optional*):
  82. Only used with 'longrope'. The scaling factor to be applied to long contexts (<
  83. `original_max_position_embeddings`). Must be a list of numbers with the same length as the hidden
  84. size divided by the number of attention heads divided by 2
  85. `low_freq_factor` (`float`, *optional*):
  86. Only used with 'llama3'. Scaling factor applied to low frequency components of the RoPE
  87. `high_freq_factor` (`float`, *optional*):
  88. Only used with 'llama3'. Scaling factor applied to high frequency components of the RoPE
  89. initializer_range (`float`, *optional*, defaults to 0.02):
  90. The standard deviation of the truncated_normal_initializer for initializing all weight matrices.
  91. """
  92. model_type = "dia_encoder"
  93. def __init__(
  94. self,
  95. max_position_embeddings: int = 1024,
  96. num_hidden_layers: int = 12,
  97. hidden_size: int = 1024,
  98. num_attention_heads: int = 16,
  99. num_key_value_heads: int = 16,
  100. head_dim: int = 128,
  101. intermediate_size: int = 4096,
  102. norm_eps: float = 1e-5,
  103. vocab_size: int = 256,
  104. hidden_act: str = "silu",
  105. rope_theta: float = 10000.0,
  106. rope_scaling: Optional[dict] = None,
  107. initializer_range: float = 0.02,
  108. **kwargs,
  109. ):
  110. self.max_position_embeddings = max_position_embeddings
  111. self.num_hidden_layers = num_hidden_layers
  112. self.hidden_size = hidden_size
  113. self.intermediate_size = intermediate_size
  114. self.num_attention_heads = num_attention_heads
  115. self.head_dim = head_dim
  116. self.norm_eps = norm_eps
  117. self.vocab_size = vocab_size
  118. self.num_key_value_heads = num_key_value_heads
  119. self.hidden_act = hidden_act
  120. self.rope_theta = rope_theta
  121. self.rope_scaling = rope_scaling
  122. # Validate the correctness of rotary position embeddings parameters
  123. # BC: if there is a 'type' field, copy it it to 'rope_type'.
  124. if self.rope_scaling is not None and "type" in self.rope_scaling:
  125. self.rope_scaling["rope_type"] = self.rope_scaling["type"]
  126. rope_config_validation(self)
  127. self.initializer_range = initializer_range
  128. super().__init__(**kwargs)
  129. class DiaDecoderConfig(PretrainedConfig):
  130. r"""
  131. This is the configuration class to store the configuration of a [`DiaDecoder`]. It is used to instantiate a Dia
  132. decoder according to the specified arguments, defining the decoder architecture.
  133. Configuration objects inherit from [`PretrainedConfig`] and can be used to control the model outputs. Read the
  134. documentation from [`PretrainedConfig`] for more information.
  135. Args:
  136. max_position_embeddings (`int`, *optional*, defaults to 3072):
  137. The maximum sequence length that this model might ever be used with.
  138. num_hidden_layers (`int`, *optional*, defaults to 18):
  139. Number of hidden layers in the Transformer decoder.
  140. hidden_size (`int`, *optional*, defaults to 2048):
  141. Dimensionality of the decoder layers and the pooler layer.
  142. intermediate_size (`int`, *optional*, defaults to 8192):
  143. Dimensionality of the "intermediate" (often named feed-forward) layer in the Transformer decoder.
  144. num_attention_heads (`int`, *optional*, defaults to 16):
  145. Number of attention heads for each attention layer in the Transformer decoder.
  146. num_key_value_heads (`int`, *optional*, defaults to 4):
  147. Number of key and value heads for each attention layer in the Transformer decoder.
  148. head_dim (`int`, *optional*, defaults to 128):
  149. Dimensionality of the attention head.
  150. cross_num_attention_heads (`int`, *optional*, defaults to 16):
  151. Number of attention heads for each cross-attention layer in the Transformer decoder.
  152. cross_head_dim (`int`, *optional*, defaults to 128):
  153. Dimensionality of the cross-attention head.
  154. cross_num_key_value_heads (`int`, *optional*, defaults to 16):
  155. Number of key and value heads for each cross-attention layer in the Transformer decoder.
  156. cross_hidden_size (`int`, *optional*, defaults to 1024):
  157. Dimensionality of the cross-attention layers.
  158. norm_eps (`float`, *optional*, defaults to 1e-05):
  159. The epsilon used by the normalization layers.
  160. vocab_size (`int`, *optional*, defaults to 1028):
  161. Vocabulary size of the Dia model. Defines the number of different tokens that can be represented by the
  162. `inputs_ids` passed when calling [`DiaModel`].
  163. hidden_act (`str` or `function`, *optional*, defaults to `"silu"`):
  164. The non-linear activation function (function or string) in the decoder. If string, `"gelu"`, `"relu"`,
  165. `"swish"` and `"gelu_new"` are supported.
  166. num_channels (`int`, *optional*, defaults to 9):
  167. Number of channels for the Dia decoder.
  168. rope_theta (`float`, *optional*, defaults to 10000.0):
  169. The base period of the RoPE embeddings.
  170. rope_scaling (`dict`, *optional*):
  171. Dictionary containing the scaling configuration for the RoPE embeddings. NOTE: if you apply new rope type
  172. and you expect the model to work on longer `max_position_embeddings`, we recommend you to update this value
  173. accordingly.
  174. Expected contents:
  175. `rope_type` (`str`):
  176. The sub-variant of RoPE to use. Can be one of ['default', 'linear', 'dynamic', 'yarn', 'longrope',
  177. 'llama3'], with 'default' being the original RoPE implementation.
  178. `factor` (`float`, *optional*):
  179. Used with all rope types except 'default'. The scaling factor to apply to the RoPE embeddings. In
  180. most scaling types, a `factor` of x will enable the model to handle sequences of length x *
  181. original maximum pre-trained length.
  182. `original_max_position_embeddings` (`int`, *optional*):
  183. Used with 'dynamic', 'longrope' and 'llama3'. The original max position embeddings used during
  184. pretraining.
  185. `attention_factor` (`float`, *optional*):
  186. Used with 'yarn' and 'longrope'. The scaling factor to be applied on the attention
  187. computation. If unspecified, it defaults to value recommended by the implementation, using the
  188. `factor` field to infer the suggested value.
  189. `beta_fast` (`float`, *optional*):
  190. Only used with 'yarn'. Parameter to set the boundary for extrapolation (only) in the linear
  191. ramp function. If unspecified, it defaults to 32.
  192. `beta_slow` (`float`, *optional*):
  193. Only used with 'yarn'. Parameter to set the boundary for interpolation (only) in the linear
  194. ramp function. If unspecified, it defaults to 1.
  195. `short_factor` (`List[float]`, *optional*):
  196. Only used with 'longrope'. The scaling factor to be applied to short contexts (<
  197. `original_max_position_embeddings`). Must be a list of numbers with the same length as the hidden
  198. size divided by the number of attention heads divided by 2
  199. `long_factor` (`List[float]`, *optional*):
  200. Only used with 'longrope'. The scaling factor to be applied to long contexts (<
  201. `original_max_position_embeddings`). Must be a list of numbers with the same length as the hidden
  202. size divided by the number of attention heads divided by 2
  203. `low_freq_factor` (`float`, *optional*):
  204. Only used with 'llama3'. Scaling factor applied to low frequency components of the RoPE
  205. `high_freq_factor` (`float`, *optional*):
  206. Only used with 'llama3'. Scaling factor applied to high frequency components of the RoPE
  207. initializer_range (`float`, *optional*, defaults to 0.02):
  208. The standard deviation of the truncated_normal_initializer for initializing all weight matrices.
  209. use_cache (`bool`, *optional*, defaults to `True`):
  210. Whether or not the model should return the last key/values attentions (not used by all models).
  211. is_encoder_decoder (`bool`, *optional*, defaults to `True`):
  212. Indicating that this model is part of an encoder-decoder architecture.
  213. """
  214. model_type = "dia_decoder"
  215. def __init__(
  216. self,
  217. max_position_embeddings: int = 3072,
  218. num_hidden_layers: int = 18,
  219. hidden_size: int = 2048,
  220. intermediate_size: int = 8192,
  221. num_attention_heads: int = 16,
  222. num_key_value_heads: int = 4,
  223. head_dim: int = 128,
  224. cross_num_attention_heads: int = 16,
  225. cross_head_dim: int = 128,
  226. cross_num_key_value_heads: int = 16,
  227. cross_hidden_size: int = 1024,
  228. norm_eps: float = 1e-5,
  229. vocab_size: int = 1028,
  230. hidden_act: str = "silu",
  231. num_channels: int = 9,
  232. rope_theta: float = 10000.0,
  233. rope_scaling: Optional[dict] = None,
  234. initializer_range: float = 0.02,
  235. use_cache: bool = True,
  236. is_encoder_decoder: bool = True,
  237. **kwargs,
  238. ):
  239. self.max_position_embeddings = max_position_embeddings
  240. self.num_hidden_layers = num_hidden_layers
  241. self.hidden_size = hidden_size
  242. self.intermediate_size = intermediate_size
  243. self.num_attention_heads = num_attention_heads
  244. self.num_key_value_heads = num_key_value_heads
  245. self.head_dim = head_dim
  246. self.cross_num_key_value_heads = cross_num_key_value_heads
  247. self.cross_num_attention_heads = cross_num_attention_heads
  248. self.cross_head_dim = cross_head_dim
  249. self.cross_hidden_size = cross_hidden_size
  250. self.norm_eps = norm_eps
  251. self.vocab_size = vocab_size
  252. self.hidden_act = hidden_act
  253. self.num_channels = num_channels
  254. self.rope_theta = rope_theta
  255. self.rope_scaling = rope_scaling
  256. # Validate the correctness of rotary position embeddings parameters
  257. # BC: if there is a 'type' field, copy it it to 'rope_type'.
  258. if self.rope_scaling is not None and "type" in self.rope_scaling:
  259. self.rope_scaling["rope_type"] = self.rope_scaling["type"]
  260. rope_config_validation(self)
  261. self.initializer_range = initializer_range
  262. self.use_cache = use_cache
  263. super().__init__(is_encoder_decoder=is_encoder_decoder, **kwargs)
  264. class DiaConfig(PretrainedConfig):
  265. r"""
  266. This is the configuration class to store the configuration of a [`DiaModel`]. It is used to instantiate a
  267. Dia model according to the specified arguments, defining the model architecture. Instantiating a configuration
  268. with the defaults will yield a similar configuration to that of the
  269. [nari-labs/Dia-1.6B](https://huggingface.co/nari-labs/Dia-1.6B) architecture.
  270. Configuration objects inherit from [`PretrainedConfig`] and can be used to control the model outputs. Read the
  271. documentation from [`PretrainedConfig`] for more information.
  272. Args:
  273. encoder_config (`DiaEncoderConfig`, *optional*):
  274. Configuration for the encoder part of the model. If not provided, a default `DiaEncoderConfig` will be used.
  275. decoder_config (`DiaDecoderConfig`, *optional*):
  276. Configuration for the decoder part of the model. If not provided, a default `DiaDecoderConfig` will be used.
  277. norm_eps (`float`, *optional*, defaults to 1e-05):
  278. The epsilon used by the normalization layers.
  279. is_encoder_decoder (`bool`, *optional*, defaults to `True`):
  280. Indicating that this model uses an encoder-decoder architecture.
  281. pad_token_id (`int`, *optional*, defaults to 1025):
  282. Padding token id.
  283. eos_token_id (`int`, *optional*, defaults to 1024):
  284. End of stream token id.
  285. bos_token_id (`int`, *optional*, defaults to 1026):
  286. Beginning of stream token id.
  287. delay_pattern (`list[int]`, *optional*, defaults to `[0, 8, 9, 10, 11, 12, 13, 14, 15]`):
  288. The delay pattern for the decoder. The length of this list must match `decoder_config.num_channels`.
  289. initializer_range (`float`, *optional*, defaults to 0.02):
  290. The standard deviation of the truncated_normal_initializer for initializing all weight matrices.
  291. use_cache (`bool`, *optional*, defaults to `True`):
  292. Whether or not the model should return the last key/values attentions (not used by all models).
  293. Example:
  294. ```python
  295. >>> from transformers import DiaConfig, DiaModel
  296. >>> # Initializing a DiaConfig with default values
  297. >>> configuration = DiaConfig()
  298. >>> # Initializing a DiaModel (with random weights) from the configuration
  299. >>> model = DiaModel(configuration)
  300. >>> # Accessing the model configuration
  301. >>> configuration = model.config
  302. ```
  303. """
  304. model_type = "dia"
  305. keys_to_ignore_at_inference = ["past_key_values"]
  306. sub_configs = {"encoder_config": DiaEncoderConfig, "decoder_config": DiaDecoderConfig}
  307. def __init__(
  308. self,
  309. encoder_config: Optional[DiaEncoderConfig] = None,
  310. decoder_config: Optional[DiaDecoderConfig] = None,
  311. norm_eps: float = 1e-5,
  312. is_encoder_decoder: bool = True,
  313. pad_token_id: int = 1025,
  314. eos_token_id: int = 1024,
  315. bos_token_id: int = 1026,
  316. delay_pattern: Optional[list[int]] = None,
  317. initializer_range: float = 0.02,
  318. use_cache: bool = True,
  319. **kwargs,
  320. ):
  321. if isinstance(encoder_config, dict):
  322. encoder_config = DiaEncoderConfig(**encoder_config)
  323. if isinstance(decoder_config, dict):
  324. decoder_config = DiaDecoderConfig(**decoder_config)
  325. self.encoder_config = encoder_config if encoder_config is not None else DiaEncoderConfig()
  326. self.decoder_config = decoder_config if decoder_config is not None else DiaDecoderConfig()
  327. self.norm_eps = norm_eps
  328. self.delay_pattern = delay_pattern if delay_pattern is not None else [0, 8, 9, 10, 11, 12, 13, 14, 15]
  329. self.initializer_range = initializer_range
  330. self.use_cache = use_cache
  331. assert self.decoder_config.num_channels == len(self.delay_pattern), (
  332. "Number of channels must match delay pattern length."
  333. )
  334. super().__init__(
  335. pad_token_id=pad_token_id,
  336. eos_token_id=eos_token_id,
  337. bos_token_id=bos_token_id,
  338. is_encoder_decoder=is_encoder_decoder,
  339. **kwargs,
  340. )
  341. def get_text_config(self, *args, **kwargs):
  342. """Defaulting to audio config as it's the decoder in this case which is usually the text backbone"""
  343. return self.decoder_config
  344. __all__ = ["DiaConfig", "DiaEncoderConfig", "DiaDecoderConfig"]