configuration_moshi.py 16 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333
  1. # coding=utf-8
  2. # Copyright 2024 Meta AI and The HuggingFace Inc. team. All rights reserved.
  3. #
  4. # Licensed under the Apache License, Version 2.0 (the "License");
  5. # you may not use this file except in compliance with the License.
  6. # You may obtain a copy of the License at
  7. #
  8. # http://www.apache.org/licenses/LICENSE-2.0
  9. #
  10. # Unless required by applicable law or agreed to in writing, software
  11. # distributed under the License is distributed on an "AS IS" BASIS,
  12. # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  13. # See the License for the specific language governing permissions and
  14. # limitations under the License.
  15. """Moshi model configuration"""
  16. from ...configuration_utils import PretrainedConfig
  17. from ...utils import logging
  18. from ..auto.configuration_auto import AutoConfig
  19. logger = logging.get_logger(__name__)
  20. class MoshiDepthConfig(PretrainedConfig):
  21. r"""
  22. This is the configuration class to store the configuration of a [`MoshiDepthDecoder`]. It is used to instantiate a
  23. Moshi depth decoder model according to the specified arguments, defining the Moshi depth decoder config.
  24. Configuration objects inherit from [`PretrainedConfig`] and can be used to control the model outputs. Read the
  25. documentation from [`PretrainedConfig`] for more information.
  26. Args:
  27. vocab_size (`int`, *optional*, defaults to 32000):
  28. Vocabulary size of the MoshiDepthDecoder model. Defines the number of different tokens that can be
  29. represented by the `inputs_ids` passed when calling [`MoshiDepthDecoder`].
  30. hidden_size (`int`, *optional*, defaults to 1024):
  31. Dimensionality of the layers and the pooler layer of the depth decoder.
  32. input_size (`int`, *optional*, defaults to 4096):
  33. Dimensionality of the input hidden states. Used to connect the main decoder to the depth decoder.
  34. num_hidden_layers (`int`, *optional*, defaults to 6):
  35. Number of depth decoder layers.
  36. num_attention_heads (`int`, *optional*, defaults to 16):
  37. Number of attention heads for each attention layer in the depth decoder block.
  38. num_key_value_heads (`int`, *optional*):
  39. This is the number of key_value heads that should be used to implement Grouped Query Attention. If
  40. `num_key_value_heads=num_attention_heads`, the model will use Multi Head Attention (MHA), if
  41. `num_key_value_heads=1` the model will use Multi Query Attention (MQA) otherwise GQA is used. When
  42. converting a multi-head checkpoint to a GQA checkpoint, each group key and value head should be constructed
  43. by meanpooling all the original heads within that group. For more details, check out [this
  44. paper](https://huggingface.co/papers/2305.13245). If it is not specified, will default to `num_attention_heads`.
  45. audio_vocab_size (`int`, *optional*, defaults to 2048):
  46. Vocabulary size of the audio part of model. Defines the number of different tokens that can be
  47. represented by the `audio_codes` passed when calling the Moshi models.
  48. max_position_embeddings (`int`, *optional*, defaults to 9):
  49. The maximum sequence length that this model might ever be used with. Typically, set this to something large
  50. just in case (e.g., 512 or 1024 or 2048).
  51. hidden_act (`str` or `function`, *optional*, defaults to `"silu"`):
  52. The non-linear activation function (function or string) in the depth decoder.
  53. head_dim (`int`, *optional*, defaults to `hidden_size // num_attention_heads`):
  54. The attention head dimension.
  55. initializer_range (`float`, *optional*, defaults to 0.02):
  56. The standard deviation of the truncated_normal_initializer for initializing all weight matrices.
  57. use_cache (`bool`, *optional*, defaults to `True`):
  58. Whether or not the model should return the last key/values attentions (not used by all models). Only
  59. relevant if `config.is_decoder=True`.
  60. sliding_window (`int`, *optional*, defaults to 8):
  61. Sliding window attention window size. If not specified, will default to `8`.
  62. attention_dropout (`float`, *optional*, defaults to 0.0):
  63. The dropout ratio for the attention probabilities.
  64. ffn_dim (`int`, *optional*, defaults to 5632):
  65. Dimensionality of the "intermediate" (often named feed-forward) layer in the depth decoder block. Must be even.
  66. rms_norm_eps (`float`, *optional*, defaults to 1e-08):
  67. The epsilon used by the rms normalization layers.
  68. num_codebooks (`int`, *optional*, defaults to 8):
  69. The number of audio codebooks for each audio channels.
  70. tie_word_embeddings (`bool`, *optional*, defaults to `False`):
  71. Whether to tie weight embeddings
  72. kwargs (*optional*):
  73. Dictionary of keyword arguments. Notably:
  74. - **audio_encoder_config** ([`PretrainedConfig`], *optional*) -- An instance of a configuration object that
  75. defines the audio encoder config.
  76. Example:
  77. ```python
  78. >>> from transformers import (
  79. ... MoshiDepthConfig,
  80. ... MoshiDepthDecoder,
  81. ... )
  82. >>> configuration = MoshiDepthConfig()
  83. >>> # Initializing a MoshiDepthDecoder (with random weights) from the kmhf/hf-moshiko style configuration
  84. >>> model = MoshiDepthDecoder(configuration)
  85. >>> # Accessing the model configuration
  86. >>> configuration = model.config
  87. ```"""
  88. model_type = "moshi_depth"
  89. keys_to_ignore_at_inference = ["past_key_values"]
  90. def __init__(
  91. self,
  92. vocab_size=32000,
  93. hidden_size=1024,
  94. input_size=4096,
  95. num_hidden_layers=6,
  96. num_attention_heads=16,
  97. num_key_value_heads=None,
  98. audio_vocab_size=2048,
  99. max_position_embeddings=9,
  100. hidden_act="silu",
  101. head_dim=None,
  102. initializer_range=0.02,
  103. use_cache=True,
  104. sliding_window=8,
  105. attention_dropout=0.0,
  106. ffn_dim=5632,
  107. rms_norm_eps=1e-8,
  108. num_codebooks=8,
  109. tie_word_embeddings=False,
  110. **kwargs,
  111. ):
  112. self.vocab_size = vocab_size
  113. self.hidden_size = hidden_size
  114. self.input_size = input_size
  115. self.num_hidden_layers = num_hidden_layers
  116. self.num_attention_heads = num_attention_heads
  117. self.num_key_value_heads = num_key_value_heads if num_key_value_heads is not None else num_attention_heads
  118. self.max_position_embeddings = max_position_embeddings
  119. self.hidden_act = hidden_act
  120. self.head_dim = head_dim or hidden_size // num_attention_heads
  121. self.initializer_range = initializer_range
  122. self.use_cache = use_cache
  123. self.sliding_window = sliding_window
  124. self.attention_dropout = attention_dropout
  125. if ffn_dim % 2 == 1:
  126. raise ValueError(f"`ffn_dim={ffn_dim}` must be even.")
  127. self.ffn_dim = ffn_dim
  128. self.rms_norm_eps = rms_norm_eps
  129. self.num_codebooks = num_codebooks
  130. self.audio_vocab_size = audio_vocab_size
  131. super().__init__(tie_word_embeddings=tie_word_embeddings, **kwargs)
  132. class MoshiConfig(PretrainedConfig):
  133. r"""
  134. This is the configuration class to store the configuration of a [`MoshiModel`]. It is used to instantiate a
  135. Moshi model according to the specified arguments, defining the audio encoder, Moshi depth decoder and Moshi decoder
  136. configs. Instantiating a configuration with the defaults will yield a similar configuration to that of the Moshiko model,
  137. e.g. [kmhf/hf-moshiko](https://huggingface.co/kmhf/hf-moshiko)
  138. Configuration objects inherit from [`PretrainedConfig`] and can be used to control the model outputs. Read the
  139. documentation from [`PretrainedConfig`] for more information.
  140. Args:
  141. vocab_size (`int`, *optional*, defaults to 32000):
  142. Vocabulary size of the MoshiDecoder model. Defines the number of different tokens that can be
  143. represented by the `inputs_ids` passed when calling [`MoshiDecoder`].
  144. hidden_size (`int`, *optional*, defaults to 4096):
  145. Dimensionality of the layers and the pooler layer of the main decoder.
  146. num_hidden_layers (`int`, *optional*, defaults to 32):
  147. Number of decoder layers.
  148. num_attention_heads (`int`, *optional*, defaults to 32):
  149. Number of attention heads for each attention layer in the main decoder block.
  150. num_key_value_heads (`int`, *optional*):
  151. This is the number of key_value heads that should be used to implement Grouped Query Attention. If
  152. `num_key_value_heads=num_attention_heads`, the model will use Multi Head Attention (MHA), if
  153. `num_key_value_heads=1` the model will use Multi Query Attention (MQA) otherwise GQA is used. When
  154. converting a multi-head checkpoint to a GQA checkpoint, each group key and value head should be constructed
  155. by meanpooling all the original heads within that group. For more details, check out [this
  156. paper](https://huggingface.co/papers/2305.13245). If it is not specified, will default to `num_attention_heads`.
  157. audio_vocab_size (`int`, *optional*):
  158. Vocabulary size of the audio part of model. Defines the number of different tokens that can be
  159. represented by the `audio_codes` passed when calling the Moshi models.
  160. max_position_embeddings (`int`, *optional*, defaults to 3000):
  161. The maximum sequence length that this model might ever be used with. Typically, set this to something large
  162. just in case (e.g., 512 or 1024 or 2048).
  163. rope_theta (`float`, *optional*, defaults to 10000.0):
  164. The base period of the RoPE embeddings.
  165. hidden_act (`str` or `function`, *optional*, defaults to `"silu"`):
  166. The non-linear activation function (function or string) in the decoder.
  167. head_dim (`int`, *optional*, defaults to `hidden_size // num_attention_heads`):
  168. The attention head dimension.
  169. initializer_range (`float`, *optional*, defaults to 0.02):
  170. The standard deviation of the truncated_normal_initializer for initializing all weight matrices.
  171. use_cache (`bool`, *optional*, defaults to `True`):
  172. Whether or not the model should return the last key/values attentions (not used by all models). Only
  173. relevant if `config.is_decoder=True`.
  174. sliding_window (`int`, *optional*, defaults to 3000):
  175. Sliding window attention window size. If not specified, will default to `3000`.
  176. attention_dropout (`float`, *optional*, defaults to 0.0):
  177. The dropout ratio for the attention probabilities.
  178. ffn_dim (`int`, *optional*, defaults to 22528):
  179. Dimensionality of the "intermediate" (often named feed-forward) layer in the main decoder block. Must be even.
  180. rms_norm_eps (`float`, *optional*, defaults to 1e-08):
  181. The epsilon used by the rms normalization layers.
  182. num_codebooks (`int`, *optional*, defaults to 8):
  183. The number of audio codebooks for each audio channels.
  184. tie_word_embeddings (`bool`, *optional*, defaults to `False`):
  185. Whether to tie weight embeddings
  186. kwargs (*optional*):
  187. Dictionary of keyword arguments. Notably:
  188. - **audio_encoder_config** ([`PretrainedConfig`], *optional*) -- An instance of a configuration object that
  189. defines the audio encoder config.
  190. - **depth__config** ([`PretrainedConfig`], *optional*) -- An instance of a configuration object that
  191. defines the depth decoder config.
  192. Example:
  193. ```python
  194. >>> from transformers import (
  195. ... MoshiConfig,
  196. ... MoshiForConditionalGeneration,
  197. ... )
  198. >>> configuration = MoshiConfig()
  199. >>> # Initializing a MoshiForConditionalGeneration (with random weights) from the kmhf/hf-moshiko style configuration
  200. >>> model = MoshiForConditionalGeneration(configuration)
  201. >>> # Accessing the model configuration
  202. >>> configuration = model.config
  203. >>> # Saving the model, including its configuration
  204. >>> model.save_pretrained("kmhf/hf-moshiko")
  205. >>> # loading model and config from pretrained folder
  206. >>> moshi_config = MoshiConfig.from_pretrained("kmhf/hf-moshiko")
  207. >>> model = MoshiForConditionalGeneration.from_pretrained("kmhf/hf-moshiko", config=moshi_config)
  208. ```"""
  209. model_type = "moshi"
  210. keys_to_ignore_at_inference = ["past_key_values"]
  211. sub_configs = {"audio_encoder_config": AutoConfig, "depth_decoder_config": MoshiDepthConfig}
  212. def __init__(
  213. self,
  214. vocab_size=32000,
  215. hidden_size=4096,
  216. num_hidden_layers=32,
  217. num_attention_heads=32,
  218. num_key_value_heads=None,
  219. audio_vocab_size=None,
  220. max_position_embeddings=3000,
  221. rope_theta=10000.0,
  222. hidden_act="silu",
  223. head_dim=None,
  224. initializer_range=0.02,
  225. use_cache=True,
  226. sliding_window=3000,
  227. attention_dropout=0.0,
  228. ffn_dim=22528,
  229. rms_norm_eps=1e-8,
  230. num_codebooks=8,
  231. tie_word_embeddings=False,
  232. **kwargs,
  233. ):
  234. self.vocab_size = vocab_size
  235. self.hidden_size = hidden_size
  236. self.num_hidden_layers = num_hidden_layers
  237. self.num_attention_heads = num_attention_heads
  238. self.num_key_value_heads = num_key_value_heads if num_key_value_heads is not None else num_attention_heads
  239. self.max_position_embeddings = max_position_embeddings
  240. self.rope_theta = rope_theta
  241. self.hidden_act = hidden_act
  242. self.head_dim = head_dim or hidden_size // num_attention_heads
  243. self.initializer_range = initializer_range
  244. self.use_cache = use_cache
  245. self.sliding_window = sliding_window
  246. self.attention_dropout = attention_dropout
  247. if ffn_dim % 2 == 1:
  248. raise ValueError(f"`ffn_dim={ffn_dim}` must be even.")
  249. self.ffn_dim = ffn_dim
  250. self.rms_norm_eps = rms_norm_eps
  251. self.num_codebooks = num_codebooks
  252. audio_encoder_config = kwargs.pop("audio_encoder_config", {})
  253. audio_encoder_model_type = audio_encoder_config.pop("model_type", "mimi")
  254. self.audio_encoder_config = AutoConfig.for_model(audio_encoder_model_type, **audio_encoder_config)
  255. if self.num_codebooks > self.audio_encoder_config.num_codebooks:
  256. raise ValueError(
  257. f"`num_codebooks={num_codebooks}` is greater than the maximum number of codebooks that the audio encoder can deal with ({self.audio_encoder_config.num_codebooks}). Please lower it."
  258. )
  259. self.audio_vocab_size = (
  260. self.audio_encoder_config.codebook_size if audio_vocab_size is None else audio_vocab_size
  261. )
  262. depth_decoder_config = kwargs.pop("depth_decoder_config", {})
  263. depth_decoder_config.update(
  264. {
  265. "audio_vocab_size": self.audio_vocab_size,
  266. "input_size": hidden_size,
  267. "vocab_size": vocab_size,
  268. "num_codebooks": num_codebooks,
  269. }
  270. )
  271. self.depth_decoder_config = MoshiDepthConfig(**depth_decoder_config)
  272. super().__init__(tie_word_embeddings=tie_word_embeddings, **kwargs)
  273. @property
  274. def sampling_rate(self):
  275. return self.audio_encoder_config.sampling_rate
  276. @classmethod
  277. def from_audio_encoder_config(
  278. cls,
  279. audio_encoder_config: PretrainedConfig,
  280. **kwargs,
  281. ):
  282. r"""
  283. Instantiate a [`MoshiConfig`] (or a derived class) from an audio encoder configuration.
  284. Returns:
  285. [`MoshiConfig`]: An instance of a configuration object
  286. """
  287. return cls(
  288. audio_encoder_config=audio_encoder_config.to_dict(),
  289. **kwargs,
  290. )
  291. __all__ = ["MoshiConfig", "MoshiDepthConfig"]