configuration_encodec.py 8.5 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199
  1. # coding=utf-8
  2. # Copyright 2023 Meta Platforms, Inc. and affiliates, and the HuggingFace Inc. team. All rights reserved.
  3. #
  4. # Licensed under the Apache License, Version 2.0 (the "License");
  5. # you may not use this file except in compliance with the License.
  6. # You may obtain a copy of the License at
  7. #
  8. # http://www.apache.org/licenses/LICENSE-2.0
  9. #
  10. # Unless required by applicable law or agreed to in writing, software
  11. # distributed under the License is distributed on an "AS IS" BASIS,
  12. # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  13. # See the License for the specific language governing permissions and
  14. # limitations under the License.
  15. """EnCodec model configuration"""
  16. import math
  17. from typing import Optional
  18. import numpy as np
  19. from ...configuration_utils import PretrainedConfig
  20. from ...utils import logging
  21. logger = logging.get_logger(__name__)
  22. class EncodecConfig(PretrainedConfig):
  23. r"""
  24. This is the configuration class to store the configuration of an [`EncodecModel`]. It is used to instantiate a
  25. Encodec model according to the specified arguments, defining the model architecture. Instantiating a configuration
  26. with the defaults will yield a similar configuration to that of the
  27. [facebook/encodec_24khz](https://huggingface.co/facebook/encodec_24khz) architecture.
  28. Configuration objects inherit from [`PretrainedConfig`] and can be used to control the model outputs. Read the
  29. documentation from [`PretrainedConfig`] for more information.
  30. Args:
  31. target_bandwidths (`list[float]`, *optional*, defaults to `[1.5, 3.0, 6.0, 12.0, 24.0]`):
  32. The range of different bandwidths the model can encode audio with.
  33. sampling_rate (`int`, *optional*, defaults to 24000):
  34. The sampling rate at which the audio waveform should be digitalized expressed in hertz (Hz).
  35. audio_channels (`int`, *optional*, defaults to 1):
  36. Number of channels in the audio data. Either 1 for mono or 2 for stereo.
  37. normalize (`bool`, *optional*, defaults to `False`):
  38. Whether the audio shall be normalized when passed.
  39. chunk_length_s (`float`, *optional*):
  40. If defined the audio is pre-processed into chunks of lengths `chunk_length_s` and then encoded.
  41. overlap (`float`, *optional*):
  42. Defines the overlap between each chunk. It is used to compute the `chunk_stride` using the following
  43. formulae : `int((1.0 - self.overlap) * self.chunk_length)`.
  44. hidden_size (`int`, *optional*, defaults to 128):
  45. Intermediate representation dimension.
  46. num_filters (`int`, *optional*, defaults to 32):
  47. Number of convolution kernels of first `EncodecConv1d` down sampling layer.
  48. num_residual_layers (`int`, *optional*, defaults to 1):
  49. Number of residual layers.
  50. upsampling_ratios (`Sequence[int]` , *optional*, defaults to `[8, 5, 4, 2]`):
  51. Kernel size and stride ratios. The encoder uses downsampling ratios instead of upsampling ratios, hence it
  52. will use the ratios in the reverse order to the ones specified here that must match the decoder order.
  53. norm_type (`str`, *optional*, defaults to `"weight_norm"`):
  54. Normalization method. Should be in `["weight_norm", "time_group_norm"]`
  55. kernel_size (`int`, *optional*, defaults to 7):
  56. Kernel size for the initial convolution.
  57. last_kernel_size (`int`, *optional*, defaults to 7):
  58. Kernel size for the last convolution layer.
  59. residual_kernel_size (`int`, *optional*, defaults to 3):
  60. Kernel size for the residual layers.
  61. dilation_growth_rate (`int`, *optional*, defaults to 2):
  62. How much to increase the dilation with each layer.
  63. use_causal_conv (`bool`, *optional*, defaults to `True`):
  64. Whether to use fully causal convolution.
  65. pad_mode (`str`, *optional*, defaults to `"reflect"`):
  66. Padding mode for the convolutions.
  67. compress (`int`, *optional*, defaults to 2):
  68. Reduced dimensionality in residual branches (from Demucs v3).
  69. num_lstm_layers (`int`, *optional*, defaults to 2):
  70. Number of LSTM layers at the end of the encoder.
  71. trim_right_ratio (`float`, *optional*, defaults to 1.0):
  72. Ratio for trimming at the right of the transposed convolution under the `use_causal_conv = True` setup. If
  73. equal to 1.0, it means that all the trimming is done at the right.
  74. codebook_size (`int`, *optional*, defaults to 1024):
  75. Number of discret codes that make up VQVAE.
  76. codebook_dim (`int`, *optional*):
  77. Dimension of the codebook vectors. If not defined, uses `hidden_size`.
  78. use_conv_shortcut (`bool`, *optional*, defaults to `True`):
  79. Whether to use a convolutional layer as the 'skip' connection in the `EncodecResnetBlock` block. If False,
  80. an identity function will be used, giving a generic residual connection.
  81. Example:
  82. ```python
  83. >>> from transformers import EncodecModel, EncodecConfig
  84. >>> # Initializing a "facebook/encodec_24khz" style configuration
  85. >>> configuration = EncodecConfig()
  86. >>> # Initializing a model (with random weights) from the "facebook/encodec_24khz" style configuration
  87. >>> model = EncodecModel(configuration)
  88. >>> # Accessing the model configuration
  89. >>> configuration = model.config
  90. ```"""
  91. model_type = "encodec"
  92. def __init__(
  93. self,
  94. target_bandwidths=[1.5, 3.0, 6.0, 12.0, 24.0],
  95. sampling_rate=24_000,
  96. audio_channels=1,
  97. normalize=False,
  98. chunk_length_s=None,
  99. overlap=None,
  100. hidden_size=128,
  101. num_filters=32,
  102. num_residual_layers=1,
  103. upsampling_ratios=[8, 5, 4, 2],
  104. norm_type="weight_norm",
  105. kernel_size=7,
  106. last_kernel_size=7,
  107. residual_kernel_size=3,
  108. dilation_growth_rate=2,
  109. use_causal_conv=True,
  110. pad_mode="reflect",
  111. compress=2,
  112. num_lstm_layers=2,
  113. trim_right_ratio=1.0,
  114. codebook_size=1024,
  115. codebook_dim=None,
  116. use_conv_shortcut=True,
  117. **kwargs,
  118. ):
  119. self.target_bandwidths = target_bandwidths
  120. self.sampling_rate = sampling_rate
  121. self.audio_channels = audio_channels
  122. self.normalize = normalize
  123. self.chunk_length_s = chunk_length_s
  124. self.overlap = overlap
  125. self.hidden_size = hidden_size
  126. self.num_filters = num_filters
  127. self.num_residual_layers = num_residual_layers
  128. self.upsampling_ratios = upsampling_ratios
  129. self.norm_type = norm_type
  130. self.kernel_size = kernel_size
  131. self.last_kernel_size = last_kernel_size
  132. self.residual_kernel_size = residual_kernel_size
  133. self.dilation_growth_rate = dilation_growth_rate
  134. self.use_causal_conv = use_causal_conv
  135. self.pad_mode = pad_mode
  136. self.compress = compress
  137. self.num_lstm_layers = num_lstm_layers
  138. self.trim_right_ratio = trim_right_ratio
  139. self.codebook_size = codebook_size
  140. self.codebook_dim = codebook_dim if codebook_dim is not None else hidden_size
  141. self.use_conv_shortcut = use_conv_shortcut
  142. if self.norm_type not in ["weight_norm", "time_group_norm"]:
  143. raise ValueError(
  144. f'self.norm_type must be one of `"weight_norm"`, `"time_group_norm"`), got {self.norm_type}'
  145. )
  146. super().__init__(**kwargs)
  147. # This is a property because you might want to change the chunk_length_s on the fly
  148. @property
  149. def chunk_length(self) -> Optional[int]:
  150. if self.chunk_length_s is None:
  151. return None
  152. else:
  153. return int(self.chunk_length_s * self.sampling_rate)
  154. # This is a property because you might want to change the chunk_length_s on the fly
  155. @property
  156. def chunk_stride(self) -> Optional[int]:
  157. if self.chunk_length_s is None or self.overlap is None:
  158. return None
  159. else:
  160. return max(1, int((1.0 - self.overlap) * self.chunk_length))
  161. @property
  162. def hop_length(self) -> int:
  163. return int(np.prod(self.upsampling_ratios))
  164. @property
  165. def codebook_nbits(self) -> int:
  166. return math.ceil(math.log2(self.codebook_size))
  167. @property
  168. def frame_rate(self) -> int:
  169. return math.ceil(self.sampling_rate / self.hop_length)
  170. @property
  171. def num_quantizers(self) -> int:
  172. return int(1000 * self.target_bandwidths[-1] // (self.frame_rate * self.codebook_nbits))
  173. __all__ = ["EncodecConfig"]