modeling_dac.py 28 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525526527528529530531532533534535536537538539540541542543544545546547548549550551552553554555556557558559560561562563564565566567568569570571572573574575576577578579580581582583584585586587588589590591592593594595596597598599600601602603604605606607608609610611612613614615616617618619620621622623624625626627628629630631632633634635636637638639640641642643644645646647648649650651652653654655656657658659660661662663664665666667668669670671672673674675676677678679680681682683684685
  1. # coding=utf-8
  2. # Copyright 2024 Descript and The HuggingFace Inc. team. All rights reserved.
  3. #
  4. # Licensed under the Apache License, Version 2.0 (the "License");
  5. # you may not use this file except in compliance with the License.
  6. # You may obtain a copy of the License at
  7. #
  8. # http://www.apache.org/licenses/LICENSE-2.0
  9. #
  10. # Unless required by applicable law or agreed to in writing, software
  11. # distributed under the License is distributed on an "AS IS" BASIS,
  12. # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  13. # See the License for the specific language governing permissions and
  14. # limitations under the License.
  15. """Transformers DAC model."""
  16. import math
  17. from dataclasses import dataclass
  18. from typing import Optional
  19. import numpy as np
  20. import torch
  21. import torch.nn as nn
  22. import torch.nn.functional as F
  23. from ...modeling_utils import PreTrainedAudioTokenizerBase
  24. from ...utils import ModelOutput, auto_docstring
  25. from .configuration_dac import DacConfig
  26. @dataclass
  27. @auto_docstring
  28. class DacOutput(ModelOutput):
  29. r"""
  30. loss (`torch.Tensor`):
  31. Loss from the encoder model, comprising the weighted combination of the commitment and codebook losses.
  32. audio_values (`torch.Tensor` of shape `(batch_size, input_length)`):
  33. Reconstructed audio data.
  34. quantized_representation (`torch.Tensor` of shape `(batch_size, dimension, time_steps)`):
  35. Quantized continuous representation of input.
  36. audio_codes (`torch.LongTensor` of shape `(batch_size, num_codebooks, time_steps)`):
  37. Codebook indices for each codebook (quantized discrete representation of input).
  38. projected_latents (`torch.Tensor` of shape `(batch_size, num_codebooks * dimension, time_steps)`):
  39. Projected latents (continuous representation of input before quantization).
  40. """
  41. loss: Optional[torch.FloatTensor] = None
  42. audio_values: Optional[torch.FloatTensor] = None
  43. quantized_representation: Optional[torch.FloatTensor] = None
  44. audio_codes: Optional[torch.LongTensor] = None
  45. projected_latents: Optional[torch.FloatTensor] = None
  46. @dataclass
  47. @auto_docstring
  48. class DacEncoderOutput(ModelOutput):
  49. r"""
  50. loss (`torch.Tensor`):
  51. Loss from the encoder model, comprising the weighted combination of the commitment and codebook losses.
  52. quantized_representation (`torch.Tensor` of shape `(batch_size, dimension, time_steps)`, *optional*):
  53. Quantized continuous representation of input.
  54. audio_codes (`torch.Tensor` of shape `(batch_size, num_codebooks, time_steps)`, *optional*):
  55. Codebook indices for each codebook (quantized discrete representation of input).
  56. projected_latents (`torch.Tensor` of shape `(batch_size, num_codebooks * dimension, time_steps)`, *optional*):
  57. Projected latents (continuous representation of input before quantization).
  58. """
  59. loss: Optional[torch.FloatTensor] = None
  60. quantized_representation: Optional[torch.FloatTensor] = None
  61. audio_codes: Optional[torch.FloatTensor] = None
  62. projected_latents: Optional[torch.FloatTensor] = None
  63. @dataclass
  64. @auto_docstring
  65. # Copied from transformers.models.encodec.modeling_encodec.EncodecDecoderOutput with Encodec->Dac, segment_length->input_length
  66. class DacDecoderOutput(ModelOutput):
  67. r"""
  68. audio_values (`torch.FloatTensor` of shape `(batch_size, input_length)`, *optional*):
  69. Decoded audio values, obtained using the decoder part of Dac.
  70. """
  71. audio_values: Optional[torch.FloatTensor] = None
  72. class Snake1d(nn.Module):
  73. """
  74. A 1-dimensional Snake activation function module.
  75. """
  76. def __init__(self, hidden_dim):
  77. super().__init__()
  78. self.alpha = nn.Parameter(torch.ones(1, hidden_dim, 1))
  79. def forward(self, hidden_states):
  80. shape = hidden_states.shape
  81. hidden_states = hidden_states.reshape(shape[0], shape[1], -1)
  82. hidden_states = hidden_states + (self.alpha + 1e-9).reciprocal() * torch.sin(self.alpha * hidden_states).pow(2)
  83. hidden_states = hidden_states.reshape(shape)
  84. return hidden_states
  85. class DacVectorQuantize(nn.Module):
  86. """
  87. Implementation of VQ similar to Karpathy's repo (https://github.com/karpathy/deep-vector-quantization)
  88. Additionally uses following tricks from improved VQGAN
  89. (https://huggingface.co/papers/2110.04627):
  90. 1. Factorized codes: Perform nearest neighbor lookup in low-dimensional space
  91. for improved codebook usage
  92. 2. l2-normalized codes: Converts euclidean distance to cosine similarity which
  93. improves training stability
  94. """
  95. def __init__(self, config: DacConfig):
  96. super().__init__()
  97. self.codebook_dim = config.codebook_dim
  98. self.in_proj = nn.Conv1d(config.hidden_size, config.codebook_dim, kernel_size=1)
  99. self.out_proj = nn.Conv1d(config.codebook_dim, config.hidden_size, kernel_size=1)
  100. self.codebook = nn.Embedding(config.codebook_size, config.codebook_dim)
  101. def forward(self, hidden_state):
  102. """
  103. Quantizes the input tensor using a fixed codebook and returns the corresponding codebook vectors.
  104. Args:
  105. hidden_state (`torch.FloatTensor` of shape `(batch_size, dimension, time_steps)`):
  106. Input tensor.
  107. Returns:
  108. quantized_representation (`torch.Tensor`of shape `(batch_size, dimension, time_steps)`):
  109. Quantized continuous representation of input.
  110. commitment_loss (`torch.FloatTensor`of shape `(1)`):
  111. Commitment loss to train encoder to predict vectors closer to codebook entries.
  112. codebook_loss (`torch.FloatTensor`of shape `(1)`):
  113. Codebook loss to update the codebook.
  114. audio_codes (`torch.LongTensor` of shape `(batch_size, time_steps)`):
  115. Codebook indices for each codebook, quantized discrete representation of input.
  116. projected_latents (torch.FloatTensor of shape `(batch_size, num_codebooks * dimension, time_steps)`):
  117. Projected latents (continuous representation of input before quantization).
  118. """
  119. projected_latents = self.in_proj(hidden_state)
  120. quantized_representation, audio_codes = self.decode_latents(projected_latents)
  121. commitment_loss = F.mse_loss(projected_latents, quantized_representation.detach(), reduction="mean")
  122. codebook_loss = F.mse_loss(quantized_representation, projected_latents.detach(), reduction="mean")
  123. # noop in forward pass, straight-through gradient estimator in backward pass
  124. quantized_representation = projected_latents + (quantized_representation - projected_latents).detach()
  125. quantized_representation = self.out_proj(quantized_representation)
  126. return quantized_representation, commitment_loss, codebook_loss, audio_codes, projected_latents
  127. def decode_latents(self, hidden_states):
  128. batch_size, hidden_dim, sequence_length = hidden_states.shape
  129. encodings = hidden_states.permute(0, 2, 1).reshape(batch_size * sequence_length, hidden_dim)
  130. codebook = self.codebook.weight # codebook: (N x D)
  131. # L2 normalize encodings and codebook (ViT-VQGAN)
  132. encodings = F.normalize(encodings)
  133. codebook = F.normalize(codebook)
  134. # Compute euclidean distance with codebook
  135. l2_norm = encodings.pow(2).sum(1, keepdim=True)
  136. dist = -(l2_norm - 2 * encodings @ codebook.t()) + codebook.pow(2).sum(1, keepdim=True).t()
  137. indices = dist.max(1)[1]
  138. indices = indices.reshape(hidden_states.size(0), -1)
  139. quantized_representation = self.codebook(indices).transpose(1, 2)
  140. return quantized_representation, indices
  141. class DacResidualUnit(nn.Module):
  142. """
  143. A residual unit composed of Snake1d and weight-normalized Conv1d layers with dilations.
  144. """
  145. def __init__(self, dimension: int = 16, dilation: int = 1):
  146. super().__init__()
  147. pad = ((7 - 1) * dilation) // 2
  148. self.snake1 = Snake1d(dimension)
  149. self.conv1 = nn.Conv1d(dimension, dimension, kernel_size=7, dilation=dilation, padding=pad)
  150. self.snake2 = Snake1d(dimension)
  151. self.conv2 = nn.Conv1d(dimension, dimension, kernel_size=1)
  152. def forward(self, hidden_state):
  153. """
  154. Forward pass through the residual unit.
  155. Args:
  156. hidden_state (`torch.Tensor` of shape `(batch_size, channels, time_steps)`):
  157. Input tensor .
  158. Returns:
  159. output_tensor (`torch.Tensor` of shape `(batch_size, channels, time_steps)`):
  160. Input tensor after passing through the residual unit.
  161. """
  162. output_tensor = hidden_state
  163. output_tensor = self.conv1(self.snake1(output_tensor))
  164. output_tensor = self.conv2(self.snake2(output_tensor))
  165. padding = (hidden_state.shape[-1] - output_tensor.shape[-1]) // 2
  166. if padding > 0:
  167. hidden_state = hidden_state[..., padding:-padding]
  168. output_tensor = hidden_state + output_tensor
  169. return output_tensor
  170. class DacEncoderBlock(nn.Module):
  171. """Encoder block used in DAC encoder."""
  172. def __init__(self, config: DacConfig, stride: int = 1, stride_index: int = 1):
  173. super().__init__()
  174. dimension = config.encoder_hidden_size * 2**stride_index
  175. self.res_unit1 = DacResidualUnit(dimension // 2, dilation=1)
  176. self.res_unit2 = DacResidualUnit(dimension // 2, dilation=3)
  177. self.res_unit3 = DacResidualUnit(dimension // 2, dilation=9)
  178. self.snake1 = Snake1d(dimension // 2)
  179. self.conv1 = nn.Conv1d(
  180. dimension // 2, dimension, kernel_size=2 * stride, stride=stride, padding=math.ceil(stride / 2)
  181. )
  182. def forward(self, hidden_state):
  183. hidden_state = self.res_unit1(hidden_state)
  184. hidden_state = self.res_unit2(hidden_state)
  185. hidden_state = self.snake1(self.res_unit3(hidden_state))
  186. hidden_state = self.conv1(hidden_state)
  187. return hidden_state
  188. class DacDecoderBlock(nn.Module):
  189. """Decoder block used in DAC decoder."""
  190. def __init__(self, config: DacConfig, stride: int = 1, stride_index: int = 1):
  191. super().__init__()
  192. input_dim = config.decoder_hidden_size // 2**stride_index
  193. output_dim = config.decoder_hidden_size // 2 ** (stride_index + 1)
  194. self.snake1 = Snake1d(input_dim)
  195. self.conv_t1 = nn.ConvTranspose1d(
  196. input_dim,
  197. output_dim,
  198. kernel_size=2 * stride,
  199. stride=stride,
  200. padding=math.ceil(stride / 2),
  201. )
  202. self.res_unit1 = DacResidualUnit(output_dim, dilation=1)
  203. self.res_unit2 = DacResidualUnit(output_dim, dilation=3)
  204. self.res_unit3 = DacResidualUnit(output_dim, dilation=9)
  205. def forward(self, hidden_state):
  206. hidden_state = self.snake1(hidden_state)
  207. hidden_state = self.conv_t1(hidden_state)
  208. hidden_state = self.res_unit1(hidden_state)
  209. hidden_state = self.res_unit2(hidden_state)
  210. hidden_state = self.res_unit3(hidden_state)
  211. return hidden_state
  212. class DacResidualVectorQuantize(nn.Module):
  213. """
  214. ResidualVectorQuantize block - Introduced in SoundStream: An end2end neural audio codec (https://huggingface.co/papers/2107.03312)
  215. """
  216. def __init__(self, config: DacConfig):
  217. super().__init__()
  218. n_codebooks = config.n_codebooks
  219. quantizer_dropout = config.quantizer_dropout
  220. self.n_codebooks = n_codebooks
  221. self.quantizers = nn.ModuleList([DacVectorQuantize(config) for i in range(config.n_codebooks)])
  222. self.quantizer_dropout = quantizer_dropout
  223. def forward(self, hidden_state, n_quantizers: Optional[int] = None):
  224. """
  225. Quantizes the input tensor using a fixed set of codebooks and returns corresponding codebook vectors.
  226. Args:
  227. hidden_state (`torch.Tensor` of shape `(batch_size, dimension, time_steps)`):
  228. Input tensor to be quantized.
  229. n_quantizers (`int`, *optional*):
  230. Number of quantizers to use. If specified and `self.quantizer_dropout` is True,
  231. this argument is ignored during training, and a random number of quantizers is used.
  232. Returns:
  233. quantized_representation (`torch.Tensor` of shape `(batch_size, dimension, time_steps)`):
  234. Quantized continuous representation of input.
  235. audio_codes (`torch.Tensor` of shape `(batch_size, num_codebooks, time_steps)`):
  236. Codebook indices for each codebook (quantized discrete representation of input).
  237. projected_latents (`torch.Tensor` of shape `(batch_size, num_codebooks * dimension, time_steps)`):
  238. Projected latents (continuous representation of input before quantization).
  239. commitment_loss (`torch.Tensor` of shape `(1)`):
  240. Commitment loss to train the encoder to predict vectors closer to codebook entries.
  241. codebook_loss (`torch.Tensor` of shape `(1)`):
  242. Codebook loss to update the codebook.
  243. """
  244. quantized_representation = 0
  245. residual = hidden_state
  246. commitment_loss = 0
  247. codebook_loss = 0
  248. audio_codes = []
  249. projected_latents = []
  250. n_quantizers = n_quantizers if n_quantizers is not None else self.n_codebooks
  251. if self.training:
  252. n_quantizers = torch.ones((hidden_state.shape[0],)) * self.n_codebooks + 1
  253. dropout = torch.randint(1, self.n_codebooks + 1, (hidden_state.shape[0],))
  254. n_dropout = int(hidden_state.shape[0] * self.quantizer_dropout)
  255. n_quantizers[:n_dropout] = dropout[:n_dropout]
  256. n_quantizers = n_quantizers.to(hidden_state.device)
  257. for i, quantizer in enumerate(self.quantizers):
  258. if self.training is False and i >= n_quantizers:
  259. break
  260. quantized_representation_i, commitment_loss_i, codebook_loss_i, indices_i, projected_latents_i = quantizer(
  261. residual
  262. )
  263. # Create mask to apply quantizer dropout
  264. mask = torch.full((hidden_state.shape[0],), fill_value=i, device=hidden_state.device) < n_quantizers
  265. quantized_representation = quantized_representation + quantized_representation_i * mask[:, None, None]
  266. residual = residual - quantized_representation_i
  267. # Sum losses
  268. commitment_loss += commitment_loss_i * mask
  269. codebook_loss += codebook_loss_i * mask
  270. audio_codes.append(indices_i)
  271. projected_latents.append(projected_latents_i)
  272. audio_codes = torch.stack(audio_codes, dim=1)
  273. projected_latents = torch.cat(projected_latents, dim=1)
  274. return quantized_representation, audio_codes, projected_latents, commitment_loss, codebook_loss
  275. def from_codes(self, audio_codes: torch.Tensor):
  276. """
  277. Reconstructs the continuous representation from quantized codes.
  278. Args:
  279. audio_codes (`torch.Tensor` of shape `(batch_size, num_codebooks, time_steps)`):
  280. Quantized discrete representation of input.
  281. Returns:
  282. quantized_representation (`torch.Tensor`):
  283. Quantized continuous representation of input.
  284. projected_latents (`torch.Tensor`):
  285. List of projected latents (continuous representations of input before quantization)
  286. for each codebook.
  287. audio_codes (`torch.Tensor`):
  288. Codebook indices for each codebook.
  289. """
  290. quantized_representation = 0.0
  291. projected_latents = []
  292. n_codebooks = audio_codes.shape[1]
  293. for i in range(n_codebooks):
  294. projected_latents_i = self.quantizers[i].codebook(audio_codes[:, i, :]).transpose(1, 2)
  295. projected_latents.append(projected_latents_i)
  296. quantized_representation += self.quantizers[i].out_proj(projected_latents_i)
  297. return quantized_representation, torch.cat(projected_latents, dim=1), audio_codes
  298. def from_latents(self, latents: torch.Tensor):
  299. """Reconstructs the quantized representation from unquantized latents.
  300. Args:
  301. latents (`torch.Tensor` of shape `(batch_size, total_latent_dimension, time_steps)`):
  302. Continuous representation of input after projection.
  303. Returns:
  304. quantized_representation (`torch.Tensor` of shape `(batch_size, dimension, time_steps)`):
  305. Quantized representation of the full-projected space.
  306. quantized_latents (`torch.Tensor` of shape `(batch_size, dimension, time_steps)`):
  307. Quantized representation of the latent space (continuous representation before quantization).
  308. """
  309. quantized_representation = 0
  310. quantized_latents = []
  311. codes = []
  312. codebook_dims_tensor = torch.tensor([0] + [q.codebook_dim for q in self.quantizers])
  313. dims = torch.cumsum(codebook_dims_tensor, dim=0)
  314. n_codebooks = np.where(dims <= latents.shape[1])[0].max(axis=0, keepdims=True)[0]
  315. for i in range(n_codebooks):
  316. hidden_dim_j, hidden_dim_k = dims[i], dims[i + 1]
  317. quantized_latents_i, codes_i = self.quantizers[i].decode_latents(latents[:, hidden_dim_j:hidden_dim_k, :])
  318. quantized_latents.append(quantized_latents_i)
  319. codes.append(codes_i)
  320. quantized_representation_i = self.quantizers[i].out_proj(quantized_latents_i)
  321. quantized_representation = quantized_representation + quantized_representation_i
  322. return quantized_representation, torch.cat(quantized_latents, dim=1)
  323. class DacDecoder(nn.Module):
  324. """DAC Decoder"""
  325. def __init__(self, config: DacConfig):
  326. super().__init__()
  327. input_channel = config.hidden_size
  328. channels = config.decoder_hidden_size
  329. strides = config.upsampling_ratios
  330. # Add first conv layer
  331. self.conv1 = nn.Conv1d(input_channel, channels, kernel_size=7, padding=3)
  332. # Add upsampling + MRF blocks
  333. block = []
  334. for stride_index, stride in enumerate(strides):
  335. block += [DacDecoderBlock(config, stride, stride_index)]
  336. self.block = nn.ModuleList(block)
  337. output_dim = config.decoder_hidden_size // 2 ** (stride_index + 1)
  338. self.snake1 = Snake1d(output_dim)
  339. self.conv2 = nn.Conv1d(output_dim, 1, kernel_size=7, padding=3)
  340. self.tanh = nn.Tanh()
  341. def forward(self, hidden_state):
  342. hidden_state = self.conv1(hidden_state)
  343. for layer in self.block:
  344. hidden_state = layer(hidden_state)
  345. hidden_state = self.snake1(hidden_state)
  346. hidden_state = self.conv2(hidden_state)
  347. hidden_state = self.tanh(hidden_state)
  348. return hidden_state
  349. class DacEncoder(nn.Module):
  350. """DAC Encoder"""
  351. def __init__(self, config: DacConfig):
  352. super().__init__()
  353. strides = config.downsampling_ratios
  354. # Create first convolution
  355. self.conv1 = nn.Conv1d(1, config.encoder_hidden_size, kernel_size=7, padding=3)
  356. self.block = []
  357. # Create EncoderBlocks that double channels as they downsample by `stride`
  358. for stride_index, stride in enumerate(strides):
  359. stride_index = stride_index + 1
  360. self.block += [DacEncoderBlock(config, stride=stride, stride_index=stride_index)]
  361. self.block = nn.ModuleList(self.block)
  362. d_model = config.encoder_hidden_size * 2**stride_index
  363. self.snake1 = Snake1d(d_model)
  364. self.conv2 = nn.Conv1d(d_model, config.hidden_size, kernel_size=3, padding=1)
  365. def forward(self, hidden_state):
  366. hidden_state = self.conv1(hidden_state)
  367. for module in self.block:
  368. hidden_state = module(hidden_state)
  369. hidden_state = self.snake1(hidden_state)
  370. hidden_state = self.conv2(hidden_state)
  371. return hidden_state
  372. @auto_docstring
  373. class DacPreTrainedModel(PreTrainedAudioTokenizerBase):
  374. config: DacConfig
  375. base_model_prefix = "dac"
  376. main_input_name = "input_values"
  377. def _init_weights(self, module):
  378. if isinstance(module, nn.Conv1d):
  379. nn.init.trunc_normal_(module.weight, std=0.02)
  380. nn.init.constant_(module.bias, 0)
  381. elif isinstance(module, Snake1d):
  382. module.alpha.data.fill_(1.0)
  383. elif isinstance(module, nn.ConvTranspose1d):
  384. module.reset_parameters()
  385. elif isinstance(module, nn.Embedding):
  386. module.weight.data.normal_(mean=0.0, std=0.02)
  387. def apply_weight_norm(self):
  388. weight_norm = nn.utils.weight_norm
  389. if hasattr(nn.utils.parametrizations, "weight_norm"):
  390. weight_norm = nn.utils.parametrizations.weight_norm
  391. for layer in self.quantizer.quantizers:
  392. weight_norm(layer.in_proj)
  393. weight_norm(layer.out_proj)
  394. weight_norm(self.encoder.conv1)
  395. weight_norm(self.encoder.conv2)
  396. for layer in self.encoder.block:
  397. weight_norm(layer.conv1)
  398. weight_norm(layer.res_unit1.conv1)
  399. weight_norm(layer.res_unit1.conv2)
  400. weight_norm(layer.res_unit2.conv1)
  401. weight_norm(layer.res_unit2.conv2)
  402. weight_norm(layer.res_unit3.conv1)
  403. weight_norm(layer.res_unit3.conv2)
  404. weight_norm(self.decoder.conv1)
  405. weight_norm(self.decoder.conv2)
  406. for layer in self.decoder.block:
  407. weight_norm(layer.conv_t1)
  408. weight_norm(layer.res_unit1.conv1)
  409. weight_norm(layer.res_unit1.conv2)
  410. weight_norm(layer.res_unit2.conv1)
  411. weight_norm(layer.res_unit2.conv2)
  412. weight_norm(layer.res_unit3.conv1)
  413. weight_norm(layer.res_unit3.conv2)
  414. def remove_weight_norm(self):
  415. for layer in self.quantizer.quantizers:
  416. nn.utils.remove_weight_norm(layer.in_proj)
  417. nn.utils.remove_weight_norm(layer.out_proj)
  418. nn.utils.remove_weight_norm(self.encoder.conv1)
  419. nn.utils.remove_weight_norm(self.encoder.conv2)
  420. for layer in self.encoder.block:
  421. nn.utils.remove_weight_norm(layer.conv1)
  422. nn.utils.remove_weight_norm(layer.res_unit1.conv1)
  423. nn.utils.remove_weight_norm(layer.res_unit1.conv2)
  424. nn.utils.remove_weight_norm(layer.res_unit2.conv1)
  425. nn.utils.remove_weight_norm(layer.res_unit2.conv2)
  426. nn.utils.remove_weight_norm(layer.res_unit3.conv1)
  427. nn.utils.remove_weight_norm(layer.res_unit3.conv2)
  428. nn.utils.remove_weight_norm(self.decoder.conv1)
  429. nn.utils.remove_weight_norm(self.decoder.conv2)
  430. for layer in self.decoder.block:
  431. nn.utils.remove_weight_norm(layer.conv_t1)
  432. nn.utils.remove_weight_norm(layer.res_unit1.conv1)
  433. nn.utils.remove_weight_norm(layer.res_unit1.conv2)
  434. nn.utils.remove_weight_norm(layer.res_unit2.conv1)
  435. nn.utils.remove_weight_norm(layer.res_unit2.conv2)
  436. nn.utils.remove_weight_norm(layer.res_unit3.conv1)
  437. nn.utils.remove_weight_norm(layer.res_unit3.conv2)
  438. @auto_docstring(
  439. custom_intro="""
  440. The DAC (Descript Audio Codec) model.
  441. """
  442. )
  443. class DacModel(DacPreTrainedModel):
  444. def __init__(self, config: DacConfig):
  445. super().__init__(config)
  446. self.config = config
  447. self.encoder = DacEncoder(config)
  448. self.decoder = DacDecoder(config)
  449. self.quantizer = DacResidualVectorQuantize(config)
  450. self.bits_per_codebook = int(math.log2(self.config.codebook_size))
  451. if 2**self.bits_per_codebook != self.config.codebook_size:
  452. raise ValueError("The codebook_size must be a power of 2.")
  453. # Initialize weights and apply final processing
  454. self.post_init()
  455. @auto_docstring
  456. def encode(
  457. self,
  458. input_values: torch.Tensor,
  459. n_quantizers: Optional[int] = None,
  460. return_dict: Optional[bool] = None,
  461. ):
  462. r"""
  463. input_values (`torch.Tensor of shape `(batch_size, 1, time_steps)`):
  464. Input audio data to encode,
  465. n_quantizers (int, *optional*):
  466. Number of quantizers to use. If None, all quantizers are used. Default is None.
  467. """
  468. return_dict = return_dict if return_dict is not None else self.config.return_dict
  469. quantized_representation = self.encoder(input_values)
  470. quantized_representation, audio_codes, projected_latents, commitment_loss, codebook_loss = self.quantizer(
  471. quantized_representation, n_quantizers
  472. )
  473. loss = self.config.commitment_loss_weight * commitment_loss + self.config.codebook_loss_weight * codebook_loss
  474. if not return_dict:
  475. return (loss, quantized_representation, audio_codes, projected_latents)
  476. return DacEncoderOutput(loss, quantized_representation, audio_codes, projected_latents)
  477. @auto_docstring
  478. def decode(
  479. self,
  480. quantized_representation: Optional[torch.Tensor] = None,
  481. audio_codes: Optional[torch.Tensor] = None,
  482. return_dict: Optional[bool] = None,
  483. ):
  484. r"""
  485. quantized_representation (torch.Tensor of shape `(batch_size, dimension, time_steps)`, *optional*):
  486. Quantized continuous representation of input.
  487. audio_codes (`torch.Tensor` of shape `(batch_size, num_codebooks, time_steps)`, *optional*):
  488. The codebook indices for each codebook, representing the quantized discrete
  489. representation of the input. This parameter should be provided if you want
  490. to decode directly from the audio codes (it will overwrite quantized_representation).
  491. return_dict (`bool`, *optional*, defaults to `True`):
  492. Whether to return a [`DacDecoderOutput`] instead of a plain tuple.
  493. """
  494. if quantized_representation is None and audio_codes is None:
  495. raise ValueError("Either `quantized_representation` or `audio_codes` must be provided.")
  496. return_dict = return_dict if return_dict is not None else self.config.return_dict
  497. if audio_codes is not None:
  498. quantized_representation = self.quantizer.from_codes(audio_codes)[0]
  499. audio_values = self.decoder(quantized_representation).squeeze(1)
  500. if not return_dict:
  501. return (audio_values,)
  502. return DacDecoderOutput(audio_values)
  503. @auto_docstring
  504. def forward(
  505. self,
  506. input_values: torch.Tensor,
  507. n_quantizers: Optional[int] = None,
  508. return_dict: Optional[bool] = None,
  509. ):
  510. r"""
  511. input_values (`torch.Tensor` of shape `(batch_size, 1, time_steps)`):
  512. Audio data to encode.
  513. n_quantizers (`int`, *optional*):
  514. Number of quantizers to use. If `None`, all quantizers are used. Default is `None`.
  515. Examples:
  516. ```python
  517. >>> from datasets import load_dataset, Audio
  518. >>> from transformers import DacModel, AutoProcessor
  519. >>> librispeech_dummy = load_dataset("hf-internal-testing/librispeech_asr_dummy", "clean", split="validation")
  520. >>> model = DacModel.from_pretrained("descript/dac_16khz")
  521. >>> processor = AutoProcessor.from_pretrained("descript/dac_16khz")
  522. >>> librispeech_dummy = librispeech_dummy.cast_column("audio", Audio(sampling_rate=processor.sampling_rate))
  523. >>> audio_sample = librispeech_dummy[-1]["audio"]["array"]
  524. >>> inputs = processor(raw_audio=audio_sample, sampling_rate=processor.sampling_rate, return_tensors="pt")
  525. >>> encoder_outputs = model.encode(inputs["input_values"])
  526. >>> # Get the intermediate audio codes
  527. >>> audio_codes = encoder_outputs.audio_codes
  528. >>> # Reconstruct the audio from its quantized representation
  529. >>> audio_values = model.decode(encoder_outputs.quantized_representation)
  530. >>> # or the equivalent with a forward pass
  531. >>> audio_values = model(inputs["input_values"]).audio_values
  532. ```"""
  533. return_dict = return_dict if return_dict is not None else self.config.return_dict
  534. length = input_values.shape[-1]
  535. loss, quantized_representation, audio_codes, projected_latents = self.encode(
  536. input_values, n_quantizers, return_dict=False
  537. )
  538. audio_values = self.decode(quantized_representation, return_dict=False)[0][..., :length]
  539. if not return_dict:
  540. return (loss, audio_values, quantized_representation, audio_codes, projected_latents)
  541. return DacOutput(loss, audio_values, quantized_representation, audio_codes, projected_latents)
  542. __all__ = ["DacModel", "DacPreTrainedModel"]