modeling_xcodec.py 24 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525526527528529530531532533534535536537538539540541542543544545546547548549550551552553554555556557558559560561562563564565566567568569570571572573574575576577578579580
  1. # coding=utf-8
  2. # Copyright 2025 The HuggingFace Inc. team. All rights reserved.
  3. #
  4. # Licensed under the Apache License, Version 2.0 (the "License");
  5. # you may not use this file except in compliance with the License.
  6. # You may obtain a copy of the License at
  7. #
  8. # http://www.apache.org/licenses/LICENSE-2.0
  9. #
  10. # Unless required by applicable law or agreed to in writing, software
  11. # distributed under the License is distributed on an "AS IS" BASIS,
  12. # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  13. # See the License for the specific language governing permissions and
  14. # limitations under the License.
  15. """Transformers Xcodec model."""
  16. import math
  17. from dataclasses import dataclass
  18. from typing import Optional, Union
  19. import torch
  20. import torch.nn as nn
  21. import torch.nn.functional as F
  22. from ...modeling_utils import PreTrainedAudioTokenizerBase
  23. from ...utils import ModelOutput, auto_docstring
  24. from ..auto import AutoModel
  25. from .configuration_xcodec import XcodecConfig
  26. @dataclass
  27. class XcodecOutput(ModelOutput):
  28. """
  29. Args:
  30. audio_codes (`torch.LongTensor` of shape `(batch_size, num_quantizers, codes_length)`, *optional*):
  31. Discrete code indices computed using `model.encode`.
  32. audio_values (`torch.FloatTensor` of shape `(batch_size, channels, num_samples)`, *optional*)
  33. Decoded audio values obtained using the decoder part of Xcodec.
  34. """
  35. audio_codes: Optional[torch.LongTensor] = None
  36. audio_values: Optional[torch.FloatTensor] = None
  37. @dataclass
  38. class XcodecEncoderOutput(ModelOutput):
  39. """
  40. Args:
  41. audio_codes (`torch.LongTensor` of shape `(batch_size, num_quantizers, codes_length)`, *optional*):
  42. Discrete code indices computed using `model.encode`.
  43. """
  44. audio_codes: Optional[torch.LongTensor] = None
  45. @dataclass
  46. class XcodecDecoderOutput(ModelOutput):
  47. """
  48. Args:
  49. audio_values (`torch.FloatTensor` of shape `(batch_size, channels, num_samples)`, *optional*):
  50. Decoded audio values obtained using the decoder part of Xcodec.
  51. """
  52. audio_values: Optional[torch.FloatTensor] = None
  53. class ResidualUnit(nn.Module):
  54. """Residual block for SemanticEncoder and SemanticDecoder used in Xcodec."""
  55. def __init__(self, config: XcodecConfig, in_channels: int, out_channels: int, dilation: int):
  56. super().__init__()
  57. self.activation = nn.ELU()
  58. padding = ((config.unit_kernel_size - 1) // 2) * dilation
  59. self.conv1 = nn.Conv1d(
  60. in_channels,
  61. out_channels,
  62. config.unit_kernel_size,
  63. stride=1,
  64. padding=padding,
  65. dilation=dilation,
  66. groups=1,
  67. bias=False,
  68. )
  69. self.conv2 = nn.Conv1d(in_channels=out_channels, out_channels=out_channels, kernel_size=1, bias=False)
  70. def forward(self, hidden_state: torch.Tensor) -> torch.Tensor:
  71. output_tensor = self.activation(hidden_state)
  72. output_tensor = self.conv1(output_tensor)
  73. output_tensor = self.activation(output_tensor)
  74. output_tensor = self.conv2(output_tensor)
  75. return hidden_state + output_tensor
  76. class SemanticEncoderBlock(nn.Module):
  77. def __init__(self, config: XcodecConfig, in_channels: int, out_channels: int, stride: int):
  78. super().__init__()
  79. self.res_units = nn.ModuleList(
  80. [ResidualUnit(config, in_channels, in_channels, dilation) for dilation in config.block_dilations]
  81. )
  82. # special case: stride=1, do not use kernel=2
  83. kernel = 3 if stride == 1 else (2 * stride)
  84. padding = (kernel - 1) // 2
  85. self.conv = nn.Conv1d(in_channels, out_channels, kernel_size=kernel, stride=stride, padding=padding, bias=True)
  86. def forward(self, hidden_state: torch.Tensor) -> torch.Tensor:
  87. for unit in self.res_units:
  88. hidden_state = unit(hidden_state)
  89. hidden_state = self.conv(hidden_state)
  90. return hidden_state
  91. class SemanticEncoder(nn.Module):
  92. def __init__(self, config):
  93. super().__init__()
  94. if len(config.strides) != len(config.channel_ratios):
  95. raise ValueError("Number of strides must match the number of channel_ratios.")
  96. self.conv = nn.Conv1d(
  97. config.semantic_hidden_size,
  98. config.semantic_hidden_size,
  99. config.kernel_size,
  100. 1,
  101. config.kernel_size // 2,
  102. bias=False,
  103. )
  104. in_channels = config.semantic_hidden_size
  105. conv_blocks = []
  106. for i, stride in enumerate(config.strides):
  107. out_channels = int(config.semantic_hidden_size * config.channel_ratios[i])
  108. conv_blocks += [SemanticEncoderBlock(config, in_channels, out_channels, stride)]
  109. in_channels = out_channels
  110. self.conv_blocks = nn.ModuleList(conv_blocks)
  111. def forward(self, hidden_state: torch.Tensor) -> torch.Tensor:
  112. hidden_state = self.conv(hidden_state)
  113. for block in self.conv_blocks:
  114. hidden_state = block(hidden_state)
  115. return hidden_state
  116. class SemanticDecoderBlock(nn.Module):
  117. def __init__(self, config: XcodecConfig, in_channels: int, out_channels: int, stride: int):
  118. super().__init__()
  119. if stride == 1:
  120. self.conv = nn.Conv1d(
  121. in_channels,
  122. out_channels,
  123. kernel_size=3,
  124. stride=1,
  125. padding=1,
  126. bias=True,
  127. )
  128. else:
  129. kernel_size = 2 * stride
  130. padding = (stride + 1) // 2
  131. output_padding = 1 if stride % 2 == 1 else 0
  132. self.conv = nn.ConvTranspose1d(
  133. in_channels, out_channels, kernel_size, stride, padding, output_padding, bias=False
  134. )
  135. self.res_units = nn.ModuleList(
  136. [ResidualUnit(config, out_channels, out_channels, dilation) for dilation in config.block_dilations]
  137. )
  138. def forward(self, hidden_state: torch.Tensor) -> torch.Tensor:
  139. hidden_state = self.conv(hidden_state)
  140. for unit in self.res_units:
  141. hidden_state = unit(hidden_state)
  142. return hidden_state
  143. class SemanticDecoder(nn.Module):
  144. def __init__(self, config):
  145. super().__init__()
  146. self.conv1 = nn.Conv1d(
  147. in_channels=config.semantic_hidden_size,
  148. out_channels=int(config.semantic_hidden_size * config.channel_ratios[0]),
  149. kernel_size=config.kernel_size,
  150. stride=1,
  151. padding=config.kernel_size // 2,
  152. bias=False,
  153. )
  154. conv_blocks = []
  155. for i, stride in enumerate(config.strides):
  156. in_channels = int(config.semantic_hidden_size * config.channel_ratios[i])
  157. if i < (len(config.channel_ratios) - 1):
  158. out_channels = int(config.semantic_hidden_size * config.channel_ratios[i + 1])
  159. else:
  160. out_channels = config.semantic_hidden_size
  161. conv_blocks += [SemanticDecoderBlock(config, in_channels, out_channels, stride)]
  162. self.conv_blocks = nn.ModuleList(conv_blocks)
  163. self.conv2 = nn.Conv1d(
  164. config.semantic_hidden_size,
  165. config.semantic_hidden_size,
  166. config.kernel_size,
  167. stride=1,
  168. padding=config.kernel_size // 2,
  169. bias=False,
  170. )
  171. def forward(self, hidden_state: torch.Tensor) -> torch.Tensor:
  172. hidden_state = self.conv1(hidden_state)
  173. for block in self.conv_blocks:
  174. hidden_state = block(hidden_state)
  175. hidden_state = self.conv2(hidden_state)
  176. return hidden_state
  177. class XcodecEuclideanCodebook(nn.Module):
  178. """Codebook with Euclidean distance."""
  179. def __init__(self, config):
  180. super().__init__()
  181. embed = torch.zeros(config.codebook_size, config.codebook_dim)
  182. self.codebook_size = config.codebook_size
  183. self.register_buffer("inited", torch.Tensor([True]))
  184. self.register_buffer("cluster_size", torch.zeros(config.codebook_size))
  185. self.register_buffer("embed", embed)
  186. self.register_buffer("embed_avg", embed.clone())
  187. # Copied from transformers.models.encodec.modeling_encodec.EncodecEuclideanCodebook.quantize
  188. def quantize(self, hidden_states):
  189. embed = self.embed.t()
  190. scaled_states = hidden_states.pow(2).sum(1, keepdim=True)
  191. dist = -(scaled_states - 2 * hidden_states @ embed + embed.pow(2).sum(0, keepdim=True))
  192. embed_ind = dist.max(dim=-1).indices
  193. return embed_ind
  194. def encode(self, hidden_states):
  195. shape = hidden_states.shape
  196. hidden_states = hidden_states.reshape((-1, shape[-1]))
  197. embed_ind = self.quantize(hidden_states)
  198. embed_ind = embed_ind.view(*shape[:-1])
  199. return embed_ind
  200. def decode(self, embed_ind):
  201. quantized = F.embedding(embed_ind, self.embed)
  202. return quantized
  203. class XcodecVectorQuantization(nn.Module):
  204. """
  205. Vector quantization implementation. Currently supports only euclidean distance.
  206. """
  207. def __init__(self, config: XcodecConfig):
  208. super().__init__()
  209. self.codebook = XcodecEuclideanCodebook(config)
  210. # Copied from transformers.models.encodec.modeling_encodec.EncodecVectorQuantization.encode
  211. def encode(self, hidden_states):
  212. hidden_states = hidden_states.permute(0, 2, 1)
  213. embed_in = self.codebook.encode(hidden_states)
  214. return embed_in
  215. # Copied from transformers.models.encodec.modeling_encodec.EncodecVectorQuantization.decode
  216. def decode(self, embed_ind):
  217. quantize = self.codebook.decode(embed_ind)
  218. quantize = quantize.permute(0, 2, 1)
  219. return quantize
  220. class XcodecResidualVectorQuantization(nn.Module):
  221. """
  222. Residual vector quantization implementation. Follows Algorithm 1 in https://huggingface.co/papers/2107.03312
  223. """
  224. def __init__(self, config: XcodecConfig):
  225. super().__init__()
  226. self.quantizers = nn.ModuleList([XcodecVectorQuantization(config) for _ in range(config.num_quantizers)])
  227. self.frame_rate = config.frame_rate
  228. self.codebook_size = config.codebook_size
  229. self.num_quantizers = config.num_quantizers
  230. def get_bandwidth_per_quantizer(self):
  231. """Return bandwidth per quantizer."""
  232. return math.log2(self.codebook_size) * self.frame_rate / 1000
  233. def get_num_quantizers_for_bandwidth(self, bandwidth=None) -> int:
  234. """Return num_quantizers based on specified target bandwidth."""
  235. bw_per_q = self.get_bandwidth_per_quantizer()
  236. num_quantizers = self.num_quantizers
  237. if bandwidth is not None and bandwidth > 0.0:
  238. num_quantizers = int(max(1, math.floor(bandwidth / bw_per_q)))
  239. return num_quantizers
  240. def encode(self, embeddings: torch.Tensor, bandwidth=None) -> torch.Tensor:
  241. """
  242. Encode the input tensor into discrete indices using RVQ, with the number of quantizers selected based on the given bandwidth.
  243. Each quantizer /codebook residually quantizes the input and returns the nearest indices in terms of Euclidian distance.
  244. """
  245. num_quantizers = self.get_num_quantizers_for_bandwidth(bandwidth)
  246. residual = embeddings
  247. all_indices = []
  248. for quantizer in self.quantizers[:num_quantizers]:
  249. indices = quantizer.encode(residual)
  250. quantized = quantizer.decode(indices)
  251. residual = residual - quantized
  252. all_indices.append(indices)
  253. out_indices = torch.stack(all_indices)
  254. return out_indices
  255. def decode(self, codes: torch.Tensor) -> torch.Tensor:
  256. """Decode the given codes to their quantized representation."""
  257. quantized_out = torch.tensor(0.0, device=codes.device)
  258. for i, indices in enumerate(codes):
  259. quantizer = self.quantizers[i]
  260. quantized = quantizer.decode(indices)
  261. quantized_out = quantized_out + quantized
  262. return quantized_out
  263. @auto_docstring
  264. class XcodecPreTrainedModel(PreTrainedAudioTokenizerBase):
  265. """
  266. An abstract class to handle weights initialization and a simple interface for downloading and loading pretrained
  267. models.
  268. """
  269. config_class = XcodecConfig
  270. base_model_prefix = "xcodec"
  271. main_input_name = "input_values"
  272. def _init_weights(self, module):
  273. """Initialize the weights"""
  274. if isinstance(module, nn.Linear):
  275. module.weight.data.normal_(mean=0.0, std=self.config.initializer_range)
  276. if module.bias is not None:
  277. module.bias.data.zero_()
  278. elif isinstance(module, (nn.LayerNorm, nn.GroupNorm)):
  279. module.bias.data.zero_()
  280. module.weight.data.fill_(1.0)
  281. elif isinstance(module, nn.Conv1d):
  282. nn.init.kaiming_normal_(module.weight)
  283. if module.bias is not None:
  284. k = math.sqrt(module.groups / (module.in_channels * module.kernel_size[0]))
  285. nn.init.uniform_(module.bias, a=-k, b=k)
  286. elif module.__class__.__name__ == "Snake1d":
  287. module.alpha.data.fill_(1.0)
  288. elif isinstance(module, nn.ConvTranspose1d):
  289. module.reset_parameters()
  290. elif isinstance(module, nn.Embedding):
  291. module.weight.data.normal_(mean=0.0, std=0.02)
  292. elif isinstance(module, XcodecModel):
  293. # The conv1d are not handled correctly, as `self.acoustic_encoder/decoder` are initialized from a PreTrainedModel,
  294. # but then only the submodules are used (which are not PreTrainedModels...) -> here we reinit them as in DacModel
  295. for submodule in module.acoustic_encoder.modules():
  296. if isinstance(submodule, nn.Conv1d):
  297. nn.init.trunc_normal_(submodule.weight, std=0.02)
  298. nn.init.constant_(submodule.bias, 0)
  299. for submodule in module.acoustic_decoder.modules():
  300. if isinstance(submodule, nn.Conv1d):
  301. nn.init.trunc_normal_(submodule.weight, std=0.02)
  302. nn.init.constant_(submodule.bias, 0)
  303. def apply_weight_norm(self):
  304. """Apply weight norm in the acoustic encoder and decoder because the original checkpoint has weight norm applied."""
  305. weight_norm = torch.nn.utils.weight_norm
  306. if hasattr(torch.nn.utils.parametrizations, "weight_norm"):
  307. weight_norm = torch.nn.utils.parametrizations.weight_norm
  308. weight_norm(self.acoustic_encoder.conv1)
  309. weight_norm(self.acoustic_encoder.conv2)
  310. for block in self.acoustic_encoder.block:
  311. weight_norm(block.conv1)
  312. for res_unit in (block.res_unit1, block.res_unit2, block.res_unit3):
  313. weight_norm(res_unit.conv1)
  314. weight_norm(res_unit.conv2)
  315. weight_norm(self.acoustic_decoder.conv1, name="weight")
  316. weight_norm(self.acoustic_decoder.conv2, name="weight")
  317. for block in self.acoustic_decoder.block:
  318. weight_norm(block.conv_t1, name="weight")
  319. for res_unit in (block.res_unit1, block.res_unit2, block.res_unit3):
  320. weight_norm(res_unit.conv1, name="weight")
  321. weight_norm(res_unit.conv2, name="weight")
  322. def remove_weight_norm(self):
  323. """Remove the weight norm from the acoustic encoder and decoder."""
  324. for module in (self.acoustic_encoder, self.acoustic_decoder):
  325. for m in module.modules():
  326. try:
  327. torch.nn.utils.remove_weight_norm(m, name="weight")
  328. except (ValueError, AttributeError):
  329. pass
  330. if hasattr(m, "parametrizations") and "weight" in m.parametrizations:
  331. torch.nn.utils.parametrize.remove_parametrizations(m, "weight", leave_parametrized=True)
  332. @auto_docstring(custom_intro="""The Xcodec neural audio codec model.""")
  333. class XcodecModel(XcodecPreTrainedModel):
  334. def __init__(self, config):
  335. super().__init__(config)
  336. self.config = config
  337. self.pad = config.hop_length // 2
  338. acoustic_model = AutoModel.from_config(config.acoustic_model_config)
  339. self.acoustic_encoder = acoustic_model.encoder
  340. self.acoustic_decoder = acoustic_model.decoder
  341. self._adjust_dac_decoder(self.acoustic_decoder)
  342. self.encoder_semantic = SemanticEncoder(config)
  343. self.decoder_semantic = SemanticDecoder(config)
  344. self.semantic_model = AutoModel.from_config(config.semantic_model_config).eval()
  345. self.fc = nn.Linear(config.hidden_size, config.hidden_size)
  346. self.fc1 = nn.Linear(config.hidden_size, config.semantic_model_config.hidden_size)
  347. self.fc2 = nn.Linear(config.hidden_size, config.acoustic_model_config.hidden_size)
  348. self.quantizer = XcodecResidualVectorQuantization(config)
  349. # Initialize weights and apply final processing
  350. self.post_init()
  351. @staticmethod
  352. def _adjust_dac_decoder(decoder: nn.Module):
  353. r"""
  354. DAC implemented in Xcodec is slightly different from the HF version.
  355. DAC in Xcodec adjusts the output padding in every ConvTranspose1d in the decoder and removes
  356. the final `nn.Tanh` activation function.
  357. """
  358. for module in decoder.modules():
  359. if isinstance(module, nn.ConvTranspose1d):
  360. stride = module.stride[0] if isinstance(module.stride, tuple) else module.stride
  361. module.output_padding = (stride % 2,)
  362. if hasattr(decoder, "tanh") and isinstance(decoder.tanh, nn.Tanh):
  363. decoder.tanh = nn.Identity()
  364. def _extract_semantic_features(self, input_values: torch.FloatTensor) -> torch.FloatTensor:
  365. input_values = input_values[:, 0, :]
  366. input_values = F.pad(input_values, (self.pad, self.pad))
  367. with torch.no_grad():
  368. outputs = self.semantic_model(input_values, output_hidden_states=True)
  369. hidden_states = outputs.hidden_states
  370. stacked = torch.stack(hidden_states, dim=1)
  371. return stacked.mean(dim=1)
  372. @auto_docstring
  373. def encode(
  374. self,
  375. input_values: torch.Tensor,
  376. bandwidth: Optional[float] = None,
  377. return_dict: Optional[bool] = None,
  378. ) -> Union[torch.Tensor, XcodecEncoderOutput]:
  379. r"""
  380. input_values (`torch.FloatTensor` of shape `(batch_size, channels, num_samples)`):
  381. Float values of the input audio waveform.
  382. bandwidth (`float`, *optional*):
  383. The target bandwidth in (kbps) supports only values in `config.target_bandwidths`.
  384. Defaults to the highest available bandwidth `4.0` kbps.
  385. return_dict (`bool`, *optional*):
  386. Whether or not to return a [`~utils.ModelOutput`].
  387. Returns:
  388. `torch.LongTensor` of shape `(batch_size, num_quantizers, codes_length)` containing the discrete encoded audio codes.
  389. """
  390. return_dict = return_dict if return_dict is not None else self.config.return_dict
  391. channels = input_values.shape[1]
  392. if channels != 1:
  393. raise ValueError(f"Audio must be mono, but got {channels}")
  394. if bandwidth is None:
  395. bandwidth = self.config.target_bandwidths[-1]
  396. elif bandwidth not in self.config.target_bandwidths:
  397. raise ValueError(
  398. f"This model doesn't support the bandwidth {bandwidth}. Select one of {self.config.target_bandwidths}."
  399. )
  400. e_semantic_input = self._extract_semantic_features(input_values).detach()
  401. e_semantic = self.encoder_semantic(e_semantic_input.transpose(1, 2))
  402. e_acoustic = self.acoustic_encoder(input_values)
  403. if e_acoustic.shape[2] != e_semantic.shape[2]:
  404. # make sure they line up if frames don't match
  405. e_acoustic = self.acoustic_encoder(F.pad(input_values[:, 0, :], (self.pad, self.pad)).unsqueeze(1))
  406. embeddings = torch.cat([e_acoustic, e_semantic], dim=1)
  407. embeddings = self.fc(embeddings.transpose(1, 2)).transpose(1, 2)
  408. audio_codes = self.quantizer.encode(embeddings, bandwidth)
  409. audio_codes = audio_codes.transpose(0, 1)
  410. if not return_dict:
  411. return audio_codes
  412. return XcodecEncoderOutput(audio_codes)
  413. @auto_docstring
  414. def decode(
  415. self,
  416. audio_codes: torch.Tensor,
  417. return_dict: Optional[bool] = None,
  418. ) -> Union[torch.Tensor, XcodecDecoderOutput]:
  419. r"""
  420. audio_codes (`torch.LongTensor` of shape `(batch_size, num_quantizers, codes_length)`):
  421. Discrete code indices computed using `model.encode`.
  422. return_dict (`bool`, *optional*):
  423. Whether or not to return a [`~utils.ModelOutput`]
  424. Returns:
  425. Decoded audio values of shape `(batch_size, channels, num_samples)` obtained using the decoder part of
  426. Xcodec.
  427. """
  428. return_dict = return_dict if return_dict is not None else self.config.return_dict
  429. audio_codes = audio_codes.transpose(0, 1)
  430. quantized = self.quantizer.decode(audio_codes)
  431. quantized_acoustic = self.fc2(quantized.transpose(1, 2)).transpose(1, 2)
  432. audio_values = self.acoustic_decoder(quantized_acoustic)
  433. if not return_dict:
  434. return audio_values
  435. return XcodecDecoderOutput(audio_values)
  436. @auto_docstring
  437. def forward(
  438. self,
  439. input_values: torch.Tensor,
  440. audio_codes: Optional[torch.Tensor] = None,
  441. bandwidth: Optional[float] = None,
  442. return_dict: Optional[bool] = None,
  443. ) -> Union[tuple[torch.Tensor, torch.Tensor], XcodecOutput]:
  444. r"""
  445. input_values (`torch.FloatTensor` of shape `(batch_size, channels, num_samples)`):
  446. The raw float values of the input audio waveform.
  447. audio_codes (`torch.LongTensor` of shape `(batch_size, num_quantizers, codes_length)`:
  448. Discrete code indices computed using `model.encode`.
  449. bandwidth (`float`, *optional*):
  450. Target bandwidth in kbps. Must be one of `config.target_bandwidths`. Defaults to the highest available bandwidth.
  451. bandwidth (`float`, *optional*):
  452. Target bandwidth in kbps. Must be one of `config.target_bandwidths`. Defaults to the highest available bandwidth.
  453. return_dict (`bool`, *optional*):
  454. Whether to return a [`XcodecOutput`] instead of a plain tuple.
  455. Returns:
  456. `XcodecOutput` or tuple `(audio_codes, audio_values)`:
  457. - `audio_codes` of shape `(batch_size, num_quantizers, codes_length)`: the quantized discrete codes.
  458. - `audio_values` of shape `(batch_size, channels, num_samples)`: the reconstructed audio waveform given the codes.
  459. Example:
  460. ```python
  461. >>> from datasets import load_dataset
  462. >>> from transformers import AutoFeatureExtractor, XcodecModel
  463. >>> model_id = "hf-audio/xcodec-hubert-librispeech"
  464. >>> model = XcodecModel.from_pretrained(model_id)
  465. >>> feature_extractor = AutoFeatureExtractor.from_pretrained(model_id)
  466. >>> dataset = load_dataset("hf-internal-testing/librispeech_asr_dummy", "clean", split="validation")
  467. >>> dataset = dataset.cast_column("audio", Audio(sampling_rate=feature_extractor.sampling_rate))
  468. >>> audio_sample = dataset[0]['audio']['array']
  469. >>> inputs = feature_extractor(raw_audio=audio_sample, return_tensors="pt")
  470. >>> outputs = model(**inputs)
  471. >>> audio_codes = outputs.audio_codes
  472. >>> audio_values = outputs.audio_values
  473. ```
  474. """
  475. return_dict = return_dict if return_dict is not None else self.config.return_dict
  476. length = input_values.shape[-1]
  477. if audio_codes is None:
  478. audio_codes = self.encode(input_values, bandwidth, return_dict=False)
  479. audio_values = self.decode(audio_codes, return_dict=return_dict)[0][..., :length]
  480. if not return_dict:
  481. return (audio_codes, audio_values)
  482. return XcodecOutput(audio_codes=audio_codes, audio_values=audio_values)
  483. __all__ = ["XcodecModel", "XcodecPreTrainedModel"]