configuration_mimi.py 13 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279
  1. # coding=utf-8
  2. # Copyright 2024 Meta Platforms, Inc. and affiliates, and the HuggingFace Inc. team. All rights reserved.
  3. #
  4. # Licensed under the Apache License, Version 2.0 (the "License");
  5. # you may not use this file except in compliance with the License.
  6. # You may obtain a copy of the License at
  7. #
  8. # http://www.apache.org/licenses/LICENSE-2.0
  9. #
  10. # Unless required by applicable law or agreed to in writing, software
  11. # distributed under the License is distributed on an "AS IS" BASIS,
  12. # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  13. # See the License for the specific language governing permissions and
  14. # limitations under the License.
  15. """Mimi model configuration"""
  16. import math
  17. import numpy as np
  18. from ...configuration_utils import PretrainedConfig
  19. from ...utils import logging
  20. logger = logging.get_logger(__name__)
  21. class MimiConfig(PretrainedConfig):
  22. r"""
  23. This is the configuration class to store the configuration of an [`MimiModel`]. It is used to instantiate a
  24. Mimi model according to the specified arguments, defining the model architecture. Instantiating a configuration
  25. with the defaults will yield a similar configuration to that of the
  26. [kyutai/mimi](https://huggingface.co/kyutai/mimi) architecture.
  27. Configuration objects inherit from [`PretrainedConfig`] and can be used to control the model outputs. Read the
  28. documentation from [`PretrainedConfig`] for more information.
  29. Args:
  30. sampling_rate (`int`, *optional*, defaults to 24000):
  31. The sampling rate at which the audio waveform should be digitalized expressed in hertz (Hz).
  32. frame_rate (`float`, *optional*):
  33. Should be computed from the other parameters, yet kept for backward compatibility.
  34. audio_channels (`int`, *optional*, defaults to 1):
  35. Number of channels in the audio data. Either 1 for mono or 2 for stereo.
  36. hidden_size (`int`, *optional*, defaults to 512):
  37. Intermediate representation dimension.
  38. num_filters (`int`, *optional*, defaults to 64):
  39. Number of convolution kernels of first `MimiConv1d` down sampling layer.
  40. num_residual_layers (`int`, *optional*, defaults to 1):
  41. Number of residual layers.
  42. upsampling_ratios (`Sequence[int]`, *optional*):
  43. Kernel size and stride ratios. The encoder uses downsampling ratios instead of upsampling ratios, hence it
  44. will use the ratios in the reverse order to the ones specified here that must match the decoder order.
  45. If not specified, will defaults to `[8, 6, 5, 4]`
  46. kernel_size (`int`, *optional*, defaults to 7):
  47. Kernel size for the initial convolution.
  48. last_kernel_size (`int`, *optional*, defaults to 3):
  49. Kernel size for the last convolution layer.
  50. residual_kernel_size (`int`, *optional*, defaults to 3):
  51. Kernel size for the residual layers.
  52. dilation_growth_rate (`int`, *optional*, defaults to 2):
  53. How much to increase the dilation with each layer.
  54. use_causal_conv (`bool`, *optional*, defaults to `True`):
  55. Whether to use fully causal convolution.
  56. pad_mode (`str`, *optional*, defaults to `"constant"`):
  57. Padding mode for the convolutions.
  58. compress (`int`, *optional*, defaults to 2):
  59. Reduced dimensionality in residual branches.
  60. trim_right_ratio (`float`, *optional*, defaults to 1.0):
  61. Ratio for trimming at the right of the transposed convolution under the `use_causal_conv = True` setup. If
  62. equal to 1.0, it means that all the trimming is done at the right.
  63. codebook_size (`int`, *optional*, defaults to 2048):
  64. Number of discret codes in each codebooks.
  65. codebook_dim (`int`, *optional*, defaults to 256):
  66. Dimension of the unquantized codebook vectors. If not defined, uses `hidden_size`.
  67. num_quantizers (`int`, *optional*, defaults to 32):
  68. Number of quantizer channels, or codebooks, in the quantizer.
  69. use_conv_shortcut (`bool`, *optional*, defaults to `False`):
  70. Whether to use a convolutional layer as the 'skip' connection in the `MimiResnetBlock` block. If False,
  71. an identity function will be used, giving a generic residual connection.
  72. vector_quantization_hidden_dimension (`int`, *optional*, defaults to 256):
  73. Intermediate representation dimension in the residual vector quantization space.
  74. num_semantic_quantizers (`int`, *optional*, defaults to 1):
  75. Number of semantic quantizer channels, or codebooks, in the semantic quantizer. Must be lower than `num_quantizers`.
  76. upsample_groups (`int`, *optional*, defaults to 512):
  77. If `frame_rate!=encodec_frame_rate`, indicates the number of groups used in the upsampling operation to go from one rate to another.
  78. num_hidden_layers (`int`, *optional*, defaults to 8):
  79. Number of hidden layers in the Transformer models.
  80. intermediate_size (`int`, *optional*, defaults to 2048):
  81. Dimension of the MLP representations.
  82. num_attention_heads (`int`, *optional*, defaults to 8):
  83. Number of attention heads for each attention layer in the Transformer encoder.
  84. num_key_value_heads (`int`, *optional*, defaults to 8):
  85. This is the number of key_value heads that should be used to implement Grouped Query Attention. If
  86. `num_key_value_heads=num_attention_heads`, the model will use Multi Head Attention (MHA), if
  87. `num_key_value_heads=1` the model will use Multi Query Attention (MQA) otherwise GQA is used. When
  88. converting a multi-head checkpoint to a GQA checkpoint, each group key and value head should be constructed
  89. by meanpooling all the original heads within that group. For more details, check out [this
  90. paper](https://huggingface.co/papers/2305.13245). If it is not specified, will default to `8`.
  91. head_dim (`int`, *optional*, defaults to `hidden_size // num_attention_heads`):
  92. The attention head dimension.
  93. hidden_act (`str` or `function`, *optional*, defaults to `"gelu"`):
  94. The non-linear activation function (function or string) in the decoder.
  95. max_position_embeddings (`int`, *optional*, defaults to 8000):
  96. The maximum sequence length that this model might ever be used with. Mimi's sliding window attention
  97. allows sequence of up to 8000 tokens.
  98. initializer_range (`float`, *optional*, defaults to 0.02):
  99. The standard deviation of the truncated_normal_initializer for initializing all weight matrices.
  100. norm_eps (`float`, *optional*, defaults to 1e-05):
  101. The epsilon used by the LayerNorm normalization layers.
  102. use_cache (`bool`, *optional*, defaults to `False`):
  103. Whether or not the model should return the last key/values attentions (not used by all models). Only
  104. relevant if `config.is_decoder=True`.
  105. use_streaming (`bool`, *optional*, defaults to `False`):
  106. Whether to use streaming mode. If `True`, the model encode method will return the padding cache that can be used in a subsequent call to the encode method.
  107. rope_theta (`float`, *optional*, defaults to 10000.0):
  108. The base period of the RoPE embeddings.
  109. sliding_window (`int`, *optional*, defaults to 250):
  110. Sliding window attention window size. If not specified, will default to `250`.
  111. attention_dropout (`float`, *optional*, defaults to 0.0):
  112. The dropout ratio for the attention probabilities.
  113. layer_scale_initial_scale (`float`, *optional*, defaults to 0.01):
  114. Initial scale of the residual rescaling operation done in the Transformer models.
  115. attention_bias (`bool`, defaults to `False`, *optional*, defaults to `False`):
  116. Whether to use a bias in the query, key, value and output projection layers during self-attention.
  117. Example:
  118. ```python
  119. >>> from transformers import MimiModel, MimiConfig
  120. >>> # Initializing a "kyutai/mimi" style configuration
  121. >>> configuration = MimiConfig()
  122. >>> # Initializing a model (with random weights) from the "kyutai/mimi" style configuration
  123. >>> model = MimiModel(configuration)
  124. >>> # Accessing the model configuration
  125. >>> configuration = model.config
  126. ```"""
  127. model_type = "mimi"
  128. def __init__(
  129. self,
  130. sampling_rate=24_000,
  131. frame_rate=None,
  132. audio_channels=1,
  133. hidden_size=512,
  134. num_filters=64,
  135. num_residual_layers=1,
  136. upsampling_ratios=None,
  137. kernel_size=7,
  138. last_kernel_size=3,
  139. residual_kernel_size=3,
  140. dilation_growth_rate=2,
  141. use_causal_conv=True,
  142. pad_mode="constant",
  143. compress=2,
  144. trim_right_ratio=1.0,
  145. codebook_size=2048,
  146. codebook_dim=256,
  147. num_quantizers=32,
  148. use_conv_shortcut=False,
  149. vector_quantization_hidden_dimension=256,
  150. num_semantic_quantizers=1,
  151. upsample_groups=512,
  152. num_hidden_layers=8,
  153. intermediate_size=2048,
  154. num_attention_heads=8,
  155. num_key_value_heads=8,
  156. head_dim=None,
  157. hidden_act="gelu",
  158. max_position_embeddings=8000,
  159. initializer_range=0.02,
  160. norm_eps=1e-5,
  161. use_cache=False,
  162. use_streaming=False,
  163. rope_theta=10000.0,
  164. sliding_window=250,
  165. attention_dropout=0.0,
  166. layer_scale_initial_scale=0.01,
  167. attention_bias=False,
  168. **kwargs,
  169. ):
  170. self.sampling_rate = sampling_rate
  171. self.audio_channels = audio_channels
  172. self.hidden_size = hidden_size
  173. self.num_filters = num_filters
  174. self.num_residual_layers = num_residual_layers
  175. self.upsampling_ratios = upsampling_ratios if upsampling_ratios else [8, 6, 5, 4]
  176. self.kernel_size = kernel_size
  177. self.last_kernel_size = last_kernel_size
  178. self.residual_kernel_size = residual_kernel_size
  179. self.dilation_growth_rate = dilation_growth_rate
  180. self.use_causal_conv = use_causal_conv
  181. self.pad_mode = pad_mode
  182. self.compress = compress
  183. self.trim_right_ratio = trim_right_ratio
  184. self.codebook_size = codebook_size
  185. self.codebook_dim = codebook_dim if codebook_dim is not None else hidden_size
  186. self.num_quantizers = num_quantizers
  187. self.use_conv_shortcut = use_conv_shortcut
  188. self.vector_quantization_hidden_dimension = vector_quantization_hidden_dimension
  189. self.upsample_groups = upsample_groups
  190. self.num_hidden_layers = num_hidden_layers
  191. self.intermediate_size = intermediate_size
  192. self.num_attention_heads = num_attention_heads
  193. self.num_key_value_heads = num_key_value_heads
  194. self.hidden_act = hidden_act
  195. self.max_position_embeddings = max_position_embeddings
  196. self.initializer_range = initializer_range
  197. self.norm_eps = norm_eps
  198. self.use_cache = use_cache
  199. self.use_streaming = use_streaming
  200. self.rope_theta = rope_theta
  201. self.sliding_window = sliding_window
  202. self.attention_dropout = attention_dropout
  203. self.head_dim = head_dim or hidden_size // num_attention_heads
  204. self.layer_scale_initial_scale = layer_scale_initial_scale
  205. self.attention_bias = attention_bias
  206. # Handle backward compatibility for frame_rate:
  207. # If frame_rate is explicitly provided, use it (backward compatibility)
  208. # Otherwise, compute it from other parameters (correctly)
  209. if frame_rate is not None:
  210. self._frame_rate = frame_rate
  211. else:
  212. self._frame_rate = None
  213. if num_semantic_quantizers >= self.num_quantizers:
  214. raise ValueError(
  215. f"The number of semantic quantizers should be lower than the total number of quantizers {self.num_quantizers}, but is currently {num_semantic_quantizers}."
  216. )
  217. self.num_semantic_quantizers = num_semantic_quantizers
  218. super().__init__(**kwargs)
  219. @property
  220. def encodec_frame_rate(self) -> int:
  221. hop_length = np.prod(self.upsampling_ratios)
  222. return math.ceil(self.sampling_rate / hop_length)
  223. @property
  224. def num_codebooks(self) -> int:
  225. # alias to num_quantizers
  226. return self.num_quantizers
  227. @property
  228. def frame_size(self) -> int:
  229. # 1. we need each encoder conv stride
  230. # first conv
  231. strides = [1]
  232. # layer convs
  233. for ratio in reversed(self.upsampling_ratios):
  234. for j in range(self.num_residual_layers):
  235. len_kernel_sizes = len(self.residual_kernel_size) if isinstance(self.residual_kernel_size, list) else 1
  236. strides.extend([1] * (len_kernel_sizes + 1))
  237. if self.use_conv_shortcut: # skip connection
  238. strides.append(1)
  239. strides.append(ratio)
  240. # last conv
  241. strides.append(1)
  242. # downsampling layer
  243. strides.append(2)
  244. return math.prod(strides)
  245. @property
  246. def frame_rate(self) -> float:
  247. # handle backward compatibility
  248. if self._frame_rate is not None:
  249. return self._frame_rate
  250. return self.sampling_rate / self.frame_size
  251. __all__ = ["MimiConfig"]